please, I created a unidirectional bipartite graph as follows:
G = dgl.heterograph({ ('user', 'rates', 'item') : M.nonzero()})
G.nodes['user'].data['h'] = torch.randn(len(users), 512)
G.nodes['item'].data['h'] = torch.randn(len(items), 512)
G.edges['rates'].data['label'] = weights
I want to apply dgl.nn.GATConv on this graph for the classification of nodes, in this way:
eids = np.arange(G.number_of_edges())
#eids = np.random.permutation(eids)
test_size = int(len(eids) * test_prop)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3)
sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
train_loader = dgl.dataloading.DataLoader(G, eids[2*test_size:], sampler, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
gatconv = GatConv((G.nodes['user'].data['h'].shape[1],G.nodes['item'].data['h'].shape[1), 2,3)
for input_nodes, edge_subgraph, blocks in tqdm(valid_loader):
edge_subgraph = edge_subgraph.to(torch.device(device))
blocks = [b.to(torch.device(device)) for b in blocks]
users = blocks[0].nodes['user'].data['h']
items = blocks[0].nodes['item'].data['h']
labels = blocks[0].edata['label']
res = gatconv(blocks[0], (users, items))
but I get the following error:
AssertionError: Current HeteroNodeDataView has multiple node types, please passing the node type and the corresponding data through a dict.
how i can solve this problem? And thank you in advance.