MarginRankingLoss in Batch Graph Training

Hello there,

I am a bit of a newbie so I’m not sure if what I will ask makes any kind of sense.
I have a graph of 1K nodes - 5K edges with each node having 5 features and one numerical value as “label”. I am trying to correctly predict the relative ordering of each node’s “label” with respect to the rest of the nodes using MarginRankingLoss (based on the values of each node’s 5 features). This appears to be working.

I want to generalize this approach to more graphs (i.e. train on several graphs at once). For example, let’s suppose I have three graphs of 1K nodes - 5K edges (for simplicity sake; actual graph sizes may vary and may be different among them) with each node having 5 features (this will not vary).

If I use dgl.batch() to batch them all into a single graph how would this affect the loss calculation? I assume that MarginRankingLoss on the entire graph would not have any particular meaning since I am not interested in the relative ordering of all nodes compared to all others nodes (since some of them essentially belong in different graphs)?

Would it be correct if I did something similar to:

for each epoch:
   for each connected component in the batched graph:
      calculate MarginRankingLoss for the nodes in the component
      loss.backward()
      optimizer.step()

and “hope” that the average loss for all components will go down once enough epochs have passed? If so, how would I able to isolate the logits and labels of the particular connected component so that I can compute MarginRankingLoss only on those nodes themselves?

Thank you in advance,
Andreas