Is there a way of running GraphSage Model with half precision

I am using torch DDP, for running the GraphSage model using the reddit dataset. Can someone help me out in running the model with FP-16?

Currently we do not support message passing with fp16 tensors.
A workaround is to cast your fp16 input tensor to fp32 before any update_all and apply_edges calls, and cast the output fp32 tensor to fp16.

1 Like

I am a beginner with DGL, it would be really helpful if you could specifically tell me the lines I need to change in train_sampling_multi_gpu.py and in sageconv.py, and also the changes to be done.
Thank you

I wrote an example to show how to use half precision in DGL:

1 Like