Link Prediction with Multiple Nodes (attribute, heterogeneous)

Hi community,

I have a dimensionality problem when trying to do link prediction between user-node and item-node. I took inspiration from this guide: [DGL Guide](5.3 Link Prediction — DGL 0.6.1 documentation)
Currently, the item-node has 5 attributes whereas the user-node only has 3 attributes. This keeps giving me dimensionality problem with the well-known error message: “mat1 and mat2 shapes cannot be multiplied”. If I do dummy features of same dimension, the problem is not present.

My question is if it is possible to do link prediction between two different node-types with different number of attributes and if so, how? I see that some posts recommend recommends using MLP to align dimensionality, but I can not find anywhere describing why and how this solves the issue.

I am quite new to GNN, so could be I am missing some essential knowledge within the area.

Thanks in advance.

Current code, which works:

hetero_graph = g

k = 5

g.nodes['member'].data['features'] = torch.randn(7000, 5)
g.nodes['course'].data['features'] = torch.randn(88, 5)

user_feats = hetero_graph.nodes['member'].data['features'].float()
item_feats = hetero_graph.nodes['course'].data['features'].float()

model = Model(5, 10 , 5, hetero_graph.etypes)

node_features = {'member': user_feats, 'course': item_feats}
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    negative_graph = construct_negative_graph(hetero_graph, k, ('member', 'm_attends', 'course'))
    pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('member', 'm_attends', 'course'))
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())
1 Like

I see that some posts recommend recommends using MLP to align dimensionality, but I can not find anywhere describing why and how this solves the issue.

Let’s say item-nodes have 5 attributes and user-nodes have 3 attributes. You can use an MLP to project 3 attributes to hidden representations of size 5, vice versa.

Makes sense, thanks. Would that needed to be applied prior to the convolutional layer? Right now the difference in attribute-counts gives problems even before the first convolution but every paper I read says that MLP should be part of the prediction/score function.

Yes, typically you will do that before the first graph convolution layer.

Thanks so much! That helped a ton. The model is now running perfectly.

I am using an EdgeDataLoader for training, validation and testing, but what is the recommended approach in DGL when wanting to predict (get a score) for all possible links between users and items not currently linked? There are users who are not connected to any items so far, but would still like to get score for those as well. Should I create them in a new graph and use the embeddings to get the probability of it being present? Or is there a way where I can obtain a user X item table with scores? :slight_smile: The latter would be highly preferred.

I got everything to run, thanks @mufeili.
Now a have some theoretical questions regarding Link Prediction models, that I hope you can help me understand:

  1. Does the prefere having categorical variables (e.g. university, city or such) presented as node-features or as edges between users like a social network?
  2. I get an AUC on the test set of ~0.12, which shows it finds patterns just predicts complete opposite of what it is supposed to. I have tried to change up almost everything, but nothing seems to fix it. Any idea what could cause this? I have listed the things I think could influence it below:

Thanks! :slight_smile:

My graph is as follows:

Graph(num_nodes={'course': 88, 'member': 144428, 'non-member': 8459},
      num_edges={('course', 'm_attended_by', 'member'): 9325, ('course', 'nm_attended_by', 'non-member'): 10978, ('member', 'm_attends', 'course'): 9325, ('non-member', 'nm_attends', 'course'): 10978},
      metagraph=[('course', 'member', 'm_attended_by'), ('course', 'non-member', 'nm_attended_by'), ('member', 'course', 'm_attends'), ('non-member', 'course', 'nm_attends')])

Dataloader:

device = 'cpu'

#data = train_ori.dataset

negative_sampler = dgl.dataloading.negative_sampler.Uniform(5)
sampler = dgl.dataloading.MultiLayerNeighborSampler([8, 8])

trainloader = dgl.dataloading.EdgeDataLoader(
    # The following arguments are specific to NodeDataLoader.
    g,                            # The graph
    eids = train_eid_dict,  # The edges to iterate over
    block_sampler = sampler,                                # The neighbor sampler
    negative_sampler=negative_sampler,      # The negative sampler
    #exclude='reverse_types',
    reverse_etypes={'m_attends': 'm_attended_by', 'm_attended_by': 'm_attends',
                  'nm_attends': 'nm_attended_by', 'nmattended_by': 'nm_attends'
                  }
    ,

    device=device,                          # Put the MFGs on CPU or GPU
    # The following arguments are inherited from PyTorch DataLoader.
    #batch_size=len(train_eid_dict['m_attends']),    # Batch size
    batch_size = 2048,
    shuffle=True,       # Whether to shuffle the nodes for every epoch
    drop_last=True,    # Whether to drop the last incomplete batch
    num_workers=0       # Number of sampler processes
)

And finally the loop for running the model:

for epoch in range(200):
    model.train()
    train_loss = 0.0
    for input_nodes, positive_graph, negative_graph, blocks in trainloader:
        blocks = [b.to(torch.device(device)) for b in blocks]

        positive_graph = positive_graph.to(torch.device(device))
        negative_graph = negative_graph.to(torch.device(device))

        node_features = {'member': blocks[0].srcdata['features']['member'].float(), \
                         'non-member': blocks[0].srcdata['features']['non-member'].float(), \
                         'course': blocks[0].srcdata['features']['course'].float()}

        pos_score, neg_score, _ = model(positive_graph, negative_graph, blocks, node_features)
        loss = compute_loss(pos_score, neg_score, g.canonical_etypes)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()/positive_graph.number_of_edges()
  1. Does the prefere having categorical variables (e.g. university, city or such) presented as node-features or as edges between users like a social network?

I generally prefer having some as node features. Having an additional relation type can be costly so I will avoid that if possible.

  1. I get an AUC on the test set of ~0.12, which shows it finds patterns just predicts complete opposite of what it is supposed to. I have tried to change up almost everything, but nothing seems to fix it. Any idea what could cause this? I have listed the things I think could influence it below:

Have you tried computing AUC on the training set? Does that give a reasonable number? Did you observe loss decreasing over time?

I generally prefer having some as node features. Having an additional relation type can be costly so I will avoid that if possible.

Perfect. Will keep it that way then. Thanks!

Have you tried computing AUC on the training set? Does that give a reasonable number? Did you observe loss decreasing over time?
Yes. It is around 0.2 as well, so it does model the patterns but just wrong way around as the loss keeps decreasing aswell. It could be fixed by just swapping the positive and negative scores. Otherwise it might be our score function which is wrong. :confused: Maybe you can identify if we swapped some elements around?

HeteroScorePredictor:

class HeteroScorePredictor(nn.Module):
    def forward(self, edge_subgraph, x):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['h'] = x
            for etype in edge_subgraph.canonical_etypes:
                edge_subgraph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)
                #edge_subgraph.apply_edges(self.apply_edges, etype=etype)
            return edge_subgraph.edata['score']

Compute Loss:

def compute_loss(pos_score, neg_score, canonical_etypes):
    # Margin loss
    all_losses = []
    for given_type in canonical_etypes:
        n_edges = pos_score[given_type].shape[0]
        if n_edges == 0:
            continue
        all_losses.append((1 - neg_score[given_type].view(n_edges, -1) + 
                           pos_score[given_type].unsqueeze(1)).clamp(min=0).mean())
    return torch.stack(all_losses, dim=0).mean()

The loss term does not appear correct to me. The loss should go higher as the positive score goes lower and the negative score goes higher. The loss function you used goes the other way.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.