Trouble Training Link Prediction on Heterograph with EdgeDataLoader

Hi,

I’m having some issues training a link prediction model on a heterograph using the edge data loader. Specifically, I have a graph with two types of nodes source and user, with the relation that a user is follower of a source. The source has a feature called source_embedding with dimension 750 and the user has user_embedding feature with dimension 800. I’ve been able to construct my graph successfully in DGL.

I’m not sure what exactly should be the inputs to my RGCN model should be. In the EdgeDataLoader, I can see that the features I am interested in (user_embedding, source_embedding) are there in blocks. In the Link Prediction example (https://docs.dgl.ai/guide/training-link.html), they pass in a dictionary of node features and edge types. The batch link prediction example (https://docs.dgl.ai/guide/minibatch-link.html) seems to be off, as they don’t even pass in the positive/negative graph to the model, which leads to my confusion. I haven’t been able to find an example that does what the non-batch link prediction guide does in a batch way.

I have features for both source and user. I can pass in the blocks and a node feature like so:

pos_score, neg_score = model(positive_graph, negative_graph, blocks, node_features, ('source', 'has_follower', 'user'))

However, what would node_features be? According to the non-batch link prediction example, it should be something like below. Are edge features not included?

node_features = {'source': blocks[0].srcdata['source'], 'user': blocks[0].srcdata['user']}

Further, what would the dimension of in_features be in my model since the features of both my nodes have different dimension?

Finally, how would the forward function look of the model? Would it iterate over edge types as in the non-batch link prediction example? Maybe something like below? But then how does it handle the various edge types? Are node_features constructed correctly?

class HeteroScorePredictor(nn.Module):
    def forward(self, edge_subgraph, x):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['x'] = x
            for etype in edge_subgraph.canonical_etypes:
                edge_subgraph.apply_edges(
                    dgl.function.u_dot_v('x', 'x', 'score'), etype=etype)
            return edge_subgraph.edata['score']

class TestModel(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = FakeNewsRGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroScorePredictor()
    def forward(self, g, neg_g, blocks, x, etype):
        # run both part of the graph through it given the input feature
        h = self.sage(blocks, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

pos_score, neg_score = model(positive_graph, negative_graph, blocks, node_features, ('source', 'has_follower', 'user'))
1 Like

The batch link prediction example (https://docs.dgl.ai/guide/minibatch-link.html) seems to be off, as they don’t even pass in the positive/negative graph to the model, which leads to my confusion.

Yes the documentation there is wrong. Will fix that in https://github.com/dmlc/dgl/pull/2233.

However, what would node_features be? According to the non-batch link prediction example, it should be something like below. Are edge features not included?

Your code is correct for node_features. Edge features are indeed not included. To include that you essentially need to either change the GNN model or the decoder to consider edge features.

Further, what would the dimension of in_features be in my model since the features of both my nodes have different dimension?

You can specify different dimension of in_features for each node type in the in the GNN model’s __init__ method.

class StochasticTwoLayerRGCN(nn.Module):
    def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
        super().__init__()
        self.conv1 = dglnn.HeteroGraphConv({
                # you can change in_feat to something like in_feat[ntype] here
                rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')
                for rel in rel_names
                })
        self.conv2 = dglnn.HeteroGraphConv({
                rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
                for rel in rel_names
                })

Finally, how would the forward function look of the model? Would it iterate over edge types as in the non-batch link prediction example? Maybe something like below? But then how does it handle the various edge types? Are node_features constructed correctly?

Since the iteration of edge types is already handled in the ScorePredictor, you don’t have to iterate over the edge types in Model again. My fix to the documentation goes with the following Model definition:

class ScorePredictor(nn.Module):
    def forward(self, edge_subgraph, x):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['x'] = x
            for etype in edge_subgraph.canonical_etypes:
                edge_subgraph.apply_edges(
                    dgl.function.u_dot_v('x', 'x', 'score'), etype=etype)
            return edge_subgraph.edata['score']

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_classes,
                 etypes):
        super().__init__()
        self.rgcn = StochasticTwoLayerRGCN(
            in_features, hidden_features, out_features, etypes)
        self.pred = ScorePredictor()
    def forward(self, positive_graph, negative_graph, blocks, x):
        x = self.rgcn(blocks, x)
        pos_score = self.pred(positive_graph, x)
        neg_score = self.pred(negative_graph, x)
        return pos_score, neg_score

The training loop should look like this:

model = Model(in_features, hidden_features, out_features, num_classes, etypes)
model = model.cuda()
opt = torch.optim.Adam(model.parameters())

for input_nodes, positive_graph, negative_graph, blocks in dataloader:
    blocks = [b.to(torch.device('cuda')) for b in blocks]
    positive_graph = positive_graph.to(torch.device('cuda'))
    negative_graph = negative_graph.to(torch.device('cuda'))
    input_features = blocks[0].srcdata['features']
    pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
1 Like

@BarclayII

Thanks for the help and for updating the document.

For handling features of different types, does that mean etypes should actually be ntypes? Otherwise, how will I do something like in_feat[ntype]? I instead passed in ntypes for rel_names.

I’m now getting this error:

  File "GNN_model.py", line 308, in <module>
    main(args)
  File "GNN_model.py", line 263, in main
    pos_score, neg_score = model(positive_graph, negative_graph, blocks, node_features)
  File "path/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "GNN_model.py", line 128, in forward
    h = self.sage(blocks, x)
  File "path/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "GNN_model.py", line 100, in forward
    x = self.conv1(blocks[0], x)
  File "path/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "path/lib/python3.6/site-packages/dgl/nn/pytorch/hetero.py", line 149, in forward
    dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
  File "path/lib/python3.6/site-packages/dgl/nn/pytorch/hetero.py", line 149, in <dictcomp>
    dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
TypeError: unhashable type: 'slice'

My code is the same as what you posted above, my node_features are node_features = {'source': blocks[0].srcdata['source'], 'user': blocks[0].srcdata['user']}

My blocks is:

[Block(num_src_nodes={'source': 174, 'user': 60},
      num_dst_nodes={'source': 174, 'user': 60},
      num_edges={('source', 'has_follower', 'user'): 232},
      metagraph=[('source', 'user', 'has_follower')]), Block(num_src_nodes={'source': 174, 'user': 60},
      num_dst_nodes={'source': 10, 'user': 60},
      num_edges={('source', 'has_follower', 'user'): 232},
      metagraph=[('source', 'user', 'has_follower')])]

Do you have any idea what the problem could be now? Maybe I shouldn’t actually be passing in ntypes instead of etypes? But then how would I set the various input feeature size?

You can get all the triplets of source node type, edge type, and destination node type with g.canonical_etypes, so you can change the initialization of StochasticTwoLayerRGCN to accept a dictionary of input feature sizes, and replace rel_names to the triplets. For instance, try this code:

class StochasticTwoLayerRGCN(nn.Module):
    def __init__(self, in_feat_dict, hidden_feat, out_feat, canonical_etypes):
        super().__init__()
        self.conv1 = dglnn.HeteroGraphConv({
                # you can change in_feat to something like in_feat[ntype] here
                etype : dglnn.GraphConv(in_feat_dict[utype], hidden_feat, norm='right')
                for utype, etype, vtype in canonical_etypes
                })
        self.conv2 = dglnn.HeteroGraphConv({
                etype : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
                for _, etype, _ in canonical_etypes
                })
1 Like

@BarclayII
Ok, I did that and it seems to initialize my parameters correctly, but I’m still having trouble passing the features correctly. Thanks for your continual help.

Initializing my Conv layer with in_feats[utype] doesn’t seem to get my user embedding correctly as an input feature since g.canonical_etypes is [('source', 'has_follower', 'user')], so it only sets the in_features to in_feat_dict['source'] and ignores in_feat_dict['user'].

For the forward pass, I tried to pass in the features to self.rgcn in your case like so:

node_features = {'source': blocks[0].srcdata['source_embedding']['source'], 'user': blocks[0].srcdata['user_embedding']['user']}

However, when printing the output of the forward pass (setup just like your StochasticTwoLayerRGCN), the first only has features of my user node, and the second convolution layer returns an empty dict! I’m not sure why this happens.

class TestRGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, canonical_etypes, ):
        super(TestRGCN, self).__init__()

        self.conv1 = dglnn.HeteroGraphConv({
                etype : dglnn.GraphConv(in_feats[utype], hid_feats, norm='right')
                for utype, etype, vtype in canonical_etypes
                })
        self.conv2 = dglnn.HeteroGraphConv({
                etype : dglnn.GraphConv(hid_feats, out_feats, norm='right')
                for _, etype, _ in canonical_etypes
                })

    def forward(self, blocks, inputs):

        x1 = self.conv1(blocks[0], inputs)
        x2 = self.conv2(blocks[1], x1)

        return x2

Output of x1:

{'user': tensor([[ 0.0052, -0.0134, -0.0723,  ...,  0.0719, -0.0329,  0.0110],
        [-0.1712, -0.4903,  0.0603,  ..., -0.1914, -0.3530,  0.0801],
        [-0.0950, -0.2891,  0.0049,  ..., -0.0823, -0.2171,  0.0522],
        ...,
        [-0.1543, -0.4432,  0.0472,  ..., -0.1642, -0.3208,  0.0728],
        [ 0.0238,  0.0432, -0.0881,  ...,  0.1066,  0.0054,  0.0021],
        [ 0.1437,  0.3533, -0.1740,  ...,  0.2709,  0.2142, -0.0417]],
       device='cuda:0', grad_fn=<SumBackward1>)}

Output of x2: Empty!

For now, let’s assume that I continue with the output as x1 and run the next step: self.pred(positive_graph, x)

class HeteroScorePredictor(nn.Module):
    def forward(self, edge_subgraph, x):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['x'] = x
            for etype in edge_subgraph.canonical_etypes:
                edge_subgraph.apply_edges(
                    dgl.function.u_dot_v('x', 'x', 'score'), etype=etype)
            return edge_subgraph.edata['score']

class TestModel(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, canonical_etypes):
        super().__init__()
        self.sage = TestRGCN(in_features, hidden_features, out_features, canonical_etypes)
        self.pred = HeteroScorePredictor()
    def forward(self, g, neg_g, blocks, x):
        x = self.sage(blocks, x)
        pos_score = self.pred(g, x)
        neg_score = self.pred(neg_g, x)
        return pos_score, neg_score 

Then, I get the following error:

Traceback (most recent call last):
  File "GNN_model.py", line 304, in <module>
    main(args)
  File "GNN_model.py", line 259, in main
    pos_score, neg_score = model(positive_graph, negative_graph, blocks, node_features)
  File "path/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "GNN_model.py", line 116, in forward
    pos_score = self.pred(g, x)
  File "path/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "GNN_model.py", line 102, in forward
    dgl.function.u_dot_v('x', 'x', 'score'), etype=etype)
  File "path/lib/python3.6/site-packages/dgl/heterograph.py", line 4119, in apply_edges
    edata = core.invoke_gsddmm(g, func)
  File "path/lib/python3.6/site-packages/dgl/core.py", line 201, in invoke_gsddmm
    x = alldata[func.lhs][func.lhs_field]
  File "path/lib/python3.6/site-packages/dgl/view.py", line 66, in __getitem__
    return self._graph._get_n_repr(self._ntid, self._nodes)[key]
  File "path/lib/python3.6/site-packages/dgl/frame.py", line 386, in __getitem__
    return self._columns[name].data
KeyError: 'x'

Upon investigating this further with a custom u_dot_v function, it seems like edges.src doesn’t have a 'x' value, but edges.dst does.

def apply_edges(self, edges):
        # this fails
        h_u = edges.src['h']
        # this succeeds
        h_v = edges.dst['h']
        score = self.W(torch.cat([h_u, h_v], 1))
        return {'score': score}

I’m not sure what is wrong at these two points. Maybe there is something wrong with my input features still? Maybe edge_subgraph.ndata['h'] = x isn’t doing the right thing?

By making my graph bidirectional - connecting users back to sources and sources to users as before, things seem to be working now. I’m having some issues computing the loss, but will look into that.

@BarclayII

I basically ended up doing it like this, and it seems to be working now. Can you help me verify that it is correct?

class TestRGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, canonical_etypes):
        super(TestRGCN, self).__init__()

        self.conv1 = dglnn.HeteroGraphConv({
                etype : dglnn.GraphConv(in_feats[utype], hid_feats, norm='right')
                for utype, etype, vtype in canonical_etypes
                })
        self.conv2 = dglnn.HeteroGraphConv({
                etype : dglnn.GraphConv(hid_feats, out_feats, norm='right')
                for _, etype, _ in canonical_etypes
                })

    def forward(self, blocks, inputs):
        x = self.conv1(blocks[0], inputs)
        x = self.conv2(blocks[1], x)

        return x

