Graphsage with Node and edge features

Im trying to modify this Graphsage class for message passing that includes both the node and edge features, Have come up with the below code

class GraphSAGE(nn.Module):
    def __init__(self, g, in_feats, h_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, "gcn")
        self.dropout = Dropout(0.1)
        self.conv2 = SAGEConv(h_feats, h_feats, "gcn")
        self.conv3 = SAGEConv(h_feats, h_feats, "gcn")
        

    def forward(self, g, node_features, edge_features):

        g.ndata["hn"] = node_features
        g.edata["he"] = edge_features

        # Case1: Perform two rounds of message passing, 
        # one using node features, one using edge features
        g.update_all(fn.copy_u('hn', 'm'), fn.sum('m', 'hn_aggr'))
        g.update_all(fn.copy_e('he', 'm'), fn.sum('m', 'he_aggr'))
        # You can then further combine g.ndata['hn_aggr'] and g.ndata['he_aggr']



        node_features = g.ndata["hn_aggr"]
        h = self.conv1(g, node_features)
        h = F.relu(h)
        h = self.dropout(h)
        h = self.conv2(g, h)
        h = self.conv3(g, h)
        return h

The FAQ here mentions about combining the aggr values, Could you clarify more on this, probably help on how this can be integrated in the above class would be great?

If your edge feature is scalar weight, DGL’s SAGEConv accepts an edge_weight argument to the forward function. If it is not the case, could you describe what you want to achieve mathematically?

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.