I am using torch DDP, for running the GraphSage model using the reddit dataset. Can someone help me out in running the model with FP-16?
Currently we do not support message passing with fp16 tensors.
A workaround is to cast your fp16 input tensor to fp32 before any update_all
and apply_edges
calls, and cast the output fp32 tensor to fp16.
1 Like
I am a beginner with DGL, it would be really helpful if you could specifically tell me the lines I need to change in train_sampling_multi_gpu.py and in sageconv.py, and also the changes to be done.
Thank you
I wrote an example to show how to use half precision in DGL:
1 Like