class HeteroScorePredictor(nn.Module):
    def forward(self, edge_subgraph, x):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['h'] = x
            for etype in edge_subgraph.canonical_etypes:
                edge_subgraph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)
                # edge_subgraph.apply_edges(self.apply_edges, etype=etype)
            return edge_subgraph.edata['score']

class TestModel(nn.Module):
    # here we have a model that first computes the representation and then predicts the scores for the edges
    def __init__(self, in_features, hidden_features, out_features, canonical_etypes):
        super().__init__()
        self.sage = TestRGCN(in_features, hidden_features, out_features, canonical_etypes)
        self.pred = HeteroScorePredictor()
    def forward(self, g, neg_g, blocks, x):
        x = self.sage(blocks, x)
        pos_score = self.pred(g, x)
        neg_score = self.pred(neg_g, x)
        return pos_score, neg_score 

def compute_loss(pos_score, neg_score, canonical_etypes):
    # Margin loss
    all_losses = []
    for given_type in canonical_etypes:
        n_edges = pos_score[given_type].shape[0]
        if n_edges == 0:
            continue
        all_losses.append((1 - neg_score[given_type].view(n_edges, -1) + pos_score[given_type].unsqueeze(1)).clamp(min=0).mean())
    return torch.stack(all_losses, dim=0).mean()

