I am a newbie to DGL. I am wondering if someone can help me explained in details what the following codes snippets are doing ? What are the dimension of D, E,V in the tensors below ?
These codes are extracted from utils.py in …/pytorch/rgcn . Thanks a lot !
def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100):
""" Perturb one element in the triplets
"""
n_batch = (test_size + batch_size - 1) // batch_size
ranks = []
for idx in range(n_batch):
print("batch {} / {}".format(idx, n_batch))
batch_start = idx * batch_size
batch_end = min(test_size, (idx + 1) * batch_size)
batch_a = a[batch_start: batch_end]
batch_r = r[batch_start: batch_end]
emb_ar = embedding[batch_a] * w[batch_r]
emb_ar = emb_ar.transpose(0, 1).unsqueeze(2) # size: D x E x 1
emb_c = embedding.transpose(0, 1).unsqueeze(1) # size: D x 1 x V
# out-prod and reduce sum
out_prod = torch.bmm(emb_ar, emb_c) # size D x E x V
score = torch.sum(out_prod, dim=0) # size E x V
score = torch.sigmoid(score)
target = b[batch_start: batch_end]
ranks.append(sort_and_rank(score, target))
return torch.cat(ranks)