GraphSAGE for heterogeneous graphs

I want to find similar implementation of original unsupervised graphsage for heterogeneous graphs where each node might have different initial feature vector size. Then project to common space and perform regular unsupervised graphsage
Is there an existing implementation for this?

Probably PinSAGE is what you want. https://arxiv.org/pdf/1806.01973.pdf

However DGL doesn’t have a mature implementation for now. @BarclayII is working on the related API (https://github.com/dmlc/dgl/pull/1249). We hope we could bring this out soon.

1 Like

Is there any ETA for this feature?

The core features will be done at the end of Feb. Model development will start next month.

1 Like

BTW, the original GraphSAGE only works for homograph, so @navmarri if you could point us to some published papers about using GraphSAGE on heterograph, we could try include that in our next release.

There is no paper per se, I came across the blog from uber for multimodal data

1 Like

To summarize the block, they just add an extra projection layer for nodes with different feature size to a common space. Is it possible to do that in dgl? just adding a projection layer for few nodes and excluding others?

It’s very doable. The idea is when we create the graph, we label user nodes from 0 to #users - 1, while item nodes from #users to #user+#items - 1. During forward propagation, we apply two FC layers acting as the two individual projection layers. The results can be concatenated since they now have the same feature dimension and can be fed to DGLGraph as normal.

def some_forward_func(self, g, user_feats, item_feats):
    # user_feats shape: (#users, D1)
    # item_feats shape: (#items, D2)
    user_feats = self.user_fc(user_feats)  # now of shape (#users, D)
    item_feats = self.item_fc(item_feats)  # now of shape (#items, D)
    all_feats = th.cat([user_feats, item_feats], 0)
    g.ndata['h'] = all_feats
    g.update_all(...)
    # optionally, you may split them back
    user_feats, item_feats = th.split(g.ndata['h'], [num_users, num_items])
    return user_feats, item_feats
1 Like