model = TestModel(in_features={'source':700, 'user':800}, hidden_features=512, out_features=256, canonical_etypes=g.canonical_etypes)
...
for epoch in range(args.n_epochs):
        for input_nodes, positive_graph, negative_graph, blocks in dataloader:
            model.train()
            blocks = [b.to(torch.device('cuda')) for b in blocks]

            positive_graph = positive_graph.to(torch.device('cuda'))
            negative_graph = negative_graph.to(torch.device('cuda'))

            node_features = {'source': blocks[0].srcdata['source_embedding']['source'], 'user': blocks[0].srcdata['user_embedding']['user']}
            pos_score, neg_score = model(positive_graph, negative_graph, blocks, node_features)
            loss = compute_loss(pos_score, neg_score, g.canonical_etypes)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

My input graph basically has a structure like this (with more nodes of course):

data_dict = {('source', 'has_follower', 'user'): (torch.tensor([0, 0]), torch.tensor([0, 0])), ('user', 'follows', 'source'): (torch.tensor([0, 0]), torch.tensor([0, 0]))}

And node has an embedding:

dgl_graph.nodes['source'].data['source_embedding'] = torch.zeros(1, 700)
dgl_graph.nodes['user'].data['user_embedding'] = torch.zeros(1, 800)

Your code looks correct.

