I want to have a global readout on one type of nodes of heterograph, but I can’t make GlobalAttentionPooling work. Here is a simple repro.
Take this graph:
import dgl
import dgl.nn
import torch
g = dgl.heterograph({('a', 'a2b', 'b'): ([0, 1, 2], [0, 1, 0])})
g.nodes['a'].data['x'] = torch.ones((3, 2))
gap = dgl.nn.GlobalAttentionPooling(torch.nn.Linear(2, 1))
And try to apply
gap(g, g.nodes['a'].data['x'])
This raises an exception “AssertionError: Current HeteroNodeDataView has multiple node types, please passing the node type and the corresponding data through a dict.” In the glob.py I see that it tries to do
graph.ndata['gate'] = gate
which fails, as graph is heterograph. I then thought I’d extract a subgraph that has only nodes that I want with g.node_type_subgraph('a')
, but that also fails, as there are no edges
Any ideas how could I pool AttentionPooling off? (Loving the dgl.ai library, by the way!)