I implemented an exclude function. Is it ok? Is there an efficient issue?
import dgl
import torch as th
g = dgl.heterograph({
('user', 'like', 'item'): (th.tensor([0, 0, 1, 1, 1, 2, 3, 4, 4]), th.tensor([2, 3, 1, 2, 4, 2, 1, 0, 4])),
('item', 'liked', 'user'): (th.tensor([0, 0, 1, 1, 1, 2, 3, 4, 4]), th.tensor([2, 3, 1, 2, 4, 2, 1, 0, 4])),
('user', 'valid', 'item'): (th.tensor([0, 0, 1, 1, 2, 2]), th.tensor([0, 4, 3, 4, 0, 1]))
})
g.ndata['id'] = {'user': th.arange(5), 'item': th.arange(5)}
g['valid'].edata['rating'] = th.tensor([0, 0, 1, 1, 2, 2])
class EdgeTypeExcluder:
def __init__(self, g, edge_types):
self.edge_types = {t: th.arange(g.num_edges(t)) for t in edge_types}
def __call__(self, x):
return self.edge_types
return None
excluder = EdgeTypeExcluder(g, [('user', 'valid', 'item')])
sampler = dgl.dataloading.NeighborSampler([-1]*2)
sampler = dgl.dataloading.as_edge_prediction_sampler(sampler, exclude=excluder)
train_eid_dict = {k: th.arange(g.number_of_edges(etype=k)) for k in ['valid']}
dataloader = dgl.dataloading.DataLoader(g, train_eid_dict, graph_sampler=sampler,
batch_size=3, shuffle=True, drop_last=True, num_workers=4)
input_nodes, edge_subgraph, blocks = next(iter(dataloader))
print('====================================edge_subgraph=====================================')
print(edge_subgraph)
print(edge_subgraph.edges(etype='valid'))
print(edge_subgraph.nodes['user'])
print(edge_subgraph.nodes['item'])
print(edge_subgraph.edges['valid'].data)
print('=======================================blocks0========================================')
print(blocks[0])
print('user-like-item:', blocks[0].edges(etype='like'))
print('item-liked_by-user:', blocks[0].edges(etype='liked'))
print('=======================================blocks1========================================')
print(blocks[1])
print('user-like-item:', blocks[1].edges(etype='like'))
print('item-liked_by-user:', blocks[1].edges(etype='liked'))
The output is as follows, which seems to works well.
====================================edge_subgraph=====================================
Graph(num_nodes={'item': 3, 'user': 2},
num_edges={('item', 'liked', 'user'): 0, ('user', 'like', 'item'): 0, ('user', 'valid', 'item'): 3},
metagraph=[('item', 'user', 'liked'), ('user', 'item', 'like'), ('user', 'item', 'valid')])
(tensor([0, 1, 1]), tensor([0, 1, 2]))
NodeSpace(data={'id': tensor([2, 1]), '_ID': tensor([2, 1])})
NodeSpace(data={'id': tensor([1, 3, 4]), '_ID': tensor([1, 3, 4])})
{'rating': tensor([2, 1, 1]), '_ID': tensor([5, 2, 3])}
=======================================blocks0========================================
Block(num_src_nodes={'item': 5, 'user': 5},
num_dst_nodes={'item': 5, 'user': 5},
num_edges={('item', 'liked', 'user'): 9, ('user', 'like', 'item'): 9, ('user', 'valid', 'item'): 0},
metagraph=[('item', 'user', 'liked'), ('user', 'item', 'like'), ('user', 'item', 'valid')])
user-like-item: (tensor([1, 2, 3, 1, 4, 4, 3, 1, 0]), tensor([0, 0, 1, 2, 2, 3, 4, 4, 4]))
item-liked_by-user: (tensor([3, 0, 4, 0, 1, 3, 2, 0, 2]), tensor([0, 0, 0, 1, 1, 2, 3, 4, 4]))
=======================================blocks1========================================
Block(num_src_nodes={'item': 5, 'user': 5},
num_dst_nodes={'item': 3, 'user': 2},
num_edges={('item', 'liked', 'user'): 5, ('user', 'like', 'item'): 5, ('user', 'valid', 'item'): 0},
metagraph=[('item', 'user', 'liked'), ('user', 'item', 'like'), ('user', 'item', 'valid')])
user-like-item: (tensor([1, 2, 3, 1, 4]), tensor([0, 0, 1, 2, 2]))
item-liked_by-user: (tensor([3, 0, 4, 0, 1]), tensor([0, 0, 0, 1, 1]))