Changing Graph Structure after Construction and Running Models with Partial Edges

Hi,

I was wondering, if I have already built a heterograph like below, can I add new edge types to it/change the structure? For example, can I add a new ‘user agrees_with source’? If so, how would I do that?

data_dict = {('source', 'has_follower', 'user'): (torch.tensor([0, 0]), torch.tensor([0, 0])), ('user', 'follows', 'source'): (torch.tensor([0, 0]), torch.tensor([0, 0])) }
dgl_graph = dgl.heterograph(data_dict)

I also have another question, I constructed another graph in which I added a new ‘tweet’ node, which is connected to only some sources and some users. When trying to run this model and train it with link prediction, I get an error in the ScorePrediction. Here is the code I’m using for that:

class HeteroScorePredictor(nn.Module):
    def forward(self, edge_subgraph, x):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['h'] = x
            for etype in edge_subgraph.canonical_etypes:
                if edge∂_subgraph.num_edges(etype) <= 0:
                    continue
                edge_subgraph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)
            return edge_subgraph.edata['score']

And here is the error:

Traceback (most recent call last):                                                                               
  File "lib/python3.6/multiprocessing/process.py", line 258, in _
bootstrap                                                                                                        
    self.run()                                                                                                   
  File "lib/python3.6/multiprocessing/process.py", line 93, in ru
n                                                                                                                
    self._target(*self._args, **self._kwargs)                                                                    
  File "GNN_model.py", line 1039, in running_code                 
    pos_score, neg_score = model(positive_graph, negative_graph, blocks, node_features)                          
  File "lib/python3.6/site-packages/torch/nn/modules/module.py", 
line 550, in __call__                                                                                            
    result = self.forward(*input, **kwargs)                                                                      
  File "GNN_model.py", line 200, in forward                       
    pos_score = self.pred(g, x)                                                                                  
  File "lib/python3.6/site-packages/torch/nn/modules/module.py", 
line 550, in __call__                                                                                            
    result = self.forward(*input, **kwargs)                                                                      
  File "GNN_model.py", line 187, in forward                       
    edge_subgraph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)                              
  File "lib/python3.6/site-packages/dgl/heterograph.py", line 411
9, in apply_edges                                                                                                
    edata = core.invoke_gsddmm(g, func)                                                                          
  File "lib/python3.6/site-packages/dgl/core.py", line 201, in in
voke_gsddmm                                                                                                      
    x = alldata[func.lhs][func.lhs_field]                                                                        
  File "lib/python3.6/site-packages/dgl/view.py", line 66, in __g
etitem__                                                                                                         
    return self._graph._get_n_repr(self._ntid, self._nodes)[key]                                                 
  File "lib/python3.6/site-packages/dgl/frame.py", line 386, in _
_getitem__                                                                                                       
    return self._columns[name].data                                                                              
KeyError: 'h' 

I’m not sure what the issue is here since I’m skipping edge types that don’t exist in the score predictor.

I was wondering, if I have already built a heterograph like below, can I add new edge types to it/change the structure? For example, can I add a new ‘user agrees_with source’? If so, how would I do that?

You cannot modify the scheme of a heterograph once created. You need to create a new heterograph.

For the error, can you check for which canonical edge type it occurs? Can you also check whether the corresponding source and destination node types have the feature ‘h’?

Turns out I had a bug, I was able to find it with your debugging help, thanks.

On an unrelated note, it seems like I can add edges to my graph when they already exist (meaning an edge of the same type between the same nodes exists). How does DGL handle this when training something like link prediction? Will it treat it as if there was only one edge there? Also, is there an efficient way I can check for this (make sure I don’t add a duplicate edge) when adding the edges?

On an unrelated note, it seems like I can add edges to my graph when they already exist (meaning an edge of the same type between the same nodes exists).

Yes, the graph data structures in DGL allow multigraphs.

How does DGL handle this when training something like link prediction? Will it treat it as if there was only one edge there?

No, they are treated as two separate edges.

Also, is there an efficient way I can check for this (make sure I don’t add a duplicate edge) when adding the edges?

You may find these APIs helpful:

Thanks that was helpful.

I had another question, can I train link prediction and node classification simultaneously with a weighting parameters between the two losses? How would I do this with data loaders in DGL?

Currently, I am training node classification with a node data loader like so:

dataloader = dgl.dataloading.NodeDataLoader(graph, {'source': train_idx}, sampler, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers)
...
for iteration, (input_nodes, output_nodes, blocks) in enumerate(dataloader):

and I am training Link Prediction like so:

dataloader = dgl.dataloading.EdgeDataLoader(graph, train_eid_dict, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(5), batch_size=args.batch_size, shuffle=True, drop_last=False, pin_memory=True, num_workers=args.num_workers)
...
for iteration, (input_nodes, positive_graph, negative_graph, blocks) in enumerate(dataloader)

I’m now wondering if it would be possible to combine these two training objectives into one, so my graph is updated both based on a loss from Link Prediction and a loss from Node Classification.

