Backward function of GSpMM

Hi there,
Could you please explain backward function of GSpMM? dgl/sparse.py at master · dmlc/dgl · GitHub

I am not understanding several things.

  1. Why input graph is reversed?
  2. If I use only sum as aggregation in the SpMM then what would be the corresponding backward operation and why?
  3. Why SDDMM is used in the backward function of GSpMM?

It would be great if you kindly provide a detailed description about it. It will be really helpful to understand.

Best,
Khaled

You could find more explanations in our paper (Appendix B).

1 Like

Hi, I was going through the Lemmas in the appendix. I thought it might be a good idea to ask here.

In the attached figure of Lemma 4, dL/dy_v does not have any reverse graph whereas dL/dx_u has reverse edges. Could you please explain why is this case? Does it consider that there is an edge from u to v as drawn? I assume while taking p.derivative with respect to x_u, which is the representation of node u, it performs this operation for all outgoing edges. But why is that? Shouldn’t it be the opposite? Please, explain a bit. It will be very helpful to understand.

Let’s use a simpler case. Suppose the message function \phi_z computes the message as m_e = x_u and the reduce function \rho sums all the received messages x_v^{new}=\sum_{(u, e, v)\in\mathcal{E}}m_e. You could verify that the computation is equivalent as multiplying the node embedding matrix X with the adjacency matrix A:

X^{new} = AX

Then, given the gradient \frac{\partial L}{\partial X^{new}}, the gradient of X is:

\frac{\partial L}{\partial X}=A^{\top}\frac{\partial L}{\partial X^{new}}

The transposition of A is equivalent to the reverse of a graph.

Another way to understand it is that according to chain rule multivariable case, the gradient of a variable is the sum of the partial derivative of each component. For a graph, an x_u is used by all its out-going edges during forward propagation, so the gradients need to be summed over the reversing edges.

1 Like

Thank you so much. This is really helpful to understand.