Yes. I forgot to mention that. Otherwise x2 will be empty because the users don’t have any incoming neighbors.

1 Like

@BarclayII Thanks so much for your help so far, it’s been very useful.

I’m now trying to used my trained GCN to do some inference. Basically, I want to create an embedding of my graph and use that to train a LR model to do some classification. I saw this document (https://docs.dgl.ai/guide/minibatch-inference.html?highlight=nodedataloader%20device) on how to do it on my large graph, and I think I need to do the same thing (node data loader instead of edge data loader since now I just want an embedding of my graph, right?). However, I am having trouble doing it.

I get a RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method when doing it. This happens when I am trying to create the new Node Data Loader that I will eventually use for creating the embedding. My actual code is more complex than this (and similar to the tutorial), but I tried doing it just like below in the model training loop, and the same error arises, so I’ve provided this more minimalistic code.

for epoch in range(args.n_epochs):
        # this is the old edge data loader
        for input_nodes, positive_graph, negative_graph, blocks in dataloader:

            model.train()
            blocks = [b.to(torch.device('cuda')) for b in blocks]

            if epoch >= 3:
                t0 = time.time()

            positive_graph = positive_graph.to(torch.device('cuda'))
            negative_graph = negative_graph.to(torch.device('cuda'))
            node_features = {'source': blocks[0].srcdata['source_embedding']['source'].to(torch.device('cuda')), 'user': blocks[0].srcdata['user_embedding']['user'].to(torch.device('cuda'))}

            # forward propagation by using all nodes and extracting the user embeddings
            pos_score, neg_score = model(positive_graph, negative_graph, blocks, node_features)
            loss = compute_loss(pos_score, neg_score, g.canonical_etypes)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if epoch >= 3:
                dur.append(time.time() - t0)

            print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | ""ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(), n_edges / np.mean(dur) / 1000))

            # evaluate the model
            source_feats = g.nodes['source'].data['source_embedding'].to(torch.device('cuda'))
            user_feats = g.nodes['user'].data['user_embedding'].to(torch.device('cuda'))
            node_features = {'source': source_feats, 'user': user_feats}
            new_sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
            new_dataloader = dgl.dataloading.NodeDataLoader(
                g, {k: torch.arange(g.number_of_nodes(k)) for k in g.ntypes}, sampler, batch_size=args.batch_size, shuffle=True, drop_last=False, pin_memory=True, num_workers=args.num_workers)

            # on this line is where the CUDA error happens
            for input_nodes, output_nodes, blocks in tqdm(new_dataloader):
                print(input_nodes)
                exit(1)

