Porting dgl code: KeyError: '_logits'

I’m currently working with a DGL implementation of the Graph Attention Network, which was written using DGL version 0.4.3.post2. In the code, there’s a softmax normalization function, and I’m encountering an error. I’m relatively new to DGL, and I’m having difficulty understanding the error.

Here’s the relevant part of the code (updated):

def normalize(self, logits):
        self._logits_name = "_logits"
        self._normalizer_name = "_norm"
        self.g.edata[self._logits_name] = logits

        self.g.update_all(fn.copy_u(self._logits_name, self._logits_name),
                          fn.sum(self._logits_name, self._normalizer_name))

        normalizer = self.g.ndata.pop(self._normalizer_name)
        scores = self.g.edata.pop(self._logits_name)
        return scores, normalizer

    def edge_softmax(self):

        if self.l0 == 0:
            scores = self.softmax(self.g, self.g.edata.pop('a'))
        else:
            scores, normalizer = self.normalize(self.g.edata.pop('a'))
            self.g.ndata['z'] = normalizer[:,0,:].unsqueeze(1)

        self.g.edata['a'] = scores[:,0,:].unsqueeze(1)

And the python traceback:

    self.g.update_all(fn.copy_u(self._logits_name, self._logits_name),
  File "/root/venv/lib/python3.9/site-packages/dgl/heterograph.py", line 5110, in update_all
    ndata = core.message_passing(
  File "/root/venv/lib/python3.9/site-packages/dgl/core.py", line 398, in message_passing
    ndata = invoke_gspmm(g, mfunc, rfunc)
  File "/root/venv/lib/python3.9/site-packages/dgl/core.py", line 361, in invoke_gspmm
    x = alldata[mfunc.target][mfunc.in_field]
  File "/root/venv/lib/python3.9/site-packages/dgl/view.py", line 80, in __getitem__
    return self._graph._get_n_repr(self._ntid, self._nodes)[key]
  File "/root/venv/lib/python3.9/site-packages/dgl/frame.py", line 688, in __getitem__
    return self._columns[name].data
KeyError: '_logits'

I get KeyError, but I’m having trouble figuring out why in this context. Can you help me understand what might be wrong?

Since your _logits is set on edata, i.e. on the edges rather than the nodes, you will need copy_e instead of copy_u.

1 Like