Is GlobalAttentionPooling supposed to work with heterographs?

I want to have a global readout on one type of nodes of heterograph, but I can’t make GlobalAttentionPooling work. Here is a simple repro.

Take this graph:

import dgl
import dgl.nn
import torch

g = dgl.heterograph({('a', 'a2b', 'b'): ([0, 1, 2], [0, 1, 0])})
g.nodes['a'].data['x'] = torch.ones((3, 2))
gap = dgl.nn.GlobalAttentionPooling(torch.nn.Linear(2, 1))

And try to apply

gap(g, g.nodes['a'].data['x'])

This raises an exception “AssertionError: Current HeteroNodeDataView has multiple node types, please passing the node type and the corresponding data through a dict.” In the glob.py I see that it tries to do

graph.ndata['gate'] = gate

which fails, as graph is heterograph. I then thought I’d extract a subgraph that has only nodes that I want with g.node_type_subgraph('a'), but that also fails, as there are no edges :frowning:

Any ideas how could I pool AttentionPooling off? (Loving the dgl.ai library, by the way!)

For the time being, you can do a workaround with.

gap(dgl.graph(([], []), num_nodes=g.num_nodes(ntype='a')), g.nodes['a'].data['x'])

I agree this is something DGL should fix. Perhaps the easiest fix will be to allow g.node_type_subgraph to return a graph with no edges.

1 Like

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.