Error when train DistSAGE with OGB-Product

Hi, I got some bugs when training OGB-Product in an 8-node cluster.

If I set a small batch size, i.e., 1024. I found that the last worker would finish the first epoch with a correct output. However, all else workers will get the following error message:

 ** On entry to cusparseCreateCsr() dimension mismatch: nnz > rows * cols

Traceback (most recent call last):
  File "baseline/graphsage/train_dist_sage.py", line 392, in <module>
    main(args)
  File "baseline/graphsage/train_dist_sage.py", line 349, in main
    run(args, device, data)
  File "baseline/graphsage/train_dist_sage.py", line 250, in run
    batch_pred = model(blocks, batch_inputs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/distributed.py", line 705, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "baseline/graphsage/train_dist_sage.py", line 88, in forward
    h = layer(block, h)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/dgl/nn/pytorch/conv/sageconv.py", line 231, in forward
    graph.update_all(msg_fn, fn.mean('m', 'neigh'))
  File "/usr/local/lib/python3.6/dist-packages/dgl/heterograph.py", line 4686, in update_all
    ndata = core.message_passing(g, message_func, reduce_func, apply_node_func)
  File "/usr/local/lib/python3.6/dist-packages/dgl/core.py", line 283, in message_passing
    ndata = invoke_gspmm(g, mfunc, rfunc)
  File "/usr/local/lib/python3.6/dist-packages/dgl/core.py", line 258, in invoke_gspmm
    z = op(graph, x)
  File "/usr/local/lib/python3.6/dist-packages/dgl/ops/spmm.py", line 170, in func
    return gspmm(g, 'copy_lhs', reduce_op, x, None)
  File "/usr/local/lib/python3.6/dist-packages/dgl/ops/spmm.py", line 64, in gspmm
    lhs_data, rhs_data)
  File "/usr/local/lib/python3.6/dist-packages/dgl/backend/pytorch/sparse.py", line 307, in gspmm
    return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
  File "/usr/local/lib/python3.6/dist-packages/torch/cuda/amp/autocast_mode.py", line 217, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/dgl/backend/pytorch/sparse.py", line 87, in forward
    out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
  File "/usr/local/lib/python3.6/dist-packages/dgl/sparse.py", line 162, in _gspmm
    arg_e_nd)
  File "/usr/local/lib/python3.6/dist-packages/dgl/_ffi/_ctypes/function.py", line 190, in __call__
    ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
  File "/usr/local/lib/python3.6/dist-packages/dgl/_ffi/base.py", line 64, in check_call
    raise DGLError(py_str(_LIB.DGLGetLastError()))
dgl._ffi.base.DGLError: [11:01:50] /opt/dgl/src/array/cuda/spmm.cu:233: Check failed: e == CUSPARSE_STATUS_SUCCESS: CUSPARSE ERROR: 3
Stack trace:
  [bt] (0) /usr/local/lib/python3.6/dist-packages/dgl/libdgl.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x4f) [0x7f4884d1f32f]
  [bt] (1) /usr/local/lib/python3.6/dist-packages/dgl/libdgl.so(void dgl::aten::cusparse::CusparseCsrmm2<float, long>(DLContext const&, dgl::aten::CSRMatrix const&, float const*, float const*, float*, int)+0x160) [0x7f4885925300]
  [bt] (2) /usr/local/lib/python3.6/dist-packages/dgl/libdgl.so(void dgl::aten::SpMMCsr<2, long, 32>(std::string const&, std::string const&, dgl::BcastOff const&, dgl::aten::CSRMatrix const&, dgl::runtime::NDArray, dgl::runtime::NDArray, dgl::runtime::NDArray, std::vector<dgl::runtime::NDArray, std::allocator<dgl::runtime::NDArray> >)+0xdc) [0x7f488596e95c]
  [bt] (3) /usr/local/lib/python3.6/dist-packages/dgl/libdgl.so(dgl::aten::SpMM(std::string const&, std::string const&, std::shared_ptr<dgl::BaseHeteroGraph>, dgl::runtime::NDArray, dgl::runtime::NDArray, dgl::runtime::NDArray, std::vector<dgl::runtime::NDArray, std::allocator<dgl::runtime::NDArray> >)+0x2633) [0x7f4884e7e593]
  [bt] (4) /usr/local/lib/python3.6/dist-packages/dgl/libdgl.so(+0x6b6838) [0x7f4884e8a838]
  [bt] (5) /usr/local/lib/python3.6/dist-packages/dgl/libdgl.so(+0x6b6dd1) [0x7f4884e8add1]
  [bt] (6) /usr/local/lib/python3.6/dist-packages/dgl/libdgl.so(DGLFuncCall+0x48) [0x7f4885423538]
  [bt] (7) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7f498946ddae]
  [bt] (8) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x22f) [0x7f498946d71f]

But things work well with large batch size, i.e., 16384.

Would anyone please give me a suggestion?

Is this happening at training process or validation process?

It’s a bit weird as well. The last worker has finished the train loop and output epoch information.

However, all other workers are still in the loop and error in prediction:

But I’ve set the train_nid to be splitted even:

Hi,

Could you try to dump the error block using try except? Or save the edge lists for us to debug? Thanks

Also what’s your cuda version?

@VoVAllen

Thanks a lot for your reply. I’ll upload the try except errors later.

For the version part, cuda version is 11.0, dgl version==0.6.1, python == 3.6.9.

cuda11.0 is a buggy version. Could you try later cuda version such as 11.1, 11.2 or 11.3?

Also this seems related to [Bugfix] Fix CUDA 11.1 crashing when number of edges is larger than number of node pairs by BarclayII · Pull Request #3265 · dmlc/dgl · GitHub

Probably not really related to cuda version.

Also is it possible for you to use newer dgl?