I fitted the GNN code in the Docs to work with GraphSAGE and everything works well up until the loss needs to be calculated. There I’m getting a Memory error with it requiring `1000+ GB memory`

for two tensors as:

pos_score.shape = torch.Size([266617, 1], neg_score.shape = torch.Size([1333085, 1],

I don’t understand why such tensors would need such a huge amount of memory for loss to be computed. Has anybody else come across this?

Graph Dimensions:

Graph(num_nodes={‘customer’: 74845, ‘product’: 1835},

num_edges={(‘customer’, ‘customerBoughtProducts’, ‘product’): 266617, (‘product’, ‘customerBoughtProductsBack’, ‘customer’): 266617},

metagraph=[(‘customer’, ‘product’, ‘customerBoughtProducts’), (‘product’, ‘customer’, ‘customerBoughtProductsBack’)])

GNN Code:

```
class HeteroDotProductPredictor(nn.Module):
def forward(self, graph, h, etype):
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
return graph.edges[etype].data['score']
class GraphSAGE(nn.Module):
def __init__(self, in_feats, hid_feats, rel_names):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({
rel: dglnn.SAGEConv(in_feats, hid_feats, 'mean')
for rel in rel_names}, aggregate='sum')
self.conv2 = dglnn.HeteroGraphConv({
rel: dglnn.SAGEConv(hid_feats, hid_feats, 'mean')
for rel in rel_names}, aggregate='sum')
def forward(self, graph, inputs):
h = self.conv1(graph, inputs)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv2(graph, h)
return h
def construct_negative_graph(graph, k, etype):
utype, _, vtype = etype
src, dst = graph.edges(etype=etype)
neg_src = src.repeat_interleave(k)
neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,)).to(device)
return dgl.heterograph(
{etype: (neg_src, neg_dst)},
num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})
def compute_loss(pos_score, neg_score):
n_edges = pos_score.shape[0]
return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean()
class Model(nn.Module):
def __init__(self, in_features, hidden_features, rel_names):
super().__init__()
self.sage = GraphSAGE(in_features, hidden_features, rel_names)
self.pred = HeteroDotProductPredictor()
def forward(self, g, neg_g, x, etype):
h = self.sage(g, x)
return self.pred(g, h, etype), self.pred(neg_g, h, etype)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
k = 5
graph = linkpred_graph.to(device)
model = Model(512, 1024, graph.etypes)
model.to(device)
user_feats = graph.nodes['customer'].data['embedding'].float()
item_feats = graph.nodes['product'].data['embedding'].float()
node_features = {'customer': user_feats, 'product': item_feats}
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
negative_graph = construct_negative_graph(graph, k, ('customer', 'customerBoughtProducts', 'product'))
pos_score, neg_score = model(graph, negative_graph, node_features, ('customer', 'customerBoughtProducts', 'product'))
loss = compute_loss(pos_score, neg_score)
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
```

Error Stacktrace:

```
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-9-15967fe0215e> in <module>
96 negative_graph = construct_negative_graph(graph, k, ('customer', 'customerBoughtProducts', 'product'))
97 pos_score, neg_score = model(graph, negative_graph, node_features, ('customer', 'customerBoughtProducts', 'product'))
---> 98 loss = compute_loss(pos_score, neg_score)
99 opt.zero_grad()
100 loss.backward()
<ipython-input-9-15967fe0215e> in compute_loss(pos_score, neg_score)
55 def compute_loss(pos_score, neg_score):
56 n_edges = pos_score.shape[0]
---> 57 return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean()
58
59
RuntimeError: CUDA out of memory. Tried to allocate 1324.05 GiB (GPU 0; 14.76 GiB total capacity; 2.12 GiB already allocated; 11.30 GiB free; 2.29 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
```