Could I just use the EdgeDataLoader as normal and train for Link Prediction? And then collect all the embeddings for the source nodes and apply a loss to those that is based on classification (but applied to the whole graph)? This way the graph would be trained for LP and node classification, but how would I add the hyperparameter that controls the weighting between the two losses because I can’t control how many source nodes will appear in each iteration of EdgeDataLoader? Or is there some other way to do this?

  1. Do you want to use a same model for both node classification and link prediction?
  2. What do you mean by “my graph is updated both based on a loss from Link Prediction and a loss from Node Classification”?

Yes

I currently have two ways to train my graph - node classification and link prediction. I am looking for a way to do them simultaneously. It’s similar to when training any ML model, you can have two different loss functions and then you sum the loss and apply that to the model optimizer. But in this case, I want my two different loss functions to be node classification (where the loss applied to my graph is based on the label of the ‘source’ nodes), and link prediction (where loss applied is based on the positive/negative edges)

I currently have two ways to train my graph - node classification and link prediction. I am looking for a way to do them simultaneously. It’s similar to when training any ML model, you can have two different loss functions and then you sum the loss and apply that to the model optimizer. But in this case, I want my two different loss functions to be node classification (where the loss applied to my graph is based on the label of the ‘source’ nodes), and link prediction (where loss applied is based on the positive/negative edges)

But you are updating the model rather than the graph, right?

Could I just use the EdgeDataLoader as normal and train for Link Prediction? And then collect all the embeddings for the source nodes and apply a loss to those that is based on classification (but applied to the whole graph)? This way the graph would be trained for LP and node classification, but how would I add the hyperparameter that controls the weighting between the two losses because I can’t control how many source nodes will appear in each iteration of EdgeDataLoader? Or is there some other way to do this?

I think you can simply construct two dataloaders, one for node classification and one for link prediction. During training, you just fetch mini-batches from the two dataloaders separately and reweight the task-specific loss with some hyperparameters. Note that the validation/test edges in link prediction should not be in the graph for node classification as well.

Yes, sorry I misspoke.

What do you mean by this? The Node Classification model has a train/validation set, but Link Prediction doesn’t right. So could I not train the Node Classification model with the train/validation/test set and just train the entire model (using the entire graph) simultaneously with Link Prediction? Of course, when I evaluate I will evaluate as if I had trained only using Node Classification.

Can I do something like this?

First, change the forward function to handle the cases where it gets a Link Prediction (LP) and Node Classification (NC) inputs:

def forward(self, g, neg_g, blocks, x):
        x = self.sage(blocks, x)
        if g is None:
            return x
        pos_score = self.pred(g, x)
        neg_score = self.pred(neg_g, x)
        return pos_score, neg_score 

And then the training loop:

# indeces for the nodes for node classification
train_mask_tensor = torch.from_numpy(train_mask)
train_idx = torch.nonzero(train_mask_tensor).squeeze()
test_mask_tensor = torch.from_numpy(test_mask)
test_idx = torch.nonzero(test_mask_tensor).squeeze()

# dataloader for node classification (NC)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args.n_layers)
dataloader_nc = dgl.dataloading.NodeDataLoader(curr_g, {'source': train_idx}, sampler, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers)
test_sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args.n_layers)
dataloader_test = dgl.dataloading.NodeDataLoader(curr_g, {'source': test_idx}, test_sampler, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers)

# dataloader for Link Prediction (LP)
train_eid_dict = {canonical_etype: torch.arange(g.num_edges(canonical_etype[1]), dtype=torch.int64).to(torch.device('cuda')) for canonical_etype in g.canonical_etypes}
dataloader_lp = dgl.dataloading.EdgeDataLoader(g, train_eid_dict, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(5), batch_size=args.batch_size, shuffle=True, drop_last=False, pin_memory=True, num_workers=args.num_workers)

loss_fcn_nc = nn.CrossEntropyLoss()
loss_fcn_nc = loss_fcn_nc.to(torch.device('cuda'))

best_train_acc = 0.0

for epoch in range(args.n_epochs):
    model.train()
    optimizer.zero_grad()
    for iteration, ((input_nodes_nc, output_nodes_nc, blocks_nc), (input_nodes_lp, positive_graph, negative_graph, blocks_lp)) in enumerate(zip(dataloader_nc, dataloader_lp)):
        # first, run the node classification part through the model
        blocks_nc = [b.to(torch.device('cuda')) for b in blocks_nc]
        input_features_nc = blocks_nc[0].srcdata     # returns a dict
        output_labels_nc = blocks_nc[-1].dstdata['source_label']['source']    # returns a dict
        output_labels_nc = (output_labels_nc - 1).long()
        output_labels_nc = torch.squeeze(output_labels_nc)
        node_features_nc = get_features_given_blocks(curr_g, blocks_nc, graph_style)
        output_predictions = model(g=None, neg_g=None, blocks_nc, node_features_nc)['source']
        loss_nc = loss_fcn_nc(output_predictions, output_labels_nc)

        # now compute the loss for LP 
        blocks_lp = [b.to(torch.device('cuda')) for b in blocks_lp]
        positive_graph = positive_graph.to(torch.device('cuda'))
        negative_graph = negative_graph.to(torch.device('cuda'))
        node_features_lp = get_features_given_blocks(curr_g, blocks_lp, graph_style, adding_advice=adding_advice)
        pos_score, neg_score = model(positive_graph, negative_graph, blocks_lp, node_features_lp)
        loss_lp = compute_loss(pos_score, neg_score, curr_g.canonical_etypes)

        # sum the losses and update the model
        loss = loss_nc + loss_lp
        loss.backward()
        clip_grad_norm_(model.parameters(), 0.25)
        optimizer.step()
        optimizer.zero_grad()

        # now do the evaluation
        # ...

