Can't get to pass the edge weights for inference step to get node representations

Hello team

After going through documentation & lots of question on this forum, I am able to put together some code to train a link prediction model for bipartite graph. However, I can’t get to figure out two things:

  1. How to pass the edge weights correctly.
  2. How do get node embeddings (inference step) after training the model.

Below is the code to reproduce the errors I’m getting. Most of the code is taken from this thread:

## create sample data 
data_dict = {('source', 'has_follower', 'user'): (torch.tensor([0, 1, 0, 0, 1, 2]), torch.tensor([0, 1, 2, 3, 4, 5])), 
             ('user', 'follows', 'source'): (torch.tensor([0, 1, 2, 3, 4, 5]), torch.tensor([0, 1, 0, 0, 1, 2]))}

dgl_graph= dgl.heterograph(data_dict)

## set node and edge features
dgl_graph.nodes['source'].data['source_embedding'] = torch.randn(dgl_graph.number_of_nodes('source'), 70) #torch.zeros(dgl_graph.number_of_nodes('source'), 70)
dgl_graph.nodes['user'].data['user_embedding'] = torch.randn(dgl_graph.number_of_nodes('user'), 80) #torch.zeros(dgl_graph.number_of_nodes('user'), 80)
dgl_graph.edges['has_follower'].data['feature'] = torch.randn(dgl_graph.num_edges('has_follower'), 1)
dgl_graph.edges['follows'].data['feature'] = torch.randn(dgl_graph.num_edges('follows'), 1)

# setup edge sampler
eid_dict = {etype: dgl_graph.edges(etype=etype, form='eid') for etype in dgl_graph.canonical_etypes}
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2) # 2 layers
dataloader = dgl.dataloading.EdgeDataLoader(dgl_graph, 
                                            eid_dict, 
                                            sampler, 
                                            negative_sampler = dgl.dataloading.negative_sampler.Uniform(5)) # 5 negative samples

# define GCN

class TestRGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, canonical_etypes):
        super(TestRGCN, self).__init__()
        self.in_feats = in_feats
        self.hid_feats = hid_feats
        self.out_feats = out_feats
        self.conv1 = dglnn.HeteroGraphConv({
                etype : dglnn.GraphConv(in_feats[utype], hid_feats, norm='right')
                for utype, etype, vtype in canonical_etypes
                })
        
        self.conv2 = dglnn.HeteroGraphConv({
                etype : dglnn.GraphConv(hid_feats, out_feats, norm='right')
                for _, etype, _ in canonical_etypes
                })

    def forward(self, blocks, inputs, edge_features):
        print({k: v.shape for k,v in inputs.items()})

        x = self.conv1(blocks[0], inputs, mod_kwargs = {'has_follower':{'edge_weight': edge_features['has_follower']},
                                                 'follows': {'edge_weight': edge_features['follows']}})
        
        x = self.conv2(blocks[1], x, mod_kwargs = {'has_follower':{'edge_weight': edge_features['has_follower']},
                                                 'follows': {'edge_weight': edge_features['follows']}})
        return x

class TestModel(nn.Module):
    # here we have a model that first computes the representation and then predicts the scores for the edges
    def __init__(self, in_features, hidden_features, out_features, canonical_etypes):
        super().__init__()
        self.sage = TestRGCN(in_features, hidden_features, out_features, canonical_etypes,)
        self.pred = HeteroScorePredictor()
    def forward(self, g, neg_g, blocks, x, edge_features):        
       
        x = self.sage(blocks, x, edge_features)
        pos_score = self.pred(g, x)
        neg_score = self.pred(neg_g, x)
        return self.sage, pos_score, neg_score 

# initiate model

model = TestModel(in_features={'source':70, 'user':80}, 
                  hidden_features=128, 
                  out_features=64, 
                  canonical_etypes=dgl_graph.canonical_etypes)

optimizer = th.optim.Adam(model.parameters())


# run model

for epoch in range(1):
    for input_nodes, positive_graph, negative_graph, blocks in dataloader:
        model.train()
#         print(epoch)
#         print(type(blocks))
        print(len(blocks))
        
        node_features = {'source': blocks[0].srcdata['source_embedding']['source'], 
                         'user': blocks[0].srcdata['user_embedding']['user'],
                        }
        edge_features = {'has_follower': blocks[0].edata['feature'][('source', 'has_follower', 'user')],
                         'follows': blocks[0].edata['feature'][('user', 'follows', 'source')]}
        sage_model, pos_score, neg_score = model(positive_graph, negative_graph, blocks, node_features, edge_features)
        loss = compute_loss(pos_score, neg_score, dgl_graph.canonical_etypes)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Problem 1:
Currently, I’m getting the Assertion error below when self.conv2 is about to execute. It complains about edge shape. It doesn’t complain about it when computing self.conv1 in forward function. What is happening here. How can I fix it?

    390             aggregate_fn = fn.copy_src('h', 'm')
    391             if edge_weight is not None:
--> 392                 assert edge_weight.shape[0] == graph.number_of_edges()
    393                 graph.edata['_edge_weight'] = edge_weight
    394                 aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')

Problem 2:
If you see in the code, i’m also returning sage_model when calling the model(). I thought this model can be used to get the node embeddings. So, I do something like:

source_em = sage_model(blocks, node_features, edge_features)['user']
source_em.shape
torch.Size([1, 64])

This is wrong. I have six users in my data set I expect to get embedding for all six user nodes. How can I do this as well ?

Regards
yolo

You need to get the edges feature for each block, but not only for blocks[0].

For example:

    def forward(self, blocks, inputs, edge_features):
        print({k: v.shape for k,v in inputs.items()})

        edge_features = {'has_follower': blocks[0].edata['feature'][('source', 'has_follower', 'user')],
                            'follows': blocks[0].edata['feature'][('user', 'follows', 'source')]}
        x = self.conv1(blocks[0], inputs, mod_kwargs = {'has_follower':{'edge_weight': edge_features['has_follower']},
                                                 'follows': {'edge_weight': edge_features['follows']}})
        
        edge_features = {'has_follower': blocks[1].edata['feature'][('source', 'has_follower', 'user')],
                            'follows': blocks[1].edata['feature'][('user', 'follows', 'source')]}
        x = self.conv2(blocks[1], x, mod_kwargs = {'has_follower':{'edge_weight': edge_features['has_follower']},
                                                 'follows': {'edge_weight': edge_features['follows']}})
        return x

For whole graph embedding you need to do the inference on the whole graph, you can refer to the example at https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling_unsupervised.py#L35 as how evaluate function is implemented

thanks @VoVAllen
could you help me understand the following: why does the dimension of tensor in the final layer is not the same number of nodes:

shape after conv1
{'source': torch.Size([3, 128]), 'user': torch.Size([5, 128])}
shape after conv2
{'source': torch.Size([1, 64]), 'user': torch.Size([4, 64])}

After self.conv2 is executed, I expect the source & user tensor to stay the same size i.e. source=3 and user = 5 but it becomes less to source=1 and user=4. I don’t know why.

Hi,

The block is different from subgraph, it’s a flattened computation graph induced by the original graph. You can refer to the animation at Introduction of Neighbor Sampling for GNN Training — DGL 0.7.1 documentation

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.