Hi all,
I’ve been playing around with heterographs with different feature sizes and numbers of nodes for node classifciation etc.
I put together a small example as a sanity check for myself and found it to be quite usefull.
Graph Construction
Lets construct a simple heterograph, two node types (A,B) with two different sized feature spaces (1,2), and different number of nodes (6,5).
import torch
import dgl
# links between the 6 nodes of type A
u1 = [0,1,2,3,4,5]
v1 = [4,5,0,1,2,3]
# links between the 5 nodes of type B
u2 = [0,1,2,3,4]
v2 = [4,0,1,2,3]
# links between the node type A and type B
u3 = [0,1,2,3,4]
v3 = [0,1,2,3,4]
# NOTE in dgl heterographs, each node type is indexed independantly with 'ids' starting from 0.
# this is important as if my indexes had started at 1 I'd have had an extra of each node type.
# that can cause trouble when assigning features.
graph = dgl.heterograph({ ('a', 'same_a', 'a') : (u1, v1),
('b', 'same_b', 'b') : (u2, v2),
('a', 'diff_a', 'b') : (u3, v3),
('b', 'diff_b', 'a') : (v3, u3)}, num_nodes_dict = {'a':6,'b':5})
# assign features to our nodes data
graph.nodes['a'].data['feat'] = torch.ones((6,1))
graph.nodes['b'].data['feat'] = torch.ones((5,2))
Model Implementation
Now our model will have to handle inputs of different sizes. This can be a little confusing as the way the modules within HeteroGraphConv are implemented can get very elaborate… I suggest making a simple modification so you can adjust the number of features later if you want to.
Here instead of specifying the number of inputs and the relationship types as is common in the tutorials, I combine them into a rel_shape_dict.
Note that we only use the values of this dictionary (shape) in the first layer. By the second layer our model is instead using the hid_feats size.
import torch.nn as nn
import dgl.nn as dglnn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self, rel_shape_dict, hid_feats, out_feats):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({
rel: dglnn.SAGEConv(size, hid_feats, aggregator_type='mean')
for rel, size in rel_shape_dict.items()}, aggregate='mean')
self.conv2 = dglnn.HeteroGraphConv({
rel: dglnn.SAGEConv(hid_feats, out_feats,aggregator_type='mean')
for rel in rel_shape_dict.keys()}, aggregate='mean')
def forward(self, graph, inputs):
# inputs are features of nodes
h = self.conv1(graph, inputs)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv2(graph, h)
return h
rel_shape_dict = {'same_a' : 1, 'same_b' : 2, 'diff_a' : (1,2), 'diff_b' : (2,1)}
#rel_shape_dict = {'same_a' : (1,1), 'same_b' : (2,2), 'diff_a' : (1,2), 'diff_b' : (2,1)} # this is equivalent
model = Model(rel_dict,10,1)
# quickly put some inputs together and bam you got a prediction
inputs = {'a':graph.nodes['a'].data['feat'],'b':graph.nodes['b'].data['feat']}
model(graph,inputs)
Disclaimer
Please note this is by no means a good model or graph merely intended for demonstrative purposes. If you’d like to add or improve please feel free!