I have a RGCN model implemented follows
Basically, it uses HeteroGraphConv
class.
How can I use this HeteroGNNExplainer
https://docs.dgl.ai/en/1.0.x/generated/dgl.nn.pytorch.explain.HeteroGNNExplainer.html
to obtain the edge masks? Thanks.
I have a RGCN model implemented follows
Basically, it uses HeteroGraphConv
class.
How can I use this HeteroGNNExplainer
https://docs.dgl.ai/en/1.0.x/generated/dgl.nn.pytorch.explain.HeteroGNNExplainer.html
to obtain the edge masks? Thanks.
You need to modify the model class to incorporate the use of edge weights, as suggested in the example of HeteroGNNExplainer
.
I am a bit confused by the weight
and edge_weight
in graphconv class, dgl.nn.pytorch.conv.graphconv — DGL 1.0.2 documentation
So below is how the RGCN is defined,
self.conv = dglnn.HeteroGraphConv(
{
rel: dglnn.GraphConv(
in_feat, out_feat, norm="right", weight=False, bias=False
)
for rel in rel_names
}
)
self.weight = nn.Parameter(
th.Tensor(len(self.rel_names), in_feat, out_feat)
)
nn.init.xavier_uniform_(
self.weight, gain=nn.init.calculate_gain("relu")
)
weight = self.basis() if self.use_basis else self.weight
wdict = {
self.rel_names[i]: {"weight": w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))
}
hs = self.conv(g, inputs, mod_kwargs=wdict)
So this weight here is basically Wr in RGCN formula, and it is passed to graphconv, where it has two parameters weight=None, edge_weight=None
.
I am confused which parameter actually takes this wdict
in your example. It seems it is weight
, if that is the case, can I simply change wdict
to
wdict = {
self.rel_names[i]: {"edge_weight": mask}
for i, w in enumerate(masks)
}
and then pass this to the trained model?
The weight
argument is for employing an alternative weight matrix W
in message passing AXW
. In the example for HeteroGNNExplainer
, it is eweight
for edge weights or equivalently the nonzero entries in A
. Your modification looks fine to me.
This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.