I am trying to use DGNconv at DGNConv — DGL 0.9.0 documentation.
The layer claims that it does not need eig_vec when no directional aggregators is given. But the following example fails
import dgl
import torch as th
from dgl.nn import DGNConv
g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
feat = th.ones(6, 10)
conv = DGNConv(10, 10, ['sum'], ['identity'], 2.5)
ret = conv(g, feat)
Error message:
KeyError Traceback (most recent call last)
Input In [1], in <cell line: 7>()
5 feat = th.ones(6, 10)
6 conv = DGNConv(10, 10, ['sum'], ['identity'], 2.5)
----> 7 ret = conv(g, feat)
File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/nn/pytorch/conv/dgnconv.py:207, in DGNConv.forward(self, graph, node_feat, edge_feat, eig_vec)
205 if self.use_eig_vec:
206 graph.ndata['eig'] = eig_vec
--> 207 return super().forward(graph, node_feat, edge_feat)
File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/nn/pytorch/conv/pnaconv.py:264, in PNAConv.forward(self, graph, node_feat, edge_feat)
241 def forward(self, graph, node_feat, edge_feat=None):
242 r"""
243 Description
244 -----------
(...)
262 should be the same as out_size.
263 """
--> 264 h_cat = torch.cat([
265 tower(
266 graph,
267 node_feat[:, ti * self.tower_in_size: (ti + 1) * self.tower_in_size],
268 edge_feat
269 )
270 for ti, tower in enumerate(self.towers)
271 ], dim=1)
272 h_out = self.mixing_layer(h_cat)
273 # add residual connection
File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/nn/pytorch/conv/pnaconv.py:265, in <listcomp>(.0)
241 def forward(self, graph, node_feat, edge_feat=None):
242 r"""
243 Description
244 -----------
(...)
262 should be the same as out_size.
263 """
264 h_cat = torch.cat([
--> 265 tower(
266 graph,
267 node_feat[:, ti * self.tower_in_size: (ti + 1) * self.tower_in_size],
268 edge_feat
269 )
270 for ti, tower in enumerate(self.towers)
271 ], dim=1)
272 h_out = self.mixing_layer(h_cat)
273 # add residual connection
File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/nn/pytorch/conv/pnaconv.py:126, in PNAConvTower.forward(self, graph, node_feat, edge_feat)
123 assert edge_feat is not None, "Edge features must be provided."
124 graph.edata['a'] = edge_feat
--> 126 graph.update_all(self.message, self.reduce_func)
127 h = self.U(
128 torch.cat([node_feat, graph.ndata['h_neigh']], dim=-1)
129 )
130 h = h * snorm_n
File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/heterograph.py:4895, in DGLHeteroGraph.update_all(self, message_func, reduce_func, apply_node_func, etype)
4893 _, dtid = self._graph.metagraph.find_edge(etid)
4894 g = self if etype is None else self[etype]
-> 4895 ndata = core.message_passing(g, message_func, reduce_func, apply_node_func)
4896 if core.is_builtin(reduce_func) and reduce_func.name in ['min', 'max'] and ndata:
4897 # Replace infinity with zero for isolated nodes
4898 key = list(ndata.keys())[0]
File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/core.py:365, in message_passing(g, mfunc, rfunc, afunc)
363 else:
364 orig_eid = g.edata.get(EID, None)
--> 365 msgdata = invoke_edge_udf(g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid)
366 # reduce phase
367 if is_builtin(rfunc):
File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/core.py:85, in invoke_edge_udf(graph, eid, etype, func, orig_eid)
82 dstdata = graph._node_frames[dtid].subframe(v)
83 ebatch = EdgeBatch(graph, eid if orig_eid is None else orig_eid,
84 etype, srcdata, edata, dstdata)
---> 85 return func(ebatch)
File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/nn/pytorch/conv/dgnconv.py:40, in DGNConvTower.message(self, edges)
38 else:
39 f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1)
---> 40 return {'msg': self.M(f), 'eig_s': edges.src['eig'], 'eig_d': edges.dst['eig']}
File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/frame.py:607, in Frame.__getitem__(self, name)
594 def __getitem__(self, name):
595 """Return the column of the given name.
596
597 Parameters
(...)
605 Column data.
606 """
--> 607 return self._columns[name].data
KeyError: 'eig'