Hi! I’m trying to use exclude='reverse_types'
in dataloading.as_edge_prediction_sampler. My understanding is that message flow graphs will not contain edges from the positive graph and their reverse edges. But after executing:
import dgl
import torch
torch.manual_seed(0)
user = torch.randint(10, (30,))
item = torch.randint(10, (30,))
g = dgl.heterograph({
('user', 'click', 'item'): (user, item),
('item', 'clicked-by', 'user'): (item, user)})
neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
sampler = dgl.dataloading.as_edge_prediction_sampler(
dgl.dataloading.NeighborSampler([2,2,2]),
exclude='reverse_types',
reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
negative_sampler=neg_sampler)
dataloader = dgl.dataloading.DataLoader(
g,
{
type: torch.arange(g.number_of_edges(type))
for type in g._etypes
},
sampler,
batch_size=2, shuffle=True, drop_last=False, num_workers=1)
input_nodes, pos_graph, neg_graph, mfgs = next(iter(dataloader))
print(pos_graph["click"].edges())
print(pos_graph["clicked-by"].edges())
print(mfgs[2]["click"].edges())
print(mfgs[2]["clicked-by"].edges())
I get
(tensor([0, 1]), tensor([0, 1]))
(tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
(tensor([0, 0, 2, 3, 4, 1, 5, 4, 6, 6, 7, 6, 5]), tensor([0, 0, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]))
(tensor([0, 0, 3]), tensor([0, 0, 1]))
This means my positive edge (0,0) of type (‘user’, ‘click’, ‘item’) is not excluded from the message flow graph and their reverse edge (0,0) of type (‘item’, ‘click-by’, ‘user’) is also not excluded.
Can someone explain where’s the mistake?