The graph I want to deal with has different types of nodes with different feature dimensions. Each kind of nodes has a specific update function. I wonder how can I do message passing in this case?
I try this way:
For example, there are two kinds of nodes in the graph. One kind of nodes has feature1, the other kind of nodes has feature2. Define all nodes with both two features and set the feature which doesn’t belong to this node to zero.
This is kinda stupid and I don’t know if it really works.
def msg_func(self, edges):
weight_feature1_type = self.weight_feature1[edges.data['type']]
weight_feature2_type = self.weight_feature2[edges.data['type']]
msg_feature1 = torch.matmul(edges.src['feature2'], weight_feature1_type)
msg_feature2 = torch.matmul(edges.src['feature1'], weight_feature2_type)
return {'msg_feature1': msg_feature1, 'msg_feature2': msg_feature2}
def red_func(self, nodes):
return {'feature1': torch.sum(nodes.mailbox['msg_feature1'], dim=1),
'feature2': torch.sum(nodes.mailbox['msg_feature2'], dim=1)}