Sorry I should have given more context. I am creating graphs on temporal sets. I am using the following code https://github.com/yule-BUAA/DNNTSP/blob/master/model/weighted_graph_conv.py
This is where the code breaks in forward pass.
def forward(self, graph, node_features, edge_weights):
r"""Compute weighted graph convolution.
-----
Input:
graph : DGLGraph, batched graph.
node_features : torch.Tensor, input features for nodes (n_1+n_2+..., in_features) or (n_1+n_2+..., T, in_features)
edge_weights : torch.Tensor, input weights for edges (T, n_1^2+n_2^2+..., n^2)
Output:
shape: (N, T, out_features)
"""
graph = graph.local_var()
# multi W first to project the features, with bias
# (N, F) / (N, T, F)
graph.ndata['n'] = node_features
# edge_weights, shape (T, N^2)
graph.edata['e'] = edge_weights.t()
graph.send(graph.edges(), self.gcn_message)
graph.recv(graph.nodes(), self.gcn_reduce)
node_features = graph.ndata.pop('h')
output = self.linear(node_features)
return output
@staticmethod
def gcn_message(edges: dgl.EdgeBatch):
if edges.src['n'].dim() == 2:
# (B, T, 1) (B, 1, F), matmul -> matmul (B, T, F)
return {'msg': torch.matmul(edges.data['e'].unsqueeze(dim=-1), edges.src['n'].unsqueeze(dim=1))}
elif edges.src['n'].dim() == 3:
# (B, T, 1) (B, T, F), mul -> (B, T, F)
return {'msg': torch.mul(edges.data['e'].unsqueeze(dim=-1), edges.src['n'])}
else:
raise ValueError(f"wrong shape for edges.src['n'], the length of shape is {edges.src['n'].dim()}")
@staticmethod
def gcn_reduce(nodes: dgl.NodeBatch):
# propagate, the first dimension is batch
# h, tensor, shape, (B, N, T, F) -> (B, T, F)
return {'h': torch.sum(nodes.mailbox['msg'], 1)}
The code fails at graph.send(graph.edges(), self.gcn_message)
with CUDA Out of memory error.
Now for the inputs in the forward
:
g
is the DGL graph with 100-1000 nodes per datapoint. The code fails even with batch size 1.
Say in my current forward pass:
g
has 600 nodes, N=600
edge_weights
have the size ntsx(N*N)
, where nts
is the number of temporal sets. nts
is in the order of 10-100. taking nts
value as 80
here. Then
edge_weights
is 80x360000
~ 25 million and is sparse
node_features
is 600x32
Do I need to change/optimize any of the functions here?