Node-level classification in multiple graphs

Hi there,

I’ve seen in the tutorials examples where all the nodes are classified (link) and examples where the graph as a whole is classified (link). However, I’m interested in performing regression at node-level in different graphs, where there’s only one node of interest.

Making reference to the karate club example, let’s say that I have data for clubs with good and bad instructors (labeled as good or bad instructors) and I want to classify the instructors of new clubs (not the club itself and not every member, just the instructors).

Is there any example/doc available that might guide me?

Thanks in advance

Check out our example of GAT on PPI.

1 Like

I’ll do that! Thank you

I have been taking a look at the example, but the size of the labels’ tensor is Nx121 (where N is the number of nodes in the batched graph). Given that the batch contains two sub-graphs I could only provide two labels (a 2x121 tensor), one for each sub-graph.

Maybe I didn’t explain myself correctly. In the example I am dealing with I am only interested in one of the nodes of each graph and I only have a label for these nodes (we can safely assume it’s always node 0 in each graph if it helps).

I understand that it is doable and less demanding from a computational point of view, but I can’t figure out how to implement the loss function. I would have to replace BCEWithLogitsLoss and f1_score with another function that only takes into account one node (or some nodes in the case of batches).

Any help would be appreciated!

There can be two approaches:

  1. Get back the list of sub-graphs with node features updated and compute a loss function on the corresponding nodes. You can get back a list of DGLGraphs with g_list = dgl.unbatch(bg), where bg is a batched graph.
  2. Simply index into the node features of the batched graph and compute a loss on them. See the example below
    import dgl
    import torch
    g1 = dgl.DGLGraph()
    g1.add_nodes(2)
    g1.ndata['h'] = torch.tensor([[1.], [2.]])
    
    g2 = dgl.DGLGraph()
    g2.add_nodes(3)
    g2.ndata['h'] = torch.tensor([[3.], [4.], [5.]])
    bg = dgl.batch([g1, g2])
    
    # To get the node feature of the first node in the first graph 
    # and second node in the second graph.
    print(bg.ndata['h'][[0, 3]])
    
    Note that the indices of nodes in the batched graph are relabeled using consecutive integers from 0 following the order of subgraphs in the list and the order of nodes in subgraphs.

Hope this helps.

1 Like

Hello,

Thank you very much for your response. I’ve done as you suggested and I think that the loss function would work properly now.

However, I’m getting NaNs in the predictions and I don’t know how to fix this. From the GAT example:

# No NaNs are found in 'inputs' or 'h' at this point
h = inputs
for l in range(self.num_layers):
    h = self.gat_layers[l](h).flatten(1)
    # At this point 'h' contains NaNs (more as we go in deeper layers)

I’m working on a new dataset, but the dataset class that I’ve coded is quite similar to the PPIDataset example.

I’ve uploaded all the code with a tiny version of the dataset I’m generating to https://github.com/ljmanso/socnav . I’ve tried to keep it as simple as possible so that it’s easy to understand.

Any help would be appreciated.

The NaN values appear starting from ret = self.g.ndata['ft'] / self.g.ndata['z'].
==> This suggests that some values in self.g.ndata['z'] are 0.
==> self.g.ndata['z'] is obtained by summing unnormalized attentions over incoming edges for nodes. If some are zero, then it’s likely that the corresponding nodes have no incoming edges at all, which I’ve verified.
==> To circumvent this issue, common practices are using only the largest connected component from a graph or adding self loops. Since your graphs are pretty small, I think it’s a better idea to add self loops. This can be performed by adding self.add_edges(self.nodes(), self.nodes()) in the last line of SocNavGraph initialization.
==> Verified that NaN values disappear :sunglasses:

1 Like

That worked! Thanks!