Error loading GAT for Heter graph

When using GAT with a heterograph

graph = dgl.heterograph({('front', 'pos', 'back'): edges,
                                 ('back', 'pos', 'front'): rev_edges,
                                })

and use save_state

torch.save(model.state_dict(), 'xxx.pth')

to save the model, but when I load it with load_state_dict, it returns the error:

RuntimeError: Error(s) in loading state_dict for NeoPolicy:
	Missing key(s) in state_dict: "gcn.gnn_layers.0.attn_l", "gcn.gnn_layers.0.attn_r", "gcn.gnn_layers.0.fc.weight", "gcn.gnn_layers.1.attn_l", "gcn.gnn_layers.1.attn_r", "gcn.gnn_layers.1.fc.weight". 
	Unexpected key(s) in state_dict: "gcn.gnn_layers.0.mods.neg.attn_l", "gcn.gnn_layers.0.mods.neg.attn_r", "gcn.gnn_layers.0.mods.neg.fc_src.weight", "gcn.gnn_layers.0.mods.neg.fc_dst.weight", "gcn.gnn_layers.0.mods.pos.attn_l", "gcn.gnn_layers.0.mods.pos.attn_r", "gcn.gnn_layers.0.mods.pos.fc_src.weight", "gcn.gnn_layers.0.mods.pos.fc_dst.weight", "gcn.gnn_layers.1.mods.neg.attn_l", "gcn.gnn_layers.1.mods.neg.attn_r", "gcn.gnn_layers.1.mods.neg.fc_src.weight", "gcn.gnn_layers.1.mods.neg.fc_dst.weight", "gcn.gnn_layers.1.mods.pos.attn_l", "gcn.gnn_layers.1.mods.pos.attn_r", "gcn.gnn_layers.1.mods.pos.fc_src.weight", "gcn.gnn_layers.1.mods.pos.fc_dst.weight". 

Do I need some special operation to save this kind of model?

Hi, could you provide an example code snippet for reproducing the error?

After testing I just find that I use a different configuration.
But I got another quesiton here. As the GATConv support bipartite, it can be used as

u, v = [0,1,2,3], [4,5,6,7] 
graph = dgl.DGLGraph((u, v))
gat = GATConv(in_feats=(6, 6), out_feats=6, num_heads=1)
features = torch.randn((8,6))
x = gat(graph, (features, features))

But as I understand the last line should be feature for in and out

x = gat(graph, (features[:4,:], features[4:,:]))

Is there any problem? Thanks

Your code looks correct to me.

But if I use x = gat(graph, (features[:4,:], features[4:,:])) , I got

DGLError: Expect number of features to match number of nodes (len(u)). Got 4 and 8 instead.

I see. This is because your graph is not in the form of a bipartite graph but a general homogeneous graph. In this case you will need to use

x = gat(graph, features)

Alternatively you can do

graph = dgl.bipartite((u, v))
...
x = gat(graph, (features[:4,:], features[4:,:]))

I tried this, but got the same error.

u, v = [0,1,2,3], [4,5,6,7] 
graph = dgl.bipartite((u, v))
gat = GATConv(in_feats=(6, 6), out_feats=6, num_heads=1)
features = torch.randn((8,6))
x = gat(graph, (features[:4,:], features[4:,:]))

I see what’s happening. The ID spaces of source nodes and destination nodes are separate. As a result, with [4,5,6,7] , there will be 8 destination nodes. To address it, simply do

u, v = [0, 1, 2, 3], [0, 1, 2, 3] 
graph = dgl.bipartite((u, v))
gat = GATConv(in_feats=(6, 6), out_feats=6, num_heads=1)
features = torch.randn((8,6))
x = gat(graph, (features[:4,:], features[4:,:]))
1 Like

Thank you, it works!

One last question:
Both the bipartite API for GAT and the HeteroGraphConv API can work for bipartite graph, is there any differences?

E.g., If I use

dgl.nn.HeteroGraphConv({'UV': dgl.nn.GATConv((6, 6), 6, 1), 
                        'VU': dgl.nn.GATConv((6, 6), 6, 1)})

It’s the same with aforementioned implementation?

There can be some subtle differences:

  1. dgl.nn.HeteroGraphConv allows you to specify different GNNConv for different edge types, e.g. SAGEConv for type UV and GATConv for type VU.
  2. With your code snippet you actually created two GATConv with different parameters.

If you want to use a same GATConv with parameters shared across edge types, then GATConv will be enough.

It means, if I share the parameter in HeteroGraphConv they are the same?

If you use a single GATConv, then it will be the same whether you use HeteroGraphConv or not.

1 Like