"The input graph for the user-defined edge function does not contain valid edges"

I made a GNN with DGL, but when I move my model and data to the GPU, I get a DGL warning that says

/usr/local/lib/python3.7/site-packages/dgl/base.py:45: DGLWarning: The input graph for the user-defined edge function does not contain valid edges
return warnings.warn(message, category=category, stacklevel=1)

And an error that says:

RuntimeError: CUDA error: invalid configuration argument

I think the source of the problem is this model I use as my DGL message_func:

    class MessageFunc(torch.nn.Module):
    def __init__(self, embed_size, hidden_size):
        super().__init__()
        
        self.rnn = torch.nn.LSTM(embed_size * 2, hidden_size)

    def forward(self, edges):
        node_reps, edge_reps = edges.src["rep"], edges.data["rep"]
        inputs = torch.cat([node_reps, edge_reps], 1)
        inputs = inputs.unsqueeze(0)
        rnn_state = edges.src["sum_incoming_state"]
        hiddens, cells = rnn_state[:, 0].contiguous(), rnn_state[:, 1].contiguous()
        hiddens, cells = hiddens.unsqueeze(0), cells.unsqueeze(0)
        outputs, (new_hiddens, new_cells) = self.rnn(inputs, (hiddens, cells))
        new_hiddens, new_cells = new_hiddens.squeeze(0), new_cells.squeeze(0)
        new_rnn_state = torch.stack([new_hiddens, new_cells], 1)
        new_data = { "state": new_rnn_state }
        return new_data

What could I be doing wrong and how can I fix it?

Can you provide a code snippet for reproducing the issue? This should include a toy graph as well as the lines of code for invoking the user-defined function.

I found that the problem occurs when I input a graph with no nodes or edges. This somehow breaks CUDA and I cannot create any new tensors in my program after this, as they all cause that CUDA error. With MessageFunc defined as above and an MPN defined like this:

class MessagePassingNet(torch.nn.Module):
    def __init__(self, hidden_size,
                 embed_size):
        super().__init__()
        
        self.make_msgs = MessageFunc(embed_size, hidden_size)
        self.reduce_msgs = dgl.function.sum("state", "sum_incoming_state")

        self.hidden_size = hidden_size
        
    def forward(self, graph):
        graph.nodes["a"].data["sum_incoming_state"] = (
            torch.randn(graph.number_of_nodes("a"), 2, self.hidden_size, device = graph.device)
        )
        
        graph.update_all(self.make_msgs, self.reduce_msgs, etype = "b")
        return graph

with a graph defined like this:

a = dgl.heterograph({("a", "b", "a"):[]}, device = device)
a.nodes["a"].data["rep"] = torch.zeros(0,10, device = device)
a.edges["b"].data["rep"] = torch.zeros(0,10, device = device)

I was able to avoid this problem by checking if a graph has edges before passing it through the MPN, but it took me a long time to figure this simple solution out. I’m not sure why it works fine on CPU, but breaks completely on GPU. I don’t know if this is purely a problem with how the model uses torch.rnn, or if there is a way that DGL could help users avoid this case better.