Trouble Training Link Prediction on Heterograph with EdgeDataLoader

By making my graph bidirectional - connecting users back to sources and sources to users as before, things seem to be working now. I’m having some issues computing the loss, but will look into that.

@BarclayII

I basically ended up doing it like this, and it seems to be working now. Can you help me verify that it is correct?

class TestRGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, canonical_etypes):
        super(TestRGCN, self).__init__()

        self.conv1 = dglnn.HeteroGraphConv({
                etype : dglnn.GraphConv(in_feats[utype], hid_feats, norm='right')
                for utype, etype, vtype in canonical_etypes
                })
        self.conv2 = dglnn.HeteroGraphConv({
                etype : dglnn.GraphConv(hid_feats, out_feats, norm='right')
                for _, etype, _ in canonical_etypes
                })

    def forward(self, blocks, inputs):
        x = self.conv1(blocks[0], inputs)
        x = self.conv2(blocks[1], x)

        return x

class HeteroScorePredictor(nn.Module):
    def forward(self, edge_subgraph, x):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['h'] = x
            for etype in edge_subgraph.canonical_etypes:
                edge_subgraph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)
                # edge_subgraph.apply_edges(self.apply_edges, etype=etype)
            return edge_subgraph.edata['score']

class TestModel(nn.Module):
    # here we have a model that first computes the representation and then predicts the scores for the edges
    def __init__(self, in_features, hidden_features, out_features, canonical_etypes):
        super().__init__()
        self.sage = TestRGCN(in_features, hidden_features, out_features, canonical_etypes)
        self.pred = HeteroScorePredictor()
    def forward(self, g, neg_g, blocks, x):
        x = self.sage(blocks, x)
        pos_score = self.pred(g, x)
        neg_score = self.pred(neg_g, x)
        return pos_score, neg_score 

def compute_loss(pos_score, neg_score, canonical_etypes):
    # Margin loss
    all_losses = []
    for given_type in canonical_etypes:
        n_edges = pos_score[given_type].shape[0]
        if n_edges == 0:
            continue
        all_losses.append((1 - neg_score[given_type].view(n_edges, -1) + pos_score[given_type].unsqueeze(1)).clamp(min=0).mean())
    return torch.stack(all_losses, dim=0).mean()

model = TestModel(in_features={'source':700, 'user':800}, hidden_features=512, out_features=256, canonical_etypes=g.canonical_etypes)
...
for epoch in range(args.n_epochs):
        for input_nodes, positive_graph, negative_graph, blocks in dataloader:
            model.train()
            blocks = [b.to(torch.device('cuda')) for b in blocks]

            positive_graph = positive_graph.to(torch.device('cuda'))
            negative_graph = negative_graph.to(torch.device('cuda'))

            node_features = {'source': blocks[0].srcdata['source_embedding']['source'], 'user': blocks[0].srcdata['user_embedding']['user']}
            pos_score, neg_score = model(positive_graph, negative_graph, blocks, node_features)
            loss = compute_loss(pos_score, neg_score, g.canonical_etypes)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

My input graph basically has a structure like this (with more nodes of course):

data_dict = {('source', 'has_follower', 'user'): (torch.tensor([0, 0]), torch.tensor([0, 0])), ('user', 'follows', 'source'): (torch.tensor([0, 0]), torch.tensor([0, 0]))}

And node has an embedding:

dgl_graph.nodes['source'].data['source_embedding'] = torch.zeros(1, 700)
dgl_graph.nodes['user'].data['user_embedding'] = torch.zeros(1, 800)

Your code looks correct.

Yes. I forgot to mention that. Otherwise x2 will be empty because the users don’t have any incoming neighbors.

1 Like

@BarclayII Thanks so much for your help so far, it’s been very useful.

