Alternate weight matrix in a bipartite graph

Hi,

I’m doing some experimentations in my research in which I will need to use different weight matrixes for different node types in a heterogeneous bipartite graph.

For the sake of example, in message massing and aggregation, every node of type A must use a weight matrix W, and every node of type B must use a weight matrix W’. Nodes type A doesn’t have any interference in the weight matrix W’ and vice versa.

Is it possible to implement this using DGL?

Below is an example and see if it meets your need.

import dgl
import torch
from dgl.nn.pytorch import HeteroGraphConv, GraphConv

# Create a bipartite graph of two node types A and B
g = dgl.heterograph({
      ('A', 'A->B', 'B'): (torch.tensor([1, 2]), torch.tensor([2, 3])),
      ('B', 'B->A', 'A'): (torch.tensor([2, 2]), torch.tensor([1, 3]))
})
# Set initial node features
g.ndata['h'] = {'A': torch.randn(4, 1), 'B': torch.randn(4, 1)}
# Use a GraphConv for each relation. In a bipartite graph, this is 
# equivalent to a weight matrix per node type
model = HeteroGraphConv({'A->B': GraphConv(1, 2, norm='none', bias=False),
                         'B->A': GraphConv(1, 2, norm='none', bias=False)})
model(g, g.ndata['h'])
1 Like