Unsupervised GraphSAGE(pool)model params become NAN

I’ve run the Unsupervised GraphSAGE code with 280 million edge and 350 thousand nodes sucessfully.
But when I switch the aggregation type from mean to pool, the loss become nan after several steps, then I find the reason is that the param become nan(input is same).

I tried out several methods to prevent this, but nan is still there.
weight decay: optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4)
clip grad:th.nn.utils.clip_grad_value_(model.parameters(), 0.1) (after loss.backward())
lr: 0.0005
methods above are useless.

So is there any solutions to sovle this problem, Thanks!

In addtion, the nan params are:
layers.0.fc_pool.weight is nan tensor([[ 5.5008e-21, 9.6237e-17, 9.9735e-19, …, 5.4157e-17,
1.0626e-16, -2.6394e-17],
[ nan, nan, nan, …, nan,
nan, nan],
[ 1.7808e-41, 1.7544e-41, 6.3507e-42, …, -1.4998e-41,
9.1533e-42, 1.5507e-41],
…,
[-6.5875e-42, 5.5646e-42, 7.1158e-42, …, 1.9212e-42,
-5.8448e-42, 1.3438e-41],
[-1.3378e-41, -1.5144e-41, -1.7052e-41, …, 7.5782e-42,
2.4520e-40, 1.0720e-42],
[ nan, nan, nan, …, nan,
nan, nan]])

layers.0.fc_pool.bias is nan tensor([-1.3928e-04, nan, -2.2153e-22, -2.6898e-02, -2.6840e-08,
-1.2008e-09, -2.3190e-16, nan, -1.4850e-04, nan,
nan, -5.0809e-13, nan, -2.9750e-06, nan,
nan, -2.1870e-03, nan, -5.4535e-06, nan,
nan, nan, -1.4432e-19, nan, -1.5761e-02,
nan, nan, -1.8105e-02, nan, -2.5433e-09,
nan, -6.3304e-28, -1.9594e-02, -4.9870e-11, -4.0917e-26,
-2.4668e-11, nan, nan, -2.4117e-04, nan,
-3.2424e-03, nan, -3.4410e-17, nan, -6.0637e-09,
-1.2643e-07, -5.8581e-23, -5.9553e-23, nan, nan,
-2.5109e-05, -8.8669e-10, -1.0904e-26, -1.1007e-07, -3.1032e-08,
nan, nan, -2.4066e-27, nan, -5.0381e-13,
-3.1144e-29, -5.2692e-03, -6.6346e-04, -1.8339e-27, nan,
nan, nan, -5.5579e-06, nan, nan,
-2.7867e-10, -1.0740e-29, -7.9432e-07, -4.5979e-06, -1.4128e-18,
-2.2100e-03, -8.8159e-32, nan, nan, nan,
-3.2046e-10, -4.0596e-10, nan, -1.1623e-15, nan,
-7.6139e-24, -1.1265e-20, nan, nan, -5.6108e-29,
nan, nan, -9.7590e-06, -3.0783e-03, nan,
nan, nan, nan, nan, nan,
nan, -1.9898e-29, nan, nan, -4.9342e-17,
nan, -1.0301e-30, nan, nan, nan,
nan, -5.7506e-09, -1.1353e-09, -1.7868e-02, nan,
nan, -2.0967e-02, nan, -7.2300e-19, nan,
-2.5128e-11, nan, nan, nan, nan,
-4.0450e-13, nan, -2.1093e-11, -4.1777e-26, nan,
-1.0727e-15, nan, -7.7383e-05, -4.5813e-17, -2.5254e-06,
-6.0359e-08, nan, nan, nan, nan,
nan, -9.1410e-07, nan, -5.0010e-07, nan,
nan, nan, nan, -3.8262e-07, -7.5550e-05,
-6.7823e-10, -5.7485e-12, -1.6971e-02, -1.0599e-10, -1.9904e-18,
nan, -4.8678e-28, -4.1970e-05, nan, nan,
nan, -1.2832e-02, nan, nan, -3.4204e-07,
-6.0900e-07, -5.1721e-23, -3.2729e-10, -2.0397e-04, nan,
nan, -1.5824e-18, nan, -1.5816e-02, nan,
-3.6280e-05, nan, nan, -2.3959e-05, nan,
-5.3371e-04, -3.3169e-25, -4.4591e-10, -2.4337e-27, -3.2034e-12,
nan, nan, nan, -6.6855e-12, -7.0434e-19,
-1.9926e-06, nan])

layers.0.fc_self.weight is nan tensor([[nan, nan, nan, …, nan, nan, nan],
[nan, nan, nan, …, nan, nan, nan],
[nan, nan, nan, …, nan, nan, nan],
…,
[nan, nan, nan, …, nan, nan, nan],
[nan, nan, nan, …, nan, nan, nan],
[nan, nan, nan, …, nan, nan, nan]])

layers.0.fc_self.bias is nan tensor([ nan, nan, nan, nan, nan,
nan, nan, nan, -6.7657e-03, nan,
nan, nan, nan, nan, -3.0758e-05,
nan, nan, nan, nan, nan,
nan, -1.9036e-22, nan, nan, nan,
nan, nan, nan, nan, -3.5856e-03,
nan, nan, -5.6798e-05, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, -2.6925e-25, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, -1.2760e-05, 4.0789e-03,
nan, nan, nan, nan])

layers.0.fc_neigh.weight is nan tensor([[nan, nan, nan, …, nan, nan, nan],
[nan, nan, nan, …, nan, nan, nan],
[nan, nan, nan, …, nan, nan, nan],
…,
[nan, nan, nan, …, nan, nan, nan],
[nan, nan, nan, …, nan, nan, nan],
[nan, nan, nan, …, nan, nan, nan]])