Error:

Traceback (most recent call last):
  File "GNN_model.py", line 378, in <module>
    main(args)
  File "GNN_model.py", line 340, in main
    for input_nodes, output_nodes, blocks in tqdm(new_dataloader):
  File "path/.local/lib/python3.6/site-packages/tqdm/std.py", line 1133, in __iter__
    for obj in iterable:
  File "path/lib/python3.6/site-packages/dgl/dataloading/pytorch/__init__.py", line 145, in __next__
    input_nodes, output_nodes, blocks = next(self.iter_)
  File "path/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "path/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File "path/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "path/lib/python3.6/site-packages/torch/_utils.py", line 395, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "path/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "path/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "path/lib/python3.6/site-packages/dgl/dataloading/pytorch/__init__.py", line 121, in collate
    input_nodes, output_nodes, blocks = super().collate(items)
  File "path/lib/python3.6/site-packages/dgl/dataloading/dataloader.py", line 364, in collate
    items = utils.prepare_tensor_dict(self.g, items, 'items')
  File "path/lib/python3.6/site-packages/dgl/utils/checks.py", line 70, in prepare_tensor_dict
    for key, val in data.items()}
  File "path/lib/python3.6/site-packages/dgl/utils/checks.py", line 70, in <dictcomp>
    for key, val in data.items()}
  File "path/lib/python3.6/site-packages/dgl/utils/checks.py", line 42, in prepare_tensor
    ret = F.copy_to(F.astype(data, g.idtype), g.device)
  File "path/lib/python3.6/site-packages/dgl/backend/pytorch/tensor.py", line 112, in copy_to
    th.cuda.set_device(ctx.index)
  File "path/lib/python3.6/site-packages/torch/cuda/__init__.py", line 245, in set_device
    torch._C._cuda_setDevice(device)
  File "path/lib/python3.6/site-packages/torch/cuda/__init__.py", line 148, in _lazy_init
    "Cannot re-initialize CUDA in forked subprocess. " + msg)
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Do you have any idea why this could be happening? Why is CUDA trying to spawn another process?

It seems that your graph is on GPU. For multiprocessing sampling you cannot put the graph on GPU unless you set the multiprocessing mode to spawn as PyTorch stated.

This is because NodeDataLoader with num_workers will instantiate multiple worker processes to sample blocks in parallel. The same goes for PyTorch as well: if you pass a GPU tensor to a PyTorch DataLoader you will probably see similar error messages.

2 Likes

@BarclayII

Ok, thanks for the help again. I was able to fix that issue by using Pytorch multiprocessing to call my main function (renamed to running_code) like so. I also had to change my graph to be on CPU before the NodeDataLoader call in the inference function (and then I think NodeDataLoader would put parts on GPU).

ctx = mp.get_context('spawn')
p = ctx.Process(target=running_code, args=(args,))
p.start()
procs.append(p)
for p in procs:
    p.join()

