DGNConv is not working without eig_vec

I am trying to use DGNconv at DGNConv — DGL 0.9.0 documentation.

The layer claims that it does not need eig_vec when no directional aggregators is given. But the following example fails

import dgl
import torch as th
from dgl.nn import DGNConv
g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
feat = th.ones(6, 10)
conv = DGNConv(10, 10, ['sum'], ['identity'], 2.5)
ret = conv(g, feat)

Error message:

KeyError                                  Traceback (most recent call last)
Input In [1], in <cell line: 7>()
      5 feat = th.ones(6, 10)
      6 conv = DGNConv(10, 10, ['sum'], ['identity'], 2.5)
----> 7 ret = conv(g, feat)

File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/nn/pytorch/conv/dgnconv.py:207, in DGNConv.forward(self, graph, node_feat, edge_feat, eig_vec)
    205 if self.use_eig_vec:
    206     graph.ndata['eig'] = eig_vec
--> 207 return super().forward(graph, node_feat, edge_feat)

File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/nn/pytorch/conv/pnaconv.py:264, in PNAConv.forward(self, graph, node_feat, edge_feat)
    241 def forward(self, graph, node_feat, edge_feat=None):
    242     r"""
    243     Description
    244     -----------
   (...)
    262         should be the same as out_size.
    263     """
--> 264     h_cat = torch.cat([
    265         tower(
    266             graph,
    267             node_feat[:, ti * self.tower_in_size: (ti + 1) * self.tower_in_size],
    268             edge_feat
    269         )
    270         for ti, tower in enumerate(self.towers)
    271     ], dim=1)
    272     h_out = self.mixing_layer(h_cat)
    273     # add residual connection

File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/nn/pytorch/conv/pnaconv.py:265, in <listcomp>(.0)
    241 def forward(self, graph, node_feat, edge_feat=None):
    242     r"""
    243     Description
    244     -----------
   (...)
    262         should be the same as out_size.
    263     """
    264     h_cat = torch.cat([
--> 265         tower(
    266             graph,
    267             node_feat[:, ti * self.tower_in_size: (ti + 1) * self.tower_in_size],
    268             edge_feat
    269         )
    270         for ti, tower in enumerate(self.towers)
    271     ], dim=1)
    272     h_out = self.mixing_layer(h_cat)
    273     # add residual connection

File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/nn/pytorch/conv/pnaconv.py:126, in PNAConvTower.forward(self, graph, node_feat, edge_feat)
    123     assert edge_feat is not None, "Edge features must be provided."
    124     graph.edata['a'] = edge_feat
--> 126 graph.update_all(self.message, self.reduce_func)
    127 h = self.U(
    128     torch.cat([node_feat, graph.ndata['h_neigh']], dim=-1)
    129 )
    130 h = h * snorm_n

File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/heterograph.py:4895, in DGLHeteroGraph.update_all(self, message_func, reduce_func, apply_node_func, etype)
   4893 _, dtid = self._graph.metagraph.find_edge(etid)
   4894 g = self if etype is None else self[etype]
-> 4895 ndata = core.message_passing(g, message_func, reduce_func, apply_node_func)
   4896 if core.is_builtin(reduce_func) and reduce_func.name in ['min', 'max'] and ndata:
   4897     # Replace infinity with zero for isolated nodes
   4898     key = list(ndata.keys())[0]

File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/core.py:365, in message_passing(g, mfunc, rfunc, afunc)
    363 else:
    364     orig_eid = g.edata.get(EID, None)
--> 365     msgdata = invoke_edge_udf(g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid)
    366 # reduce phase
    367 if is_builtin(rfunc):

File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/core.py:85, in invoke_edge_udf(graph, eid, etype, func, orig_eid)
     82 dstdata = graph._node_frames[dtid].subframe(v)
     83 ebatch = EdgeBatch(graph, eid if orig_eid is None else orig_eid,
     84                    etype, srcdata, edata, dstdata)
---> 85 return func(ebatch)

File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/nn/pytorch/conv/dgnconv.py:40, in DGNConvTower.message(self, edges)
     38 else:
     39     f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1)
---> 40 return {'msg': self.M(f), 'eig_s': edges.src['eig'], 'eig_d': edges.dst['eig']}

File ~/opt/miniconda3/envs/gli/lib/python3.10/site-packages/dgl-0.9.0-py3.10-macosx-11.1-arm64.egg/dgl/frame.py:607, in Frame.__getitem__(self, name)
    594 def __getitem__(self, name):
    595     """Return the column of the given name.
    596 
    597     Parameters
   (...)
    605         Column data.
    606     """
--> 607     return self._columns[name].data

KeyError: 'eig'

Thanks for reporting it! I will fix it ASAP by resolving this issue.