The batch_a
and batch_c
are the subject and object, or entity 1 and entity 2, from the triplets. For example, if the input data is,
[[e1, r2, e2],
[e3, r4, e4],
[e5, r6,e6],
...]
Then batch_a
would be [e1,e3,e5]
, and batch_b
would be [e2,e4,e6]
.
So for each triplet, you get num_rel
number of scores, calculated the same way as the loss,
# how the loss is calculated in the current RGCN code
e1 = entity_emb[triplets[:, 0]]
r = self.rel_W[triplets[:, 1]]
e2 = entity_emb[triplets[:, 2]]
score = torch.sum(e1 * r * e2, dim=1)
Each of the triplets corresponding to a relation, and you would rank them the exact same way the code current ranks all of the entities and calculate MRR/hits.
For example, to rank,
# same as current ranking code
_, indices = scores.sort(dim=1, descending=True)
match = indices == label.view(-1, 1)
ranks = match.nonzero()
return ranks[:, 1] + 1 # index from nonzero() starts at 0
And to calculate MRR,
rr = 1.0 / ranks.float()
return rr.mean()
This way you get the MRR for the relations instead of the entities, which makes more sense to me as “link prediction”.
Let me know if this makes sense, or if I’m missing something. Thanks!