I’m now trying to used my trained GCN to do some inference. Basically, I want to create an embedding of my graph and use that to train a LR model to do some classification. I saw this document (https://docs.dgl.ai/guide/minibatch-inference.html?highlight=nodedataloader%20device) on how to do it on my large graph, and I think I need to do the same thing (node data loader instead of edge data loader since now I just want an embedding of my graph, right?). However, I am having trouble doing it.

I get a RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method when doing it. This happens when I am trying to create the new Node Data Loader that I will eventually use for creating the embedding. My actual code is more complex than this (and similar to the tutorial), but I tried doing it just like below in the model training loop, and the same error arises, so I’ve provided this more minimalistic code.

for epoch in range(args.n_epochs):
        # this is the old edge data loader
        for input_nodes, positive_graph, negative_graph, blocks in dataloader:

            model.train()
            blocks = [b.to(torch.device('cuda')) for b in blocks]

            if epoch >= 3:
                t0 = time.time()

            positive_graph = positive_graph.to(torch.device('cuda'))
            negative_graph = negative_graph.to(torch.device('cuda'))
            node_features = {'source': blocks[0].srcdata['source_embedding']['source'].to(torch.device('cuda')), 'user': blocks[0].srcdata['user_embedding']['user'].to(torch.device('cuda'))}

            # forward propagation by using all nodes and extracting the user embeddings
            pos_score, neg_score = model(positive_graph, negative_graph, blocks, node_features)
            loss = compute_loss(pos_score, neg_score, g.canonical_etypes)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if epoch >= 3:
                dur.append(time.time() - t0)

            print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | ""ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(), n_edges / np.mean(dur) / 1000))

            # evaluate the model
            source_feats = g.nodes['source'].data['source_embedding'].to(torch.device('cuda'))
            user_feats = g.nodes['user'].data['user_embedding'].to(torch.device('cuda'))
            node_features = {'source': source_feats, 'user': user_feats}
            new_sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
            new_dataloader = dgl.dataloading.NodeDataLoader(
                g, {k: torch.arange(g.number_of_nodes(k)) for k in g.ntypes}, sampler, batch_size=args.batch_size, shuffle=True, drop_last=False, pin_memory=True, num_workers=args.num_workers)

            # on this line is where the CUDA error happens
            for input_nodes, output_nodes, blocks in tqdm(new_dataloader):
                print(input_nodes)
                exit(1)

Error:

Traceback (most recent call last):
  File "GNN_model.py", line 378, in <module>
    main(args)
  File "GNN_model.py", line 340, in main
    for input_nodes, output_nodes, blocks in tqdm(new_dataloader):
  File "path/.local/lib/python3.6/site-packages/tqdm/std.py", line 1133, in __iter__
    for obj in iterable:
  File "path/lib/python3.6/site-packages/dgl/dataloading/pytorch/__init__.py", line 145, in __next__
    input_nodes, output_nodes, blocks = next(self.iter_)
  File "path/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "path/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File "path/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "path/lib/python3.6/site-packages/torch/_utils.py", line 395, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "path/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "path/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "path/lib/python3.6/site-packages/dgl/dataloading/pytorch/__init__.py", line 121, in collate
    input_nodes, output_nodes, blocks = super().collate(items)
  File "path/lib/python3.6/site-packages/dgl/dataloading/dataloader.py", line 364, in collate
    items = utils.prepare_tensor_dict(self.g, items, 'items')
  File "path/lib/python3.6/site-packages/dgl/utils/checks.py", line 70, in prepare_tensor_dict
    for key, val in data.items()}
  File "path/lib/python3.6/site-packages/dgl/utils/checks.py", line 70, in <dictcomp>
    for key, val in data.items()}
  File "path/lib/python3.6/site-packages/dgl/utils/checks.py", line 42, in prepare_tensor
    ret = F.copy_to(F.astype(data, g.idtype), g.device)
  File "path/lib/python3.6/site-packages/dgl/backend/pytorch/tensor.py", line 112, in copy_to
    th.cuda.set_device(ctx.index)
  File "path/lib/python3.6/site-packages/torch/cuda/__init__.py", line 245, in set_device
    torch._C._cuda_setDevice(device)
  File "path/lib/python3.6/site-packages/torch/cuda/__init__.py", line 148, in _lazy_init
    "Cannot re-initialize CUDA in forked subprocess. " + msg)
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Do you have any idea why this could be happening? Why is CUDA trying to spawn another process?

It seems that your graph is on GPU. For multiprocessing sampling you cannot put the graph on GPU unless you set the multiprocessing mode to spawn as PyTorch stated.

This is because NodeDataLoader with num_workers will instantiate multiple worker processes to sample blocks in parallel. The same goes for PyTorch as well: if you pass a GPU tensor to a PyTorch DataLoader you will probably see similar error messages.

