Getting Node Feature Shapes Right - A Heterograph Example

Hi all,
I’ve been playing around with heterographs with different feature sizes and numbers of nodes for node classifciation etc.
I put together a small example as a sanity check for myself and found it to be quite usefull.

Graph Construction

Lets construct a simple heterograph, two node types (A,B) with two different sized feature spaces (1,2), and different number of nodes (6,5).

import torch
import dgl

# links between the 6 nodes of type A
u1 = [0,1,2,3,4,5]
v1 = [4,5,0,1,2,3]

# links between the 5 nodes of type B
u2 = [0,1,2,3,4]
v2 = [4,0,1,2,3]

# links between the node type A and type B
u3 = [0,1,2,3,4]
v3 = [0,1,2,3,4]

# NOTE in dgl heterographs, each node type is indexed independantly with 'ids' starting from 0.
# this is important as if my indexes had started at 1 I'd have had an extra of each node type.
# that can cause trouble when assigning features.

graph = dgl.heterograph({   ('a', 'same_a', 'a') : (u1, v1),
                            ('b', 'same_b', 'b') : (u2, v2),
                            ('a', 'diff_a', 'b') : (u3, v3),
                            ('b', 'diff_b', 'a') : (v3, u3)}, num_nodes_dict = {'a':6,'b':5})

# assign features to our nodes data
graph.nodes['a'].data['feat'] = torch.ones((6,1))
graph.nodes['b'].data['feat'] = torch.ones((5,2))

Model Implementation

Now our model will have to handle inputs of different sizes. This can be a little confusing as the way the modules within HeteroGraphConv are implemented can get very elaborate… I suggest making a simple modification so you can adjust the number of features later if you want to.

Here instead of specifying the number of inputs and the relationship types as is common in the tutorials, I combine them into a rel_shape_dict.

Note that we only use the values of this dictionary (shape) in the first layer. By the second layer our model is instead using the hid_feats size.

import  torch.nn as nn
import  dgl.nn as dglnn
import  torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, rel_shape_dict, hid_feats, out_feats):
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.SAGEConv(size, hid_feats, aggregator_type='mean')
            for rel, size in rel_shape_dict.items()}, aggregate='mean')

        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.SAGEConv(hid_feats, out_feats,aggregator_type='mean')
            for rel in rel_shape_dict.keys()}, aggregate='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

rel_shape_dict = {'same_a' : 1, 'same_b' : 2, 'diff_a' : (1,2), 'diff_b' : (2,1)}
#rel_shape_dict = {'same_a' : (1,1), 'same_b' : (2,2), 'diff_a' : (1,2), 'diff_b' : (2,1)} # this is equivalent

model = Model(rel_dict,10,1)

# quickly put some inputs together and bam you got a prediction
inputs = {'a':graph.nodes['a'].data['feat'],'b':graph.nodes['b'].data['feat']}


Please note this is by no means a good model or graph merely intended for demonstrative purposes. If you’d like to add or improve please feel free!

1 Like

@mufeili I have a question about aggretation behaviour linked to the above example.

Given the model above I was checking how the features are transformed during aggregation.

Looking at the source code for the pytorch implementation of SAGEConv, it appears that if the size of the source features is larger than the destination feature space then a linear transform is applied, before message passing.

lin_before_mp = self._in_src_feats > self._out_feats # <--- this line

if self._aggre_type == 'mean':
                graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
                graph.update_all(msg_fn, fn.mean('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
                if not lin_before_mp:
                    h_neigh = self.fc_neigh(h_neigh)

My questions are; Why is this necessary only when source features are larger than destination features, and what about when the destination features are larger than the source?

This is mostly for complexity consideration. O(|E|D') is smaller than O(|E|D) if D' is smaller than D, where |E| is the number of edges and D, D' are representation sizes.