I input the graph into GraphSAGE for feature extraction. Previously, when running on certain graphs, all model parameters will become NAN.
I’m sure there’s no NAN in the features input to the model
这个是我的模型:
from dgl.nn import SAGEConv
# ----------- 2. create model -------------- #
# build a two-layer GraphSAGE model
class GraphSAGE(nn.Module):
def __init__(self, in_feats, h_feats):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, 'mean',norm=lambda x: F.normalize(x))
self.bn1 = nn.BatchNorm1d(h_feats, eps=1e-5)
self.conv2 = SAGEConv(h_feats, h_feats, 'mean',norm=lambda x: F.normalize(x))
self.bn2 = nn.BatchNorm1d(h_feats, eps=1e-5)
self.conv3 = SAGEConv(h_feats, h_feats, 'mean',norm=lambda x: F.normalize(x))
self.bn3 = nn.BatchNorm1d(h_feats, eps=1e-5)
self.conv4 = SAGEConv(h_feats, 1, 'mean',norm=lambda x: F.normalize(x))
def forward(self, g, in_feat):
h = self.conv1(g, in_feat.float()
h = self.bn1(h)
h = F.relu(h)
h = self.conv2(g, h)
h = self.bn2(h)
h = F.relu(h)
h = self.conv3(g, h)
h = self.bn3(h)
h = F.relu(h)
h = self.conv4(g, h)
return h
my trainer:
def train(SAGE_model,PRED_grid, train_loader):
optimizer = torch.optim.Adam(SAGE_model.parameters(), lr=0.0001,weight_decay=5e-4)
for epoch in np.arange(200) + 1:
epoch_loss = 0
for batch_id, batch in tqdm(enumerate(train_loader)):
SAGE_model.train()
node_features = batch.ndata['feat'].float()
node_labels = batch.ndata['label'].float()
train_mask = batch.ndata["feat"][:,0]
train_mask = train_mask == 3
train_mask = train_mask.type(torch.bool)
PRED_grid_logits = SAGE_model(batch, node_features)
total_grid_positive = node_labels[train_mask].sum()
total_grid_negative = node_labels[train_mask].shape[0] - total_grid_positive
pos_grid_weight = total_grid_negative / total_grid_positive
grid_loss = F.binary_cross_entropy_with_logits(PRED_grid_logits[train_mask], node_labels[train_mask], pos_weight=pos_grid_weight)
epoch_loss = epoch_loss+grid_loss.item()
optimizer.zero_grad()
grid_loss.backward()
torch.nn.utils.clip_grad_norm_(SAGE_model.parameters(), 0.3)
torch.nn.utils.clip_grad_norm_(PRED_grid.parameters(), 0.3)
optimizer.step()
train(SAGE_model,PRED_grid,train_loader)