I would like to write a custom sparse operator but have not been able to figure out how to do it. I simplified the code and would like to discuss two cases:
Here is my code:
import torch
import torch.nn as nn
from torch.autograd import Function
import dgl.sparse as dglsp
class SparseOpFunction(Function):
@staticmethod
def forward(ctx, sparse_matrix, weight):
ctx.save_for_backward(weight)
ctx.sparse_matrix = sparse_matrix
# Case 1:
out = dglsp.sddmm(sparse_matrix, weight, weight.t())
# Case 2:
# out = dglsp.spmm(sparse_matrix, weight)
return out
@staticmethod
def backward(ctx, grad_output):
# dummy backward function
grad_sparse_matrix = ctx.sparse_matrix
grad_weight = ctx.saved_tensors
return grad_sparse_matrix, grad_weight
class SparseOp(nn.Module):
def __init__(self):
super(SparseOp, self).__init__()
self.weight = nn.Parameter(torch.ones(2,4))
def forward(self, input):
return SparseOpFunction.apply(input, self.weight)
def main():
indices = torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]])
A = dglsp.spmatrix(indices, shape=(2, 2))
B = torch.ones(2,4)
C = torch.ones(4,2)
B.requires_grad = True
sparse_op = SparseOp()
out = dglsp.sddmm(A, B, C)
out = sparse_op(out)
loss = out.sum()
loss.backward()
if __name__ == "__main__":
main()
For Case 1, it shows the following errors during the backward pass:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
It seems that the computational graph is not being captured correctly.
For Case 2, it shows the following errors during the backward pass:
RuntimeError: function SparseOpFunctionBackward returned a gradient different than None at position 1, but the corresponding forward input was not a Variable
It appears that a gradient cannot be assigned to a SparseMatrix object.
Any ideas or suggestions would be greatly appreciated!