Hi. I’m doing a project on the PinSage repo and I’m struggling to understand how the sample.random_walk works here:
def iter(self):
while True:
heads = torch.randint(
0, self.g.num_nodes(self.item_type), (self.batch_size,)
)
tails = dgl.sampling.random_walk(
self.g,
heads,
metapath=[self.item_to_user_etype, self.user_to_item_etype],
)[0][:, 2]
neg_tails = torch.randint(
0, self.g.num_nodes(self.item_type), (self.batch_size,)
)
mask = (tails != -1)
yield heads[mask], tails[mask], neg_tails[mask]
If the graph is a bipartite graph, then how would the sampling work? The edges would be from the item node to user node so wouldn’t the tails list contain nodes of both types? But I assume it is somehow only the item nodes? How?