Mask edges for RGCN

Hi there, I have a trained RGCN model and I would like to mask edges to understand the edge importance as mentioned in GNNexplainer. But I am not sure how to pass my masks. The current RGCN layer is:

class HeteroRGCNLayerSimple(nn.Module):
    def __init__(self, in_size, out_size, etypes, dropout=0):
        self.weight = nn.ModuleDict({name: nn.Linear(in_size, out_size) for name in etypes })
    def forward(self, G, feat_dict):  
        # The input is a dictionary of node features for each type
        funcs = {}
        for i, (srctype, etype, dsttype) in enumerate(G.canonical_etypes):
            # Compute W_r * h
            if srctype in feat_dict:
                Wh = self.weight[etype](feat_dict[srctype]) 
                 # Save it in graph for message passing
                 G.nodes[srctype].data['Wh_%s' % etype] = Wh
                # Specify per-relation message passing functions: (message_func, reduce_func).
                funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
        # Trigger message passing of multiple types.
        G.multi_update_all(funcs, 'sum')
        # return the updated node feature dictionary
        fdict = {ntype: G.nodes[ntype].data['h'] for ntype in G.ntypes if 'h' in G.nodes[ntype].data}
        return fdict

My masks for edges of each relation (edge type) would be (Number of edges, 1)
and Wh for each node type would have dimension (Number of nodes, out_size)


Please refer to HeteroGraphConv — DGL 0.9.0 documentation for the implementation of HeteroGraphConv Module. And you can pass the edge mask as edge_weight during forward for each module in HeteroGraphConv (e.g. GraphConv — DGL 0.9.0 documentation).

Hi there, could you please be more specific? I basically follow this sagemaker-graph-fraud-detection/ at master · awslabs/sagemaker-graph-fraud-detection · GitHub.

How to pass my masks to mod_kwargs to this forward (g, inputs, mod_args=None, mod_kwargs=None )?


Here is an example.

import dgl
import torch
from dgl.nn import HeteroGraphConv, GraphConv, SAGEConv

edges1 = (torch.tensor([0, 0, 1, 1, 2]), torch.tensor([2, 3, 4, 4, 5]))
edges2 = (torch.tensor([0, 1, 2, 3, 4]), torch.tensor([2, 3, 4, 5, 5]))
# random edge masks
mask1 = (torch.rand(5, 1) > 0.5).float()
mask2 = (torch.rand(5, 1) > 0.5).float()

g = dgl.heterograph({
    ('user', 'follows', 'user'): edges1,
    ('user', 'plays', 'game'): edges2
h1 = {
    'user': torch.randn((g.num_nodes('user'), 5)),
    'game': torch.randn((g.num_nodes('game'), 5))    
conv = HeteroGraphConv({
    'follows': GraphConv(5, 10),
    'plays': SAGEConv(5, 10, 'mean'),
}, aggregate='sum')
# pass masks as edge weights here
conv(g, h1,
        'follows': {'edge_weight': mask1},
        'plays': {'edge_weight': mask2}
1 Like