Hello! I am trying to write a custom layer for a heterogenous graph and have run into a problem with assigning data to src/dstdata. The issue only seems to crop up when I use batching, it seems to work as expected when I use the full graph.
I’m not sure if I am doing something wrong here, but I don’t understand why the new data is not in g.dstdata.
Thanks in advance!
toy_graph = {('a','a<>b','b') : (torch.tensor([0,0,1,1,2,3,2,3,0]),torch.tensor([0,1,1,2,1,2,3,4,0])),
('b','b<>a','a') : (torch.tensor([0,1,1,2,1,2,3,4,3]),torch.tensor([0,0,1,1,2,3,2,3,0]))
}
g = dgl.heterograph(toy_graph)
g.nodes['a'].data['A'] = torch.arange(8).view(4,2).float()
g.nodes['b'].data['B'] = torch.arange(10).view(5,2).float()
# create the sampler
fanouts = [i for i in range(1,3)]
sampler = dgl.dataloading.NeighborSampler(fanouts=fanouts)
# create the training dataloader
train_loader = dgl.dataloading.DataLoader(
g,
{'a':g.nodes('a')},
sampler,
batch_size=1,
shuffle=True,
)
# create test module
class test(nn.Module):
def __init__(
self,
in_dim,
out_dim,
):
super().__init__()
self.A_proj = nn.Linear(in_dim, out_dim , bias = False)
self.B_proj = nn.Linear(in_dim, out_dim , bias = False)
def forward(self, g, x):
with g.local_scope():
if isinstance(x, tuple):
x_src , x_dst = x
S = self.A_proj(x_src['A']['a'])
D = self.B_proj(x_dst['B']['b'])
g.srcdata['A_new'] = {'a' : S}
g.dstdata['B_new'] = {'b' : D}
print(g.srcdata.keys())
print(g.dstdata.keys())
return g.dstdata['D']
input_nodes, output_nodes, blocks = next(iter(train_loader))
x = (blocks[0].srcdata, blocks[0].dstdata)
test_layer(blocks[0], x)
Output:
dict_keys([‘A’, ‘_ID’, ‘A_new’, ‘B’, ‘B_new’])
dict_keys([‘A’, ‘_ID’, ‘B’])