How to modify message function in R-GCN tutorial to use neighbouring node features as well in formulating the node representation?

Hello,
I would like to add my own node features and use them as well to form node representation in the R-GCN.
How should I modify the message function to achieve that?

thanks in advance.

Hi, could you please elaborate on how you would like to incorporate node representation? A minimal block of code or pseudocode will be very helpful.


In this equation, We need to add another term which is given below (GCN):
equation2
I have node features and relations which are significant so I thought of incorporating the both to obtain node representations.

This sounds like a reasonable attempt. Take the message function defined below as an example:

def msg_func(edges):
    w = weight.index_select(0, edges.data['type'])
    msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
    msg = msg * edges.data['norm']
    return {'msg': msg}

With the message function above, we are only considering the features of source nodes along with the edge types. To incorporate the node features I will not modify the msg_func, but instead play with apply_node_func.

The message passing operations in DGL can be performed with three stages: send messages with msg_func, reduce messages received with reduce_func, and optionally apply an additional function to the node features with apply_node_func.

This does not require much work. Previously we have

g.update_all(msg_func, fn.sum(msg='msg', out='h'), None)

Now we have

def update_nodes(nodes):
    # W_self for W_0 you mentioned, for which you need to define it somewhere.
    new_feats = W_self * nodes.data['h'] + nodes.data['h_new'] 
    return {'h' : new_feats}

g.update_all(msg_func, fn.sum(msg='msg', out='h_new'), update_nodes)
g.ndata.pop('h_new')

Alternatively, you can also add self loops if that’s easy for your use case. Hope this helps.

1 Like

@mufeili 's way looks like adding a subsequent step after the RGCN update. But my interpretation is kind of different. So @Ronit-j correct me if I am wrong, but it looks to me that besides the RGCN update, you also want to incorporate vanilla GCN update (which does not use edge relation type at all).

There are two cases:

  1. If your extra term is inside the large parentheses (i. e. before the activation), then you only need one small change in the message function:
def message_func(edges):
    w = weight[edges.data['rel_type']]
    msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
    msg = msg * edges.data['norm']

    # incorporate normalized src node representation without edge type
    msg = msg + edges.src['h'] * edges.data['norm']  

    return {'msg': msg}
  1. If your extra term is outside the parentheses (i.e. after the activation), you can change the message function to:
def message_func(edges):
    w = weight[edges.data['rel_type']]
    msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
    msg = msg * edges.data['norm']

    # calculate an extra message which does not use edge relation
    msg_extra = edges.src['h'] * edges.data['norm']

    return {'msg': msg, 'msg_extra': msg_extra}

Here in the message function, I am returning two messages. And we need to change the g.update_all here to use the msg_extra:

g.update_all(msg_func, [fn.sum(msg='msg', out='h'), 
                        fn.sum(msg='msg_extra', out='h_extra')], None)

I am providing 2 reduce functions (by putting them in a python list but specifying different feature names to use). And after executing g.update_all, nodes feature h_extra will store the result without using edge relation.

2 Likes

Yes, I wanted the first case. Also, would it be effective if we modify the Equation as I had stated?

I am not sure… You definitely have more information now. But I can’t say how that’s going to affect the accuracy. You probably need to play with it.

Okay, thank you so much for helping @lingfan and @mufeili .