Im trying to modify this Graphsage class for message passing that includes both the node and edge features, Have come up with the below code
class GraphSAGE(nn.Module):
def __init__(self, g, in_feats, h_feats):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, "gcn")
self.dropout = Dropout(0.1)
self.conv2 = SAGEConv(h_feats, h_feats, "gcn")
self.conv3 = SAGEConv(h_feats, h_feats, "gcn")
def forward(self, g, node_features, edge_features):
g.ndata["hn"] = node_features
g.edata["he"] = edge_features
# Case1: Perform two rounds of message passing,
# one using node features, one using edge features
g.update_all(fn.copy_u('hn', 'm'), fn.sum('m', 'hn_aggr'))
g.update_all(fn.copy_e('he', 'm'), fn.sum('m', 'he_aggr'))
# You can then further combine g.ndata['hn_aggr'] and g.ndata['he_aggr']
node_features = g.ndata["hn_aggr"]
h = self.conv1(g, node_features)
h = F.relu(h)
h = self.dropout(h)
h = self.conv2(g, h)
h = self.conv3(g, h)
return h
The FAQ here mentions about combining the aggr values, Could you clarify more on this, probably help on how this can be integrated in the above class would be great?