Hi ,
I am a newbie to DGL . Is there anyone who can help me explain in details what the following codes snippet are doing. These codes are extracted from the softmax.py (line 60 ~ 85) in class EdgeSoftmax
Thanks a lot !
@staticmethod
def backward(ctx, grad_out):
“”"Backward function.
Pseudo-code:
.. code:: python
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions
return grad_score.data
"""
g = ctx.backward_cache
g = g.local_var()
out, = ctx.saved_tensors
# clear backward cache explicitly
ctx.backward_cache = None
g.edata['out'] = out
g.edata['grad_s'] = out * grad_out
g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum'))
g.apply_edges(fn.e_mul_v('out', 'accum', 'out'))
grad_score = g.edata['grad_s'] - g.edata['out']
return None, grad_score, None