batch_c are the subject and object, or entity 1 and entity 2, from the triplets. For example, if the input data is,
[[e1, r2, e2],
[e3, r4, e4],
batch_a would be
batch_b would be
So for each triplet, you get
num_rel number of scores, calculated the same way as the loss,
# how the loss is calculated in the current RGCN code
e1 = entity_emb[triplets[:, 0]]
r = self.rel_W[triplets[:, 1]]
e2 = entity_emb[triplets[:, 2]]
score = torch.sum(e1 * r * e2, dim=1)
Each of the triplets corresponding to a relation, and you would rank them the exact same way the code current ranks all of the entities and calculate MRR/hits.
For example, to rank,
# same as current ranking code
_, indices = scores.sort(dim=1, descending=True)
match = indices == label.view(-1, 1)
ranks = match.nonzero()
return ranks[:, 1] + 1 # index from nonzero() starts at 0
And to calculate MRR,
rr = 1.0 / ranks.float()
This way you get the MRR for the relations instead of the entities, which makes more sense to me as “link prediction”.
Let me know if this makes sense, or if I’m missing something. Thanks!