Hi, the bipartite case is indeed a bit more complicated than the homogeneous graph case, which we should explain it more clearly. As shown in the doc page (your first figure), the SAGEConv module on a bipartite graph requires two node feature tensors: one for source nodes (i.e., u
); one for destination nodes (i.e., v
). When there are multiple layers, you need to maintain two instead of one node feature tensors: an h_user
for user features and an h_item
for item features. This means that at each layer, you need two SAGEConv modules for computing message passing from user->item and from item-> user respectively. Here is a code snippet for you to get started (I haven’t tested it but just to show the idea):
class GraphSAGEBipartite(nn.Module):
def __init__(self, user_in_feats, item_in_feats, h_feats):
super().__init__()
self.u2i_conv1 = SAGEConv(user_in_feats, h_feats, 'mean')
self.i2u_conv1 = SAGEConv(item_in_feats, h_feats, 'mean')
self.u2i_conv2 = SAGEConv(h_feats, h_feats, 'mean')
self.i2u_conv2 = SAGEConv(h_feats, h_feats, 'mean')
def forward(self, g, user_in_feat, item_in_feat):
h_user = self.i2u_conv1(g, (item_in_feat, user_in_feat))
h_item = self.i2u_conv1(g, (user_in_feat, item_in_feat))
h_user = F.relu(h_user)
h_item = F.relu(h_item)
h_user = self.i2u_conv1(g, (h_item, h_user))
h_item = self.i2u_conv1(g, (h_user, h_item))
return h_user, h_item
For more general cases (more than two type of nodes and edges), we provide a wrapper module called HeteroGraphConv
which essentially does the above for you. You could check out its doc here: 3.3 Heterogeneous GraphConv Module — DGL 0.6.1 documentation