Speeding Up/Optimizing Large Interaction Graph Inference

I have implemented an interaction graph where I use two networks, one runs during the message pass, processing the neighbors and generating a message, and another on reduce to update the node value. I am running on a very large graph, 800k nodes and 4m edges. The problem is it takes a few seconds just for an inference pass and only uses 20% of the gpu. Does anyone know how to go about optimizing performance?

Are you running an example? Could you please show the link to the code you were running?

Thanks.