MarginRankingLoss in Batch Graph Training

Hello there,

I am a bit of a newbie so I’m not sure if what I will ask makes any kind of sense.
I have a graph of 1K nodes - 5K edges with each node having 5 features and one numerical value as “label”. I am trying to correctly predict the relative ordering of each node’s “label” with respect to the rest of the nodes using MarginRankingLoss (based on the values of each node’s 5 features). This appears to be working.

I want to generalize this approach to more graphs (i.e. train on several graphs at once). For example, let’s suppose I have three graphs of 1K nodes - 5K edges (for simplicity sake; actual graph sizes may vary and may be different among them) with each node having 5 features (this will not vary).

If I use dgl.batch() to batch them all into a single graph how would this affect the loss calculation? I assume that MarginRankingLoss on the entire graph would not have any particular meaning since I am not interested in the relative ordering of all nodes compared to all others nodes (since some of them essentially belong in different graphs)?

Would it be correct if I did something similar to:

for each epoch:
   for each connected component in the batched graph:
      calculate MarginRankingLoss for the nodes in the component

and “hope” that the average loss for all components will go down once enough epochs have passed? If so, how would I able to isolate the logits and labels of the particular connected component so that I can compute MarginRankingLoss only on those nodes themselves?

Thank you in advance,