When using GAT with a heterograph
graph = dgl.heterograph({('front', 'pos', 'back'): edges,
('back', 'pos', 'front'): rev_edges,
})
and use save_state
torch.save(model.state_dict(), 'xxx.pth')
to save the model, but when I load it with load_state_dict
, it returns the error:
RuntimeError: Error(s) in loading state_dict for NeoPolicy:
Missing key(s) in state_dict: "gcn.gnn_layers.0.attn_l", "gcn.gnn_layers.0.attn_r", "gcn.gnn_layers.0.fc.weight", "gcn.gnn_layers.1.attn_l", "gcn.gnn_layers.1.attn_r", "gcn.gnn_layers.1.fc.weight".
Unexpected key(s) in state_dict: "gcn.gnn_layers.0.mods.neg.attn_l", "gcn.gnn_layers.0.mods.neg.attn_r", "gcn.gnn_layers.0.mods.neg.fc_src.weight", "gcn.gnn_layers.0.mods.neg.fc_dst.weight", "gcn.gnn_layers.0.mods.pos.attn_l", "gcn.gnn_layers.0.mods.pos.attn_r", "gcn.gnn_layers.0.mods.pos.fc_src.weight", "gcn.gnn_layers.0.mods.pos.fc_dst.weight", "gcn.gnn_layers.1.mods.neg.attn_l", "gcn.gnn_layers.1.mods.neg.attn_r", "gcn.gnn_layers.1.mods.neg.fc_src.weight", "gcn.gnn_layers.1.mods.neg.fc_dst.weight", "gcn.gnn_layers.1.mods.pos.attn_l", "gcn.gnn_layers.1.mods.pos.attn_r", "gcn.gnn_layers.1.mods.pos.fc_src.weight", "gcn.gnn_layers.1.mods.pos.fc_dst.weight".
Do I need some special operation to save this kind of model?