I’m currently working with a DGL implementation of the Graph Attention Network, which was written using DGL version 0.4.3.post2. In the code, there’s a softmax normalization function, and I’m encountering an error. I’m relatively new to DGL, and I’m having difficulty understanding the error.
Here’s the relevant part of the code (updated):
def normalize(self, logits):
self._logits_name = "_logits"
self._normalizer_name = "_norm"
self.g.edata[self._logits_name] = logits
self.g.update_all(fn.copy_u(self._logits_name, self._logits_name),
fn.sum(self._logits_name, self._normalizer_name))
normalizer = self.g.ndata.pop(self._normalizer_name)
scores = self.g.edata.pop(self._logits_name)
return scores, normalizer
def edge_softmax(self):
if self.l0 == 0:
scores = self.softmax(self.g, self.g.edata.pop('a'))
else:
scores, normalizer = self.normalize(self.g.edata.pop('a'))
self.g.ndata['z'] = normalizer[:,0,:].unsqueeze(1)
self.g.edata['a'] = scores[:,0,:].unsqueeze(1)
And the python traceback:
self.g.update_all(fn.copy_u(self._logits_name, self._logits_name),
File "/root/venv/lib/python3.9/site-packages/dgl/heterograph.py", line 5110, in update_all
ndata = core.message_passing(
File "/root/venv/lib/python3.9/site-packages/dgl/core.py", line 398, in message_passing
ndata = invoke_gspmm(g, mfunc, rfunc)
File "/root/venv/lib/python3.9/site-packages/dgl/core.py", line 361, in invoke_gspmm
x = alldata[mfunc.target][mfunc.in_field]
File "/root/venv/lib/python3.9/site-packages/dgl/view.py", line 80, in __getitem__
return self._graph._get_n_repr(self._ntid, self._nodes)[key]
File "/root/venv/lib/python3.9/site-packages/dgl/frame.py", line 688, in __getitem__
return self._columns[name].data
KeyError: '_logits'
I get KeyError, but I’m having trouble figuring out why in this context. Can you help me understand what might be wrong?