batchX = torch.randn(size=(3000, 400)).float().to('cuda:1')
graph = torch.randn(size=(3000, 400)).float().to('cuda:1')
g = (graph!=0).int().to_sparse().indices().to('cuda:1')
g = dgl.graph((g[0],g[1]), num_nodes=3000, device='cuda:1')
edge_weight = graph[graph != 0].flatten().float().to('cuda:1')
model = GATConv(400, 128, num_heads=1).float().to('cuda:1')
model(g, batchX, edge_weight)
I notice after forward process, cuda:0 also has memory usage.
After debug into the forward function of GATConv, I notice that g.in_degrees() can cause this problem, even the graph is on cuda:1