Aaditya1 commited on
Commit
373715d
1 Parent(s): c24de3e

Create collating_graphormer.pyx

Browse files
Files changed (1) hide show
  1. collating_graphormer.pyx +134 -0
collating_graphormer.pyx ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation and HuggingFace
2
+ # Licensed under the MIT License.
3
+
4
+ from typing import Any, Dict, List, Mapping
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from ...utils import is_cython_available, requires_backends
10
+
11
+
12
+ if is_cython_available():
13
+ import pyximport
14
+
15
+ pyximport.install(setup_args={"include_dirs": np.get_include()})
16
+ from . import algos_graphormer # noqa E402
17
+
18
+
19
+ def convert_to_single_emb(x, offset: int = 512):
20
+ feature_num = x.shape[1] if len(x.shape) > 1 else 1
21
+ feature_offset = 1 + np.arange(0, feature_num * offset, offset, dtype=np.int64)
22
+ x = x + feature_offset
23
+ return x
24
+
25
+
26
+ def preprocess_item(item, keep_features=True):
27
+ requires_backends(preprocess_item, ["cython"])
28
+
29
+ if keep_features and "edge_attr" in item.keys(): # edge_attr
30
+ edge_attr = np.asarray(item["edge_attr"], dtype=np.int64)
31
+ else:
32
+ edge_attr = np.ones((len(item["edge_index"][0]), 1), dtype=np.int64) # same embedding for all
33
+
34
+ if keep_features and "node_feat" in item.keys(): # input_nodes
35
+ node_feature = np.asarray(item["node_feat"], dtype=np.int64)
36
+ else:
37
+ node_feature = np.ones((item["num_nodes"], 1), dtype=np.int64) # same embedding for all
38
+
39
+ edge_index = np.asarray(item["edge_index"], dtype=np.int64)
40
+
41
+ input_nodes = convert_to_single_emb(node_feature) + 1
42
+ num_nodes = item["num_nodes"]
43
+
44
+ if len(edge_attr.shape) == 1:
45
+ edge_attr = edge_attr[:, None]
46
+ attn_edge_type = np.zeros([num_nodes, num_nodes, edge_attr.shape[-1]], dtype=np.int64)
47
+ attn_edge_type[edge_index[0], edge_index[1]] = convert_to_single_emb(edge_attr) + 1
48
+
49
+ # node adj matrix [num_nodes, num_nodes] bool
50
+ adj = np.zeros([num_nodes, num_nodes], dtype=bool)
51
+ adj[edge_index[0], edge_index[1]] = True
52
+
53
+ shortest_path_result, path = algos_graphormer.floyd_warshall(adj)
54
+ max_dist = np.amax(shortest_path_result)
55
+
56
+ input_edges = algos_graphormer.gen_edge_input(max_dist, path, attn_edge_type)
57
+ attn_bias = np.zeros([num_nodes + 1, num_nodes + 1], dtype=np.single) # with graph token
58
+
59
+ # combine
60
+ item["input_nodes"] = input_nodes + 1 # we shift all indices by one for padding
61
+ item["attn_bias"] = attn_bias
62
+ item["attn_edge_type"] = attn_edge_type
63
+ item["spatial_pos"] = shortest_path_result.astype(np.int64) + 1 # we shift all indices by one for padding
64
+ item["in_degree"] = np.sum(adj, axis=1).reshape(-1) + 1 # we shift all indices by one for padding
65
+ item["out_degree"] = item["in_degree"] # for undirected graph
66
+ item["input_edges"] = input_edges + 1 # we shift all indices by one for padding
67
+ if "labels" not in item:
68
+ item["labels"] = item["y"]
69
+
70
+ return item
71
+
72
+
73
+ class GraphormerDataCollator:
74
+ def __init__(self, spatial_pos_max=20, on_the_fly_processing=False):
75
+ if not is_cython_available():
76
+ raise ImportError("Graphormer preprocessing needs Cython (pyximport)")
77
+
78
+ self.spatial_pos_max = spatial_pos_max
79
+ self.on_the_fly_processing = on_the_fly_processing
80
+
81
+ def __call__(self, features: List[dict]) -> Dict[str, Any]:
82
+ if self.on_the_fly_processing:
83
+ features = [preprocess_item(i) for i in features]
84
+
85
+ if not isinstance(features[0], Mapping):
86
+ features = [vars(f) for f in features]
87
+ batch = {}
88
+
89
+ max_node_num = max(len(i["input_nodes"]) for i in features)
90
+ node_feat_size = len(features[0]["input_nodes"][0])
91
+ edge_feat_size = len(features[0]["attn_edge_type"][0][0])
92
+ max_dist = max(len(i["input_edges"][0][0]) for i in features)
93
+ edge_input_size = len(features[0]["input_edges"][0][0][0])
94
+ batch_size = len(features)
95
+
96
+ batch["attn_bias"] = torch.zeros(batch_size, max_node_num + 1, max_node_num + 1, dtype=torch.float)
97
+ batch["attn_edge_type"] = torch.zeros(batch_size, max_node_num, max_node_num, edge_feat_size, dtype=torch.long)
98
+ batch["spatial_pos"] = torch.zeros(batch_size, max_node_num, max_node_num, dtype=torch.long)
99
+ batch["in_degree"] = torch.zeros(batch_size, max_node_num, dtype=torch.long)
100
+ batch["input_nodes"] = torch.zeros(batch_size, max_node_num, node_feat_size, dtype=torch.long)
101
+ batch["input_edges"] = torch.zeros(
102
+ batch_size, max_node_num, max_node_num, max_dist, edge_input_size, dtype=torch.long
103
+ )
104
+
105
+ for ix, f in enumerate(features):
106
+ for k in ["attn_bias", "attn_edge_type", "spatial_pos", "in_degree", "input_nodes", "input_edges"]:
107
+ f[k] = torch.tensor(f[k])
108
+
109
+ if len(f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max]) > 0:
110
+ f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max] = float("-inf")
111
+
112
+ batch["attn_bias"][ix, : f["attn_bias"].shape[0], : f["attn_bias"].shape[1]] = f["attn_bias"]
113
+ batch["attn_edge_type"][ix, : f["attn_edge_type"].shape[0], : f["attn_edge_type"].shape[1], :] = f[
114
+ "attn_edge_type"
115
+ ]
116
+ batch["spatial_pos"][ix, : f["spatial_pos"].shape[0], : f["spatial_pos"].shape[1]] = f["spatial_pos"]
117
+ batch["in_degree"][ix, : f["in_degree"].shape[0]] = f["in_degree"]
118
+ batch["input_nodes"][ix, : f["input_nodes"].shape[0], :] = f["input_nodes"]
119
+ batch["input_edges"][
120
+ ix, : f["input_edges"].shape[0], : f["input_edges"].shape[1], : f["input_edges"].shape[2], :
121
+ ] = f["input_edges"]
122
+
123
+ batch["out_degree"] = batch["in_degree"]
124
+
125
+ sample = features[0]["labels"]
126
+ if len(sample) == 1: # one task
127
+ if isinstance(sample[0], float): # regression
128
+ batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
129
+ else: # binary classification
130
+ batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
131
+ else: # multi task classification, left to float to keep the NaNs
132
+ batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], axis=0))
133
+
134
+ return batch