However, now I’m having issues again passing in input features into the inference function. In the RGCN Heterograph inference example (https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/model.py), they just pass in the features (x in the inference function), and index it appropriately with the NodeDataLoader like so: h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()}.

However, in my case I had previously passed in the features like below, so I’m unsure how to do a similar indexing (using input_nodes) and how to assign it to the output tensor like the example did.
node_features = {'source': blocks[0].srcdata['source_embedding']['source'].to(torch.device('cuda')), 'user': blocks[0].srcdata['user_embedding']['user'].to(torch.device('cuda'))}

I was actually able to figure it out. For anyone looking for the same issue, this is the code I used. @BarclayII thanks a lot for your help, I think I finally have everything working now :slight_smile:

# x are the input features for the entire graph
def inference(self, curr_g, x, batch_size):
        nodes = torch.arange(curr_g.number_of_nodes())
        curr_g = curr_g.to('cpu')

        for l, layer in enumerate([self.conv1, self.conv2]):
            y = {k: torch.zeros(curr_g.number_of_nodes(k), self.hid_feats if l != self.n_layers - 1 else self.out_feats) for k in curr_g.ntypes}

            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
            dataloader = dgl.dataloading.NodeDataLoader(
                curr_g, {k: torch.arange(curr_g.number_of_nodes(k)) for k in curr_g.ntypes}, sampler, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=self.num_workers)

            for input_nodes, output_nodes, blocks in tqdm(dataloader):
                block = blocks[0].to(torch.device('cuda'))

                h = {k: x[k][input_nodes[k].type(torch.LongTensor)].to(torch.device('cuda')) for k in input_nodes.keys()}
                
                h = layer(block, h)

                for k in h.keys():
                    y[k][output_nodes[k].type(torch.LongTensor)] = h[k].cpu()

            x = y
            
        return y
2 Likes

Yes. I forgot to mention that. Otherwise x2 will be empty because the users don’t have any incoming neighbors.

@BarclayII Can you explain this a bit further? So basically nodes would only be updated based on their incoming neighbors and not outgoing edges?

If this is true, then can a node that only has outgoing edges (no incoming neighbors) impact another node’s representation?

@hockeybro12 Thanks for the pointer to this thread, I indeed have a question as to the way you defined node_features.

What we basically need is a dictionary mapping from node type to the tensor, right? Do I understand correctly that this is replaced by h in the version that is working?

Yes. As input to the inference function, I still pass in x, which is:

source_feats = g.nodes['source'].data['source_embedding'].to(torch.device('cuda'))
user_feats = g.nodes['user'].data['user_embedding'].to(torch.device('cuda'))
node_features_for_inference = {'source': source_feats, 'user': user_feats}

@hockeybro12 How did you make your graph bidirectional? I think I am facing a similar issue

@hockeybro12 Why are you using the NodeDataLoader here if you want to do link prediction instead of the EdgeDataLoader? I thought that the edges were sampled, and not the nodes.

@hockeybro12 @BarclayII
One thing I am also unsure about is that only the first element of the blocks is selected in the code below. Since blocks is a list of DGLGraphs, why only select the first one and not iterate through all?

I use the NodeDataLoader for creating an embedding for the graph after training it with LinkPrediction. If you are doing LinkPrediction, you should use EdgeDataLoader.

Originally I had sources that follow users. I then added another edge, so users have sources that they follow. So I still have two types of nodes - sources and users. But now I have two types of edges - follow and follow source - depending on which direction.

So, in this case we are iterating through the layers and then passing the blocks. Normally, in a forward function, we would do:

def forward(self, blocks, inputs):
        x = self.conv1(blocks[0], inputs)
        x = self.conv2(blocks[1], x)

However, in the inference function we are passing it in ourselves:

for l, layer in enumerate([self.conv1, self.conv2]):
...
       for input_nodes, output_nodes, blocks in tqdm(dataloader):
                block = blocks[0].to(torch.device('cuda'))
                ...
                h = layer(block, h)

I think this is why we index it like so, but I’m not 100% sure.

1 Like

@hockeybro12 Thanks for your response :pray: . I’ve got a follow-up question here.

Did you add the edges manually by swapping the order of source and target node or is there any clever built-in way to do that?