Hi there, I have a trained RGCN model and I would like to mask edges to understand the edge importance as mentioned in GNNexplainer. But I am not sure how to pass my masks. The current RGCN layer is:
class HeteroRGCNLayerSimple(nn.Module):
def __init__(self, in_size, out_size, etypes, dropout=0):
super().__init__()
self.weight = nn.ModuleDict({name: nn.Linear(in_size, out_size) for name in etypes })
def forward(self, G, feat_dict):
# The input is a dictionary of node features for each type
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(G.canonical_etypes):
# Compute W_r * h
if srctype in feat_dict:
Wh = self.weight[etype](feat_dict[srctype])
# Save it in graph for message passing
G.nodes[srctype].data['Wh_%s' % etype] = Wh
# Specify per-relation message passing functions: (message_func, reduce_func).
funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# Trigger message passing of multiple types.
G.multi_update_all(funcs, 'sum')
# return the updated node feature dictionary
fdict = {ntype: G.nodes[ntype].data['h'] for ntype in G.ntypes if 'h' in G.nodes[ntype].data}
return fdict
My masks for edges of each relation (edge type) would be (Number of edges, 1)
and Wh for each node type would have dimension (Number of nodes, out_size)
Thanks.