Hello @mufeili. Your suggestions have been useful for me to handle the classification part (ie. aggregating node representations of each node type). However, it is the design of the function that I am a little confused with.
#similar to the tutorial on heterographs - mean aggregation changed to sum
class HeteroRGCNLayer(nn.Module):
def __init__(self, in_size, out_size, etypes):
super(HeteroRGCNLayer, self).__init__()
self.weight = nn.ModuleDict({
name : nn.Linear(in_size, out_size) for name in etypes
})
def forward(self, G, feat_dict):
funcs = {}
for srctype, etype, dsttype in G.canonical_etypes:
Wh = self.weight[etype](feat_dict[srctype])
G.nodes[srctype].data['Wh_%s' % etype] = Wh
funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.sum('m', 'h'))
G.multi_update_all(funcs, 'sum')
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
The HeteroRGCN module :
class HeteroRGCN(nn.Module):
def __init__(self, in_size, hidden_size, n_classes, etypes):
super(HeteroRGCN, self).__init__()
self.layer1 = HeteroRGCNLayer(in_size, hidden_size, etypes)
self.layer2 = HeteroRGCNLayer(hidden_size, hidden_size, etypes)
self.classify = nn.Linear(hidden_size, n_classes)
self.in_size = in_size
self.embed = None
def forward(self, G):
# WHERE SHOULD THESE BE DEFINED?
embed_dict = {ntype : nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), self.in_size))
for ntype in G.ntypes}
for key, embed in embed_dict.items():
nn.init.xavier_uniform_(embed)
self.embed = nn.ParameterDict(embed_dict)
h_dict = self.layer1(G, self.embed)
h_dict = {k : F.leaky_relu(h) for k, h in h_dict.items()}
h_dict = self.layer2(G, h_dict)
h_dict = {k : F.leaky_relu(h) for k, h in h_dict.items()}
#rest of the code to take mean of node representations
I am confused as to how to define self.embed - it is a ParameterDict depending on the number of nodes per nodetype. Hence, I moved it to forward() since for different input graphs, the size of ParameterDict will vary. But, it should ideally be in init() . How do I modify the HeteroRGCN module to achieve this? Specifically because I would like to save the state_dict and load it again to possibly generate graph embeddings for unseen graphs.