HELP! NaN in matrix after HeteroGraphConv

I humbly request assistance from experts to enlighten me and broaden my perspective, so that I may gain true insight.

I’m encountering an issue with my DGL-based heterogeneous graph convolution model where NaNs start appearing in the matrices after the convolution during training, causing the loss to gradually become NaN as well.

The crux of the problem is that this issue doesn’t occur in the first patch. In fact, if I reduce the learning rate, the NaNs appear later in the training process (which means that the training is completely normal for the first few batches). If the issue appeared right from the first patch, I would consider changing the model altogether. However, it does run correctly initially, but as training progresses, large areas of NaNs start appearing in the matrices, and eventually, the loss becomes NaN, preventing convergence.

I’ve checked, and the input data does not contain NaNs. The NaNs specifically appear at self.conv1. I have never encountered such an issue before and would like to ask if you have seen this type of problem before. Below is my model code and training code:

import torch as th
import torch.nn as nn
from torch.utils.data import DataLoader
from dgl.nn import HeteroGraphConv, GraphConv
import torch.nn.functional as F

class HeteroGNN(nn.Module):
def init(self, node_in_feats,instance_in_feats, hidden_size, out_size):
super(HeteroGNN, self).init()
self.conv1 = HeteroGraphConv({
‘node2instance’: dgl.nn.GraphConv(node_in_feats, hidden_size),
‘instance2instance’: dgl.nn.GraphConv(instance_in_feats, hidden_size)
}, aggregate=‘sum’)
self.conv2 = HeteroGraphConv({
‘node2instance’: dgl.nn.GraphConv(hidden_size, out_size),
‘instance2instance’: dgl.nn.GraphConv(hidden_size, out_size)
}, aggregate=‘sum’)
self.aligner = nn.Linear(node_in_feats,out_size)

def forward(self, g, h):        
    ans = self.conv1(g, h)
    # print('HeteroGNN第一步instance:',th.isnan(ans['instance']).any())

    ans = {k: F.relu(v) for k, v in ans.items()}
    # print('relu后instance:',ans['instance'])  

    ans = self.conv2(g, ans)

    ans['node'] = self.aligner(h['node'])
    return ans

class HeteroGraphModel(nn.Module):
def init(self, node_in_feats,instance_in_feats, hidden_feats1, hidden_feats2):
super(HeteroGraphModel, self).init()
self.gnn = HeteroGNN(node_in_feats,instance_in_feats, hidden_feats1, hidden_feats2)
self.type_classifier = nn.Linear(hidden_feats2, 1)

def forward(self, g, inputs):
    h = self.gnn(g, inputs)  # 返回的是一个字典,字典的元素理应能够对齐
    
    instance_feats = h['instance']
    node_feats = h['node']
    
    all_feats = th.cat([instance_feats, node_feats], dim=0)
    #print('all_feats:',all_feats.shape)

    logits = self.type_classifier(all_feats)
    logits = logits.squeeze(1)
            
    # probabilities = th.sigmoid(logits)
    #print('probabilities:',probabilities.shape)
    return logits

import torch.optim as optim

def train(model, dataloader, criterion, optimizer, num_epochs):
for epoch in range(num_epochs):
model.train()
total_loss = 0

    for g, labels in dataloader:
        # print('=====================================================================================================================')
        optimizer.zero_grad()
        # print('instance_metric:',th.isnan(g.nodes['instance'].data['metric']).any())
        # print('node_metric:',th.isnan(g.nodes['node'].data['metric']).any())
        outputs = model(g, {'instance': g.nodes['instance'].data['metric'], 'node': g.nodes['node'].data['metric']})
        
        # print('labels:',labels)
        # print('outputs:',outputs)
        if th.isnan(outputs[0]).any():  #------------------------------------
            continue
        loss = criterion(outputs, labels)
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()
    
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

def evaluate(model, dataloader, criterion):
model.eval()
total_loss = 0
with th.no_grad():
for g, labels in dataloader:
outputs = model(g, {‘instance’: g.nodes[‘instance’].data[‘metric’], ‘node’: g.nodes[‘node’].data[‘metric’]})

        loss = criterion(outputs, labels)
        total_loss += loss.item()

avg_loss = total_loss / len(dataloader)
print(f'Evaluation Loss: {avg_loss:.4f}')
return avg_loss

device = th.device(‘cpu’)
print(‘device:’,device)

node_in_feats = test_data[0][0].nodes[‘node’].data[‘metric’].shape[1]
instance_in_feats = test_data[0][0].nodes[‘instance’].data[‘metric’].shape[1]
model = HeteroGraphModel(node_in_feats,instance_in_feats, hidden_feats1=20, hidden_feats2=5).to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.000001)

def collate_fn(batch):
graphs, labels = map(list, zip(*batch))
batched_graph = dgl.batch(graphs)
labels = th.stack(labels, dim=0).squeeze()
return batched_graph, labels

dataloader = DataLoader(test_data, batch_size=1, shuffle=False, collate_fn=collate_fn)

num_epochs = 20
train(model, dataloader, criterion, optimizer, num_epochs)
evaluate(model, dataloader, criterion)

@Molujia , I suggest you open a issue on the dgl github repo with you code as a file and repro process. It will help developer or others to diagnose and fix the issue.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.