2 Likes

@BarclayII

Ok, thanks for the help again. I was able to fix that issue by using Pytorch multiprocessing to call my main function (renamed to running_code) like so. I also had to change my graph to be on CPU before the NodeDataLoader call in the inference function (and then I think NodeDataLoader would put parts on GPU).

ctx = mp.get_context('spawn')
p = ctx.Process(target=running_code, args=(args,))
p.start()
procs.append(p)
for p in procs:
    p.join()

However, now I’m having issues again passing in input features into the inference function. In the RGCN Heterograph inference example (https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/model.py), they just pass in the features (x in the inference function), and index it appropriately with the NodeDataLoader like so: h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()}.

However, in my case I had previously passed in the features like below, so I’m unsure how to do a similar indexing (using input_nodes) and how to assign it to the output tensor like the example did.
node_features = {'source': blocks[0].srcdata['source_embedding']['source'].to(torch.device('cuda')), 'user': blocks[0].srcdata['user_embedding']['user'].to(torch.device('cuda'))}

I was actually able to figure it out. For anyone looking for the same issue, this is the code I used. @BarclayII thanks a lot for your help, I think I finally have everything working now :slight_smile:

# x are the input features for the entire graph
def inference(self, curr_g, x, batch_size):
        nodes = torch.arange(curr_g.number_of_nodes())
        curr_g = curr_g.to('cpu')

        for l, layer in enumerate([self.conv1, self.conv2]):
            y = {k: torch.zeros(curr_g.number_of_nodes(k), self.hid_feats if l != self.n_layers - 1 else self.out_feats) for k in curr_g.ntypes}

            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
            dataloader = dgl.dataloading.NodeDataLoader(
                curr_g, {k: torch.arange(curr_g.number_of_nodes(k)) for k in curr_g.ntypes}, sampler, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=self.num_workers)

            for input_nodes, output_nodes, blocks in tqdm(dataloader):
                block = blocks[0].to(torch.device('cuda'))

                h = {k: x[k][input_nodes[k].type(torch.LongTensor)].to(torch.device('cuda')) for k in input_nodes.keys()}
                
                h = layer(block, h)

                for k in h.keys():
                    y[k][output_nodes[k].type(torch.LongTensor)] = h[k].cpu()

            x = y
            
        return y
2 Likes

Yes. I forgot to mention that. Otherwise x2 will be empty because the users don’t have any incoming neighbors.

@BarclayII Can you explain this a bit further? So basically nodes would only be updated based on their incoming neighbors and not outgoing edges?

If this is true, then can a node that only has outgoing edges (no incoming neighbors) impact another node’s representation?

@hockeybro12 Thanks for the pointer to this thread, I indeed have a question as to the way you defined node_features.

What we basically need is a dictionary mapping from node type to the tensor, right? Do I understand correctly that this is replaced by h in the version that is working?

Yes. As input to the inference function, I still pass in x, which is:

source_feats = g.nodes['source'].data['source_embedding'].to(torch.device('cuda'))
user_feats = g.nodes['user'].data['user_embedding'].to(torch.device('cuda'))
node_features_for_inference = {'source': source_feats, 'user': user_feats}

@hockeybro12 How did you make your graph bidirectional? I think I am facing a similar issue

@hockeybro12 Why are you using the NodeDataLoader here if you want to do link prediction instead of the EdgeDataLoader? I thought that the edges were sampled, and not the nodes.

@hockeybro12 @BarclayII
One thing I am also unsure about is that only the first element of the blocks is selected in the code below. Since blocks is a list of DGLGraphs, why only select the first one and not iterate through all?

I use the NodeDataLoader for creating an embedding for the graph after training it with LinkPrediction. If you are doing LinkPrediction, you should use EdgeDataLoader.

Originally I had sources that follow users. I then added another edge, so users have sources that they follow. So I still have two types of nodes - sources and users. But now I have two types of edges - follow and follow source - depending on which direction.

So, in this case we are iterating through the layers and then passing the blocks. Normally, in a forward function, we would do:

def forward(self, blocks, inputs):
        x = self.conv1(blocks[0], inputs)
        x = self.conv2(blocks[1], x)

However, in the inference function we are passing it in ourselves:

