How to write a custom sparse operator with dglsp and autograd

I would like to write a custom sparse operator but have not been able to figure out how to do it. I simplified the code and would like to discuss two cases:

Here is my code:

import torch
import torch.nn as nn
from torch.autograd import Function
import dgl.sparse as dglsp

class SparseOpFunction(Function):
    @staticmethod
    def forward(ctx, sparse_matrix, weight):
        ctx.save_for_backward(weight)
        ctx.sparse_matrix = sparse_matrix
        # Case 1:
        out = dglsp.sddmm(sparse_matrix, weight, weight.t())
        # Case 2:
        # out = dglsp.spmm(sparse_matrix, weight)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        # dummy backward function
        grad_sparse_matrix = ctx.sparse_matrix
        grad_weight = ctx.saved_tensors
        return grad_sparse_matrix, grad_weight

class SparseOp(nn.Module):
    def __init__(self):
        super(SparseOp, self).__init__()
        self.weight = nn.Parameter(torch.ones(2,4))

    def forward(self, input):
        return SparseOpFunction.apply(input, self.weight)

def main():
    indices = torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]])
    A = dglsp.spmatrix(indices, shape=(2, 2))
    B = torch.ones(2,4)
    C = torch.ones(4,2)
    B.requires_grad = True
    sparse_op = SparseOp()

    out = dglsp.sddmm(A, B, C)
    out = sparse_op(out)
    loss = out.sum()
    loss.backward()

if __name__ == "__main__":
    main()

For Case 1, it shows the following errors during the backward pass:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

It seems that the computational graph is not being captured correctly.

For Case 2, it shows the following errors during the backward pass:

RuntimeError: function SparseOpFunctionBackward returned a gradient different than None at position 1, but the corresponding forward input was not a Variable

It appears that a gradient cannot be assigned to a SparseMatrix object.

Any ideas or suggestions would be greatly appreciated!

Hi @junliang-lin,

The structure of sparse matrix doesn’t have gradients but only the value tensor of a sparse matrix has gradient. To support autograd in your case, you can pass the value tensor explicit in the autograd function.

class SparseOpFunction(Function):
    @staticmethod
    def forward(ctx, sparse_matrix, value, weight):
        ctx.save_for_backward(weight)
        sparse_matrix = sparse_matrix.val_like(value)
        out = dglsp.sddmm(sparse_matrix, weight, weight.t())
        return out

    @staticmethod
    def backward(ctx, grad_output):
        # dummy backward function
        grad_sparse_matrix = ctx.sparse_matrix
        grad_weight = ctx.saved_tensors
        return None, grad_sparse_matrix.value, grad_weight

Hi @czkkkkkk,

Thanks for your reply. This solution looks promising. However, it still says the sparsematrix does not require grad and does not have a grad_fn. Therefore, I slightly modified it as follows:

import torch
import torch.nn as nn
from torch.autograd import Function
import dgl.sparse as dglsp

class SparseOpFunction(Function):
    @staticmethod
    def forward(ctx, sparse_matrix, value, weight):
        ctx.save_for_backward(weight, value)
        ctx.sparse_matrix = sparse_matrix

        sparse_matrix = dglsp.val_like(sparse_matrix, value)
        out = dglsp.sddmm(sparse_matrix, weight, weight.t())

        return out.val

    @staticmethod
    def backward(ctx, grad_output):
        # dummy backward function
        grad_weight, grad_value = ctx.saved_tensors
        grad_sparse_matrix = ctx.sparse_matrix

        return None, grad_value, grad_weight

class SparseOp(nn.Module):
    def __init__(self):
        super(SparseOp, self).__init__()
        self.weight = nn.Parameter(torch.ones(2,4))

    def forward(self, sparse_matrix, value):
        return SparseOpFunction.apply(sparse_matrix, value, self.weight)

def main():
    indices = torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]])
    A = dglsp.spmatrix(indices, shape=(2, 2))
    B = torch.ones(2,4)
    C = torch.ones(4,2)
    B.requires_grad = True
    sparse_op = SparseOp()

    D = dglsp.sddmm(A, B, C)
    E_val = sparse_op(D, D.val)
    out = dglsp.val_like(A, E_val)
    loss = out.sum()
    loss.backward()

if __name__ == "__main__":
    main()

The custom sparse operator only outputs the value tensor (E_val), and we reassemble the output matrix with the input matrix and the value tensor. It works now, and I appreciate your help!

1 Like