DGL Heterograph for Graph Classification

I have been going through the DGL tutorials, and have doubts with reference to these two :

https://docs.dgl.ai/en/0.4.x/tutorials/basics/4_batch.html
https://docs.dgl.ai/en/0.4.x/tutorials/basics/5_hetero.html

I would like to classify a set of heterographs. The concept of data loading is clear (from tutorial 4) and so is the concept of HeteroRGCN (from tutorial 5) for node classification. But if I would like to do a graph classification using the HeteroRGCN as described in the tutorial, how should the forward() look like? Will taking the mean of nodes (of different types) make any sense for the graph classification? I would love to hear some ideas how to implement a forward() for graph classification.

You can take average of the node representations for each node type and then further combine them.

Hi @mufeili . Thanks for your reply.
I did take the mean of node representations of each node type. It is the combination of these that I would need some ideas on :slight_smile: . A weighted mean comes to mind. Do you have other suggestions that I could try?

Once you take the mean of node representations per node type, you can treat them as the node representations in another graph and apply any node-based pooling method like the ones here.

1 Like

Thanks! That seems like a neat way to handle this. :slight_smile:

Hello @mufeili. Your suggestions have been useful for me to handle the classification part (ie. aggregating node representations of each node type). However, it is the design of the function that I am a little confused with.

#similar to the tutorial on heterographs - mean aggregation changed to sum

class HeteroRGCNLayer(nn.Module):

def __init__(self, in_size, out_size, etypes):

    super(HeteroRGCNLayer, self).__init__()

    self.weight = nn.ModuleDict({
            name : nn.Linear(in_size, out_size) for name in etypes
        })

def forward(self, G, feat_dict):

    funcs = {}
    for srctype, etype, dsttype in G.canonical_etypes:
        Wh = self.weight[etype](feat_dict[srctype])
        G.nodes[srctype].data['Wh_%s' % etype] = Wh
        funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.sum('m', 'h'))
    G.multi_update_all(funcs, 'sum')
    return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}

The HeteroRGCN module :

class HeteroRGCN(nn.Module):
    
    def __init__(self, in_size, hidden_size, n_classes, etypes):
        super(HeteroRGCN, self).__init__()
        
        self.layer1 = HeteroRGCNLayer(in_size, hidden_size, etypes)
        self.layer2 = HeteroRGCNLayer(hidden_size, hidden_size, etypes)
        self.classify = nn.Linear(hidden_size, n_classes)
        self.in_size = in_size
        self.embed = None


    def forward(self, G):
        
        # WHERE SHOULD THESE BE DEFINED?
        embed_dict = {ntype : nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), self.in_size))
                      for ntype in G.ntypes}
        for key, embed in embed_dict.items():
            nn.init.xavier_uniform_(embed)
        self.embed = nn.ParameterDict(embed_dict)

        h_dict = self.layer1(G, self.embed) 
        h_dict = {k : F.leaky_relu(h) for k, h in h_dict.items()}
        h_dict = self.layer2(G, h_dict)
        h_dict = {k : F.leaky_relu(h) for k, h in h_dict.items()}

        #rest of the code to take mean of node representations 

I am confused as to how to define self.embed - it is a ParameterDict depending on the number of nodes per nodetype. Hence, I moved it to forward() since for different input graphs, the size of ParameterDict will vary. But, it should ideally be in init() . How do I modify the HeteroRGCN module to achieve this? Specifically because I would like to save the state_dict and load it again to possibly generate graph embeddings for unseen graphs.

It sounds like you are learning node embeddings from scratch for multiple graphs. In this case, you can initialize the embeddings for each graph and store them in ndata of the graphs. See FAQ14 here.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.