for l, layer in enumerate([self.conv1, self.conv2]):
...
       for input_nodes, output_nodes, blocks in tqdm(dataloader):
                block = blocks[0].to(torch.device('cuda'))
                ...
                h = layer(block, h)

I think this is why we index it like so, but I’m not 100% sure.

1 Like

@hockeybro12 Thanks for your response :pray: . I’ve got a follow-up question here.

Did you add the edges manually by swapping the order of source and target node or is there any clever built-in way to do that?

I had to add it manually when creating the graph.

Correct.

It can, according to the statement above. However, those other nodes may not impact that node’s representation.

I built the ScorePredictor similar to yours and it looks as follows:

class ScorePredictor(nn.Module):
    def forward(
        self,
        edge_subgraph: dgl.DGLHeteroGraph,
        x: Dict[str,  torch.Tensor],
        eval_edge_type: str,
    ) -> torch.Tensor:
        """Perform score prediction only on the evaluation edge type.

        :param edge_subgraph: subgraph to be evaluated
        :param x: dictionary mapping node type  to features
        :param eval_edge_type: edge type to be evaluated
        :return: dictionary mapping edge type to the scores for the subgraph
        """
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['x'] = x
         
            # only test on evaluation edge type
            edge_subgraph.apply_edges(
                dgl.function.u_dot_v('x', 'x', 'score'), etype=eval_edge_type)
            return edge_subgraph.edata['score']

Yet I am receiving an error that the data at 'x'cannot be accessed:

  File "/.../ScorePredictor.py", line 40, in forward
    dgl.function.u_dot_v('x', 'x', 'score'), etype=eval_edge_type)
  File "/.../lib/python3.7/site-packages/dgl/heterograph.py", line 4064, in apply_edges
    edata = core.invoke_gsddmm(g, func)
  File "/...lib/python3.7/site-packages/dgl/core.py", line 195, in invoke_gsddmm
    x = alldata[func.lhs][func.lhs_field]
  File "/.../lib/python3.7/site-packages/dgl/view.py", line 66, in __getitem__
    return self._graph._get_n_repr(self._ntid, self._nodes)[key]
  File "/...k/lib/python3.7/site-packages/dgl/frame.py", line 373, in __getitem__
    return self._columns[name].data
KeyError: 'x'

Do you have an idea how that can be, even though I am setting the .ndata here to x ? x here is a dictionary mapping node type to feature tensors. @BarclayII @hockeybro12

Did you try printing x at that function? Does it have the features you expect?

x here is a dictionary mapping node type to tensor :

>>> x
{'protein': tensor([[ 1.5659e+00,  7.6492e-01, -1.3036e-01,  1.0010e+00,  7.0579e-01,
          1.3227e-02,  1.3354e-01,  1.3097e-01,  2.2603e-02,  1.6489e-01,
          1.1034e+00,  4.6917e-01],
       ...,
       grad_fn=<AddBackward0>)}

which is also saved in the ndata of the edge_subgraph:

>>> edge_subgraph.ndata
{'disease': {'feature': tensor([], size=(0, 500)), '_ID': tensor([], dtype=torch.int64)}, 'drug': {'feature': tensor([], size=(0, 7467)), '_ID': tensor([], dtype=torch.int64)}, 'protein': {'feature': tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000], ..., ]), 
'_ID': tensor([327,  35, 264,  27, 211, 114, 355, 367, 108, 303,  25, 238, 188, 194,
        272, 351, 157,  63, 343, 245, 331, 167, 120]), 
'x': tensor([[ 1.5659e+00,  7.6492e-01, -1.3036e-01,  1.0010e+00,  7.0579e-01, 1.3227e-02,  1.3354e-01,  1.3097e-01,  2.2603e-02,  1.6489e-01, 1.1034e+00,  4.6917e-01], ...,
        , grad_fn=<AddBackward0>)}}

it has the shape:

>>> x['protein'].shape
 torch.Size([23, 12])

since the edge_subgraph has corresponding number of nodes:

>>> edge_subgraph
Graph(num_nodes={'disease': 0, 'drug': 0, 'protein': 23},
         ...,
     )

I am not sure if this is correct, since it is a dictionary. But when I try to just take the dictionary value of x, it throws me an error that x has to be a dictionary. I could not find the source code of dgl.function.u_dot_v(), so I am not sure if it can handle the input being a dictionary.