Codes explanation

Hi ,
I am a newbie to DGL . Is there anyone who can help me explain in details what the following codes snippet are doing. These codes are extracted from the softmax.py (line 60 ~ 85) in class EdgeSoftmax
Thanks a lot !

@staticmethod
def backward(ctx, grad_out):
“”"Backward function.

    Pseudo-code:

    .. code:: python

        g, out = ctx.backward_cache
        grad_out = dgl.EData(g, grad_out)
        out = dgl.EData(g, out)
        sds = out * grad_out  # type dgl.EData
        sds_sum = sds.dst_sum()  # type dgl.NData
        grad_score = sds - sds * sds_sum  # multiple expressions
        return grad_score.data
    """
    g = ctx.backward_cache
    g = g.local_var()
    out, = ctx.saved_tensors
    # clear backward cache explicitly
    ctx.backward_cache = None
    g.edata['out'] = out
    g.edata['grad_s'] = out * grad_out
    g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum'))
    g.apply_edges(fn.e_mul_v('out', 'accum', 'out'))
    grad_score = g.edata['grad_s'] - g.edata['out']
    return None, grad_score, None

This is the backward function of edge softmax.
What the forward function in edge softmax did is

and the goal of backward function is to derive d z_ij from d a_ij:

Why do we explicitly write the backward function? The answer is to speed-up and save GPU memory. However, you do not need to care about the details, knowing what it does and how to call it is enough.

Thanks for detailed explanation, really appreciate it !

Thanks again for your detailed explanation for my first query of codes explanation. Could you please also help me take a look at the second query of codes explanation ? This query was created 3 days ago with title “Codes explanation 2”. Thanks a lot