layers.0.fc_neigh.bias is nan tensor([ nan, nan, nan, nan, nan,
nan, nan, nan, -6.7657e-03, nan,
nan, nan, nan, nan, -3.0758e-05,
nan, nan, nan, nan, nan,
nan, -1.9036e-22, nan, nan, nan,
nan, nan, nan, nan, -3.5856e-03,
nan, nan, -5.6798e-05, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, -2.6925e-25, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, -1.2760e-05, 4.0789e-03,
nan, nan, nan, nan])

layers.1.fc_pool.weight is nan tensor([[nan, nan, nan, …, nan, nan, nan],
[nan, nan, nan, …, nan, nan, nan],
[nan, nan, nan, …, nan, nan, nan],
…,
[nan, nan, nan, …, nan, nan, nan],
[nan, nan, nan, …, nan, nan, nan],
[nan, nan, nan, …, nan, nan, nan]])

layers.1.fc_pool.bias is nan tensor([ nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan,
nan, -0.0493, nan, nan, nan, nan, nan, nan,
-0.0230, nan, nan, nan, nan, nan, nan, 0.0184,
nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan])

layers.1.fc_self.weight is nan tensor([[ nan, nan, nan, …, nan, nan, nan],
[ 0.0084, 0.0032, -0.0362, …, 0.0094, 0.0098, 0.0055],
[ 0.0123, 0.0018, -0.0065, …, 0.0173, 0.0008, 0.0177],
…,
[ 0.0232, 0.0058, -0.0073, …, 0.0335, -0.0012, 0.0330],
[-0.0161, -0.0029, 0.0074, …, -0.0226, -0.0005, -0.0228],
[ 0.0022, -0.0020, 0.0031, …, -0.0015, 0.0032, -0.0058]])

layers.1.fc_self.bias is nan tensor([ nan, 5.0541e-03, 2.9021e-03, -3.3436e-03, -6.8956e-04,
2.7058e-03, -2.0125e-02, -2.8778e-04, nan, -2.0819e-02,
-3.8774e-02, -9.5625e-04, -6.2541e-04, -3.2863e-03, -3.2434e-03,
-5.4319e-03, 2.2528e-02, -8.7812e-03, -6.1004e-04, 8.2608e-03,
3.3619e-03, 4.2159e-03, 3.9068e-03, 2.1328e-02, -3.3189e-03,
7.3474e-04, 5.0444e-03, -1.6041e-02, -2.8115e-03, 2.9139e-03,
3.5329e-03, 3.0742e-03, 2.0996e-03, -3.4317e-03, 2.6775e-04,
1.2546e-02, nan, 4.6222e-02, -3.1182e-03, 5.2070e-04,
-2.5719e-03, 2.7390e-03, -3.4861e-05, 3.1348e-03, -8.5339e-03,
1.3642e-03, -1.3749e-02, -2.4847e-02, 5.3479e-03, -2.9757e-02,
-7.7645e-03, -3.2868e-03, 3.0779e-03, 3.3113e-03, -1.0141e-03,
4.9990e-03, -7.2492e-04, -4.6203e-03, -3.2852e-03, 3.5352e-03,
-4.7092e-03, 4.1258e-03, -3.3914e-03, 5.9574e-04])

layers.1.fc_neigh.weight is nan tensor([[ nan, nan, nan, …, nan, nan, nan],
[-0.0044, 0.0052, -0.0022, …, -0.0061, 0.0145, 0.0378],
[-0.0021, -0.0020, -0.0016, …, -0.0007, -0.0013, 0.0117],
…,
[ 0.0020, -0.0038, -0.0035, …, -0.0021, -0.0036, 0.0217],
[ 0.0010, 0.0026, 0.0023, …, 0.0014, 0.0019, -0.0154],
[ 0.0027, -0.0009, -0.0011, …, 0.0005, -0.0003, -0.0017]])

layers.1.fc_neigh.bias is nan tensor([ nan, 5.0537e-03, 2.9021e-03, -3.3436e-03, -6.8939e-04,
2.7058e-03, -2.0061e-02, -2.8778e-04, nan, -2.0837e-02,
-4.0465e-02, -9.5625e-04, -6.1662e-04, -3.2863e-03, -3.2434e-03,
-5.4319e-03, 2.2536e-02, -8.7977e-03, -6.1004e-04, 8.2608e-03,
3.3619e-03, 4.2159e-03, 3.6595e-03, 2.1438e-02, -3.3189e-03,
7.3474e-04, 5.0444e-03, -1.6040e-02, -2.8115e-03, 2.9147e-03,
3.5329e-03, 3.0742e-03, 2.0996e-03, -3.4317e-03, 2.6775e-04,
1.2559e-02, nan, 4.6341e-02, -3.1182e-03, 5.2070e-04,
-2.5719e-03, 2.7390e-03, -3.4861e-05, 3.1348e-03, -8.5339e-03,
1.3642e-03, -1.3659e-02, -2.4820e-02, 5.3546e-03, -2.8776e-02,
-7.7644e-03, -3.2868e-03, 3.0779e-03, 3.3113e-03, -1.0141e-03,
4.9990e-03, -7.2492e-04, -4.6203e-03, -3.2852e-03, 3.5352e-03,
-4.7092e-03, 4.1258e-03, -3.3914e-03, 5.9574e-04])

By switching the aggregation type from mean to pool, the computation changes from

graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']

to

graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']

One possible interpretation is that the feature norm becomes very large with F.relu(self.fc_pool(feat_src)), where self.fc_pool is a linear layer. Have you set the norm argument in initializing SAGEConv? Applying feature normalization might help as the result of max pooling might have a too large norm.