What do you mean by this? The Node Classification model has a train/validation set, but Link Prediction doesn’t right. So could I not train the Node Classification model with the train/validation/test set and just train the entire model (using the entire graph) simultaneously with Link Prediction? Of course, when I evaluate I will evaluate as if I had trained only using Node Classification.

Typically, node classification has a train/val/test split of nodes and link prediction has a train/val/test split of edges. If you want to evaluate your model for node classification, you still need to use train/val/test split of nodes. It’s just that for node classification you only optimize loss on the training nodes while for link prediction you optimize loss over the edges.

The code looks fine to me.

Yeah, I’m not cheating by doing this right (use the train/val/test split for node classification and no train/val/test split for link prediction)?

Why would you want to use a train/val/test split for link prediction, is it just to evaluate the model?

It depends on whether you are really performing a joint training for node classification/link prediction. If link prediction is just an auxiliary task, then the setting should be fine.

Yea, since my graph is really big and diverse, I’m trying to use link prediction as an extra task to increase understanding and hopefully help the node classification training. Because I’ve seen some decent results doing both separately (link prediction being evaluated as the graph embedding and then predicting the source nodes).

@mufeili

I have another question, which is kind of unrelated but I don’t think I need to make a new topic for it. I understanding that when training GNN’s for Link Prediction, if I want to make two nodes have an embedding closer together, I should connect them with an edge. However, what should I do if I want to make two nodes embeddings be farther apart from each other once the graph is trained? Assuming they are already not connected, how can I modify my Link Prediction training process to push apart Node A and B by a certain amount if I choose (let’s say I know what Node A and Node B are from beforehand)?

You should not modify the graph structure in link prediction. The objective function for link prediction is to maximize the score for a pair of nodes to be connected as below.

\log s(u, v)

To push apart the embeddings of two nodes, you can simply minimize the score for a pair of nodes to be connected. To select a margin \Delta, you can minimize

\max(0, \Delta+\log s(u, v))

Yea, I was talking about when creating the graph if I want two nodes to be closer together I would connect them with an edge.

How can I do this in DGL with EdgeDataLoader for two specific nodes (let’s say I know their ID’s)?

How can I do this in DGL with EdgeDataLoader for two specific nodes (let’s say I know their ID’s)?

You can pass a negative_sampler when creating an EdgeDataLoader instance.

Currently I use the dgl.dataloading.negative_sampler.Uniform(5) sampler.

From the source code (https://docs.dgl.ai/_modules/dgl/dataloading/negative_sampler.html#Uniform), I see that it gives us some destination nodes given a source node. So now, in this case, are you saying that I should not choose these destination nodes randomly for the specific nodes I want to push apart, and instead only choose those nodes that I want to push it apart from (the destination being the one I want to push apart from the source)?

In that case, how can I modify the generate function, maybe something like this? And how would I know that this is the case of edges that need to be pushed apart? Maybe something like this in which case I pass a dictionary where the key is the source node and the value is the destination node? And if the source node appears in that dictionary as a key then I just return the negative node as the destination node (maybe with some scaling)?

def _generate(self, g, eids, canonical_etype, ids_to_push):
        _, _, vtype = canonical_etype
        shape = F.shape(eids)
        dtype = F.dtype(eids)
        ctx = F.context(eids)
        shape = (shape[0] * self.k,)
        src, _ = g.find_edges(eids, etype=canonical_etype)
        if src in ids_to_push:
            dst = ids_to_push[src]
        else:
            dst = F.randint(shape, dtype, ctx, 0, g.number_of_nodes(vtype))
        src = F.repeat(src, self.k, 0)
        
        return src, dst

So now, in this case, are you saying that I should not choose these destination nodes randomly for the specific nodes I want to push apart, and instead only choose those nodes that I want to push it apart from (the destination being the one I want to push apart from the source)?

If you want to choose specific nodes that you want to push apart, you need a custom negative sampler as in the code snippet you shared. Nevertheless, find_edges may yield additional overhead.

Another possibility is to construct two graphs and two EdgeDataLoader instances, one for maximizing the score of positive edges, one for minimizing the score of negative edges.

Isn’t find_edges part of the normal Uniform Negative Sampler implementation?

By this, are you now saying that we should choose negative edges and aim to minimize their score and have positive edges as the new “negative examples”? Basically doing the opposite of the link prediction negative sampler? Would this have a benefit over the previous method?