GNNExplainer import

Hi! I am trying to import the GNNExplainer as explained in the example using from dgl.nn import GNNExplainer but I get the error:

---------------------------------------------------------------------------

ImportError Traceback (most recent call last)

[<ipython-input-76-3d847f1fa71f>](https://71mh5wf6ehw-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220307-060048-RC00_432884169#) in <module>() ----> 1 from dgl.nn import GNNExplainer

ImportError: cannot import name 'GNNExplainer' from 'dgl.nn.pytorch' (/usr/local/lib/python3.7/dist-packages/dgl/nn/pytorch/__init__.py)

I think I have the most up to date DGL installation. Should I install something else? Thanks in advance!

I can import it on dgl v0.8, what’s your current version?

import dgl
print(dgl.__version__)
1 Like

Hi @neo thanks for this! I have actually version 0.6 but when I try to install version 0.8 using pip install dgl==0.8.0 I get:

ERROR: Could not find a version that satisfies the requirement dgl==0.8.0 (from versions: 0.0.1, 0.1.0, 0.1.2, 0.1.3, 0.2, 0.3, 0.3.1, 0.4rc190819, 0.4rc190821, 0.4rc190822, 0.4rc190823, 0.4rc190824, 0.4rc190826, 0.4rc190902, 0.4rc190903, 0.4rc190904, 0.4rc190905, 0.4rc190906, 0.4rc190908, 0.4rc190909, 0.4rc190910, 0.4rc190911, 0.4rc190912, 0.4rc190915, 0.4rc190916, 0.4rc190917, 0.4rc190918, 0.4rc190920, 0.4rc190921, 0.4rc190923, 0.4rc190924, 0.4rc190927, 0.4rc190928, 0.4rc190929, 0.4rc191001, 0.4rc191003, 0.4rc191004, 0.4rc191005, 0.4, 0.4.1, 0.4.2, 0.4.3, 0.4.3.post1, 0.4.3.post2, 0.5.0, 0.5.1, 0.5.2, 0.5.3, 0.6.0, 0.6.0.post1, 0.6.1) ERROR: No matching distribution found for dgl==0.8.0

Would you mind advising into how to install v0.8 please? I have also tried cloning the repo using git clone --recurse-submodules https://github.com/dmlc/dgl.git but I am unsure of how to import it as a package to then call dgl.nn. Thanks in advance!

for cuda 11.1 on colab:
pip install dgl-cu111 dglgo -f https://data.dgl.ai/wheels/repo.html

You can check Install commands for other configurations

1 Like

Thank you! It works perfectly now!

I am trying to run now explainer = GNNExplainer(model, num_hops=1) and feat_mask, edge_mask = explainer.explain_graph(G, features) and I get the error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-20-01786a92a76c> in <module>()
----> 1 feat_mask, edge_mask = explainer.explain_graph(G, features)

1 frames
/usr/local/lib/python3.7/dist-packages/dgl/nn/pytorch/explain/gnnexplainer.py in explain_graph(self, graph, feat, **kwargs)
    369         # Get the initial prediction.
    370         with torch.no_grad():
--> 371             logits = self.model(graph=graph, feat=feat, **kwargs)
    372             pred_label = logits.argmax(dim=-1)
    373 

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() got an unexpected keyword argument 'graph'

Does this also happen to you? Thanks again :blush:

No. Are you using the code as in the example in the commit?

The forward function of your model did not have a graph argument, which is required. You might have called it g instead.

Oh, thanks for your help @neo @mufeili ! I am using a custom model so I needed to change the nomenclature of the forward function as suggested (graph,feat,eweight). Now it works well :blush:

I will be working on this so I might reopen with extra issues if there were any! Thanks for your work, I was very excited to see GNNExplainer supported on DGL.

1 Like

Thanks @paulamartingonzalez . We might add more explainability models in the future and let us know if there are models that you find particularly useful/important.

1 Like

Hi again @mufeili and @neo, I am trying the code on a custom model which is as follows:

class Classifier_gen(nn.Module):
    def __init__(self, in_dim, hidden_dim_graph,hidden_dim1,n_classes,dropout,num_layers,pooling):
        super(Classifier_gen, self).__init__()
        if num_layers ==1:
            self.conv1 = GraphConv(in_dim, hidden_dim1)
        if num_layers==2:
            self.conv1 = GraphConv(in_dim, hidden_dim_graph)
            self.conv2 = GraphConv(hidden_dim_graph, hidden_dim1)
        if pooling == 'att':
            pooling_gate_nn = nn.Linear(hidden_dim1, 1)
            self.pooling = GlobalAttentionPooling(pooling_gate_nn)
        self.classify = nn.Sequential(nn.Linear(hidden_dim1,hidden_dim1),nn.Dropout(dropout))
        self.classify2 = nn.Sequential(nn.Linear(hidden_dim1, n_classes),nn.Dropout(dropout))
        self.out_act = nn.Sigmoid()
        self.num_layers = num_layers
        self.mypooling = pooling
    def forward(self, graph,feat, eweight=None):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        # Perform graph convolution and activation function.
        # Calculate graph representation by averaging all the node representations.
        num_layers = self.num_layers
        pooling = self.mypooling
        h = feat.float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(graph, h))
        h = F.relu(self.conv2(graph, h))
            
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        #hg = dgl.mean_nodes(g, 'h')
        if pooling == "max":
            hg = dgl.max_nodes(graph, 'h')
        elif pooling=="mean":
            hg = dgl.mean_nodes(graph, 'h')
        elif pooling == "sum":
            hg = dgl.sum_nodes(graph, 'h') 
        elif pooling =='att':  
            # Calculate graph representation by averaging all the node representations.
            hg = self.pooling(graph,h) 
        
        
        a2=self.classify(hg)
        a3=self.classify2(a2)
        return self.out_act(a3)

I have two questions:

  1. if my model doesn’t do edge convolutions nor takes into account edge weights, are the edge importances returned valid?
  2. I’ve run the GNNExplainer on my data (which has 19 node features) and all the features seem to have similar importances. For example, for a particular graph, these are the feature importances:
[0.29416263, 0.28345844, 0.31799579, 0.27244017, 0.24714269,
        0.280285  , 0.27397498, 0.27854383, 0.25854427, 0.30813652,
        0.28295702, 0.27138951, 0.26840273, 0.2793256 , 0.27505299,
        0.24631442, 0.28361943, 0.22539981, 0.29995206]

And for the whole graph set, the feature importance is around the same (~0.2-0.3) for each features with the mean being always around 0.26-0.27 for all the features. I see this is the case also for the examples provided and I wonder, how can we find more relevant features here? Is this behaviour normal? I see that changing the number of hops doesn’t affect much (neither in my data nor in the example) but changing the lr does. Is there some hyper parameter tunning needed?

  1. More conceptually, I see that only the graph and its features are used as input to the explainer. I thought that we were trying to look at the perturbations that don’t affect much the prediction so I just find it intuitively hard to understand why don’t we use the label?
  1. if my model doesn’t do edge convolutions nor takes into account edge weights, are the edge importances returned valid?

It must consider edge weights with an optional argument eweight, otherwise the edge importance scores were just randomly initialized from a normal distribution.

And for the whole graph set, the feature importance is around the same (~0.2-0.3) for each features with the mean being always around 0.26-0.27 for all the features. I see this is the case also for the examples provided and I wonder, how can we find more relevant features here? Is this behaviour normal? I see that changing the number of hops doesn’t affect much (neither in my data nor in the example) but changing the lr does. Is there some hyper parameter tunning needed?

Honestly I’m not sure about it. Did you also change the number of GNN layers when changing the number of hops? You can definitely try tuning the learning rate. You can also try adding fake or trivial node features and see if GNNExplainer manages to assign low importance scores to them.

  1. More conceptually, I see that only the graph and its features are used as input to the explainer. I thought that we were trying to look at the perturbations that don’t affect much the prediction so I just find it intuitively hard to understand why don’t we use the label?

GNNExplainer aims at identifying a subgraph and a subset of the node features that allow approximating the original model predictions. In this way, it does use labels, but the labels are the predictions of a GNN to explain.

Thanks so much @mufeli!

GNNExplainer aims at identifying a subgraph and a subset of the node features that allow approximating the original model predictions. In this way, it does use labels, but the labels are the predictions of a GNN to explain.

This makes a lot of sense, thanks for explaining it!

Did you also change the number of GNN layers when changing the number of hops?

With this do you refer to the layers in the GNN model? If yes, then yes I did change them :blush:

You can also try adding fake or trivial node features and see if GNNExplainer manages to assign low importance scores to them.

I have been investigating this a little bit further both with my data and with the one used in the example (my data has 19 node features and the example one has 7).

  1. I first tried making zeros all the features except a few in my own data:

When I use the correct features without making anything zero yet I get the following for a particular graph:

feat_mask = tensor([0.2796, 0.2358, 0.2982, 0.2932, 0.2907, 0.2584, 0.2548, 0.2890, 0.2668,
        0.2450, 0.2933, 0.2431, 0.2656, 0.2756, 0.2456, 0.2537, 0.2591, 0.2613,
        0.2796])

When I zero out some of the features, I still get very similar results:

A = np.zeros_like(g.ndata['h_n'].detach().numpy())
A[:,5]=g.ndata['h_n'].detach().numpy()[:,5]
A[:,10]=g.ndata['h_n'].detach().numpy()[:,10]
A[:,18]=g.ndata['h_n'].detach().numpy()[:,18]

features = torch.tensor(A)
feat_mask, edge_mask = explainer.explain_graph(g, features)

feat_mask = tensor([0.2763, 0.2770, 0.2616, 0.2513, 0.2592, 0.2937, 0.2610, 0.2516, 0.2467,
        0.2352, 0.2506, 0.2497, 0.2800, 0.2693, 0.2958, 0.2284, 0.2671, 0.3004,
        0.2791])
  1. I tried something similar in the data for the example:
explainer = GNNExplainer(model, num_hops=1,lr = 0.01)
g, _ = data[0]
features = g.ndata['attr']
feat_mask, edge_mask = explainer.explain_graph(g, features)
feat_mask = tensor([0.2844, 0.2770, 0.2995, 0.2504, 0.2628, 0.2609, 0.2833]) 

and when zeroing some of the features out:

A = np.zeros_like(g.ndata['attr'].detach().numpy())
A[:,5]=g.ndata['attr'].detach().numpy()[:,5]
A[:,6]=g.ndata['attr'].detach().numpy()[:,6]
A[:,3]=g.ndata['attr'].detach().numpy()[:,3]


features = torch.tensor(A)
feat_mask, edge_mask = explainer.explain_graph(g, features)
feat_mask = tensor([0.2777, 0.2612, 0.2789, 0.2848, 0.2802, 0.2496, 0.3019])

Is this normal? I was trying to upload an image showing the histogram of the feature importances in all the nodes but I can’t for some reason. They show what I mentioned yesterday, all the importances were between 0.2 and 0.3 with very similar means.

It must consider edge weights with an optional argument eweight , otherwise the edge importance scores were just randomly initialized from a normal distribution.

Lastly, for this I have two questions. This is a pre-trained model that I have the weights saved for. Can I just add something the below? (It is based on the provided example):


class Classifier_gen(nn.Module):
    def __init__(self, in_dim, hidden_dim_graph,hidden_dim1,n_classes,dropout,num_layers,pooling):
        super(Classifier_gen, self).__init__()
        if num_layers ==1:
            self.conv1 = GraphConv(in_dim, hidden_dim1)
        if num_layers==2:
            self.conv1 = GraphConv(in_dim, hidden_dim_graph)
            self.conv2 = GraphConv(hidden_dim_graph, hidden_dim1)
        if pooling == 'att':
            pooling_gate_nn = nn.Linear(hidden_dim1, 1)
            self.pooling = GlobalAttentionPooling(pooling_gate_nn)
        self.classify = nn.Sequential(nn.Linear(hidden_dim1,hidden_dim1),nn.Dropout(dropout))
        self.classify2 = nn.Sequential(nn.Linear(hidden_dim1, n_classes),nn.Dropout(dropout))
        self.out_act = nn.Sigmoid()
        self.num_layers = num_layers
        self.mypooling = pooling
    def forward(self, graph,feat, eweight=None):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        # Perform graph convolution and activation function.
        # Calculate graph representation by averaging all the node representations.
        num_layers = self.num_layers
        pooling = self.mypooling
        h = feat.float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(graph, h))
        if eweight is None:
          graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
        else:
          graph.edata['w'] = eweight
          graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))

        h = F.relu(self.conv2(graph, h))

        if eweight is None:
          graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
        else:
          graph.edata['w'] = eweight
          graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
            
        graph.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        #hg = dgl.mean_nodes(g, 'h')
        if pooling == "max":
            hg = dgl.max_nodes(graph, 'h')
        elif pooling=="mean":
            hg = dgl.mean_nodes(graph, 'h')
        elif pooling == "sum":
            hg = dgl.sum_nodes(graph, 'h') 
        elif pooling =='att':  
            # Calculate graph representation by averaging all the node representations.
            hg = self.pooling(graph,h) 
        
        
        a2=self.classify(hg)
        a3=self.classify2(a2)
        return self.out_act(a3)

It is giving me an error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-71-dd81e9fa2c48> in <module>()
     1 g = trainsetMB[1][0]
     2 features = g.ndata['h_n']
----> 3 feat_mask, edge_mask = explainer.explain_graph(g, features)

5 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
  1846     if has_torch_function_variadic(input, weight, bias):
  1847         return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848     return torch._C._nn.linear(input, weight, bias)
  1849 
  1850 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (907x19 and 7x2)

And I don’t know what is really happening.

So sorry for my very-long comment and many thanks for all your help!

With this do you refer to the layers in the GNN model? If yes, then yes I did change them :blush:

Yes. You are right.

Is this normal? I was trying to upload an image showing the histogram of the feature importances in all the nodes but I can’t for some reason. They show what I mentioned yesterday, all the importances were between 0.2 and 0.3 with very similar means.

I tried it myself. It seems that in general the feature importance scores decrease with a bigger lr and number of epochs. Meanwhile, the score differences among them are small.

Lastly, for this I have two questions. This is a pre-trained model that I have the weights saved for. Can I just add something the below? (It is based on the provided example):

I don’t understand this. Did you also modify the model architecture? Did you re-initialize the GNNExplainer object?

Hi @mufeili , thanks again!

I am very puzzled with this, even in the example MUTAG dataset:

I tried it myself. It seems that in general the feature importance scores decrease with a bigger lr and number of epochs. Meanwhile, the score differences among them are small.

By taking a look at the GNN paper where they also predict mutageneicity (not in the MUTAG dataset, but some node features are shared) they show somewhat “big” differences → no scale is given so I don’t really know how big.

I think then, in the MUTAG dataset with node labels:

  0  C
  1  N
  2  O
  3  F
  4  I
  5  Cl
  6  Br

We would expect to see O, C and N being more important than F, I, Br but the differences we see are so small that I don’t know to which extend we can claim this. I am not an expert on GNNExplainer nor on its implementation, but could there be a small bug in the implementation? :persevere: Don’t you think it is also strange that even when masking with zeros a set of features, they are given a comparable importance to the real ones maintained?

Re to this:

I don’t understand this. Did you also modify the model architecture? Did you re-initialize the GNNExplainer object?

Yes, I re-loaded everything for this case :blush:

Many thanks and so sorry for my many messages, I am just very eager to use GNNExplainer in my data :innocent:

By taking a look at the GNN paper where they also predict mutageneicity (not in the MUTAG dataset, but some node features are shared) they show somewhat “big” differences → no scale is given so I don’t really know how big.

You need to check the authors’ code or contact the authors. Quite often the paper does not include all the details and it can be tricky to reproduce the reported results.

We would expect to see O, C and N being more important than F, I, Br but the differences we see are so small that I don’t know to which extend we can claim this. I am not an expert on GNNExplainer nor on its implementation, but could there be a small bug in the implementation? :persevere: Don’t you think it is also strange that even when masking with zeros a set of features, they are given a comparable importance to the real ones maintained?

There might be bugs that I’m not aware of. If you can identify them, I’m happy to fix them. Meanwhile, it’s possible that the method assigns similar importance scores to even trial features. In that case, you might need to increase the coefficients related to the node feature mask here and see if that helps.

Thanks @mufeili !

There might be bugs that I’m not aware of. If you can identify them, I’m happy to fix them. Meanwhile, it’s possible that the method assigns similar importance scores to even trial features. In that case, you might need to increase the coefficients related to the node feature mask here and see if that helps.

Actually this recommendation works great and now I see greater differences in node feature importance. Indeed, in the GNN Explainer paper the authors mention that the penalisations in the loss can be changed to have the constraints needed for different problems. I see that differences increase when I reduce the “node_size” chef and increase the “node_entropy” one. Is that as expected? I have no ground-truth for my problem, have you seen anyone using it in this setting? Should I just play with some parameters?

Lastly, I have tried this

Lastly, for this I have two questions. This is a pre-trained model that I have the weights saved for. Can I just add something the below? (It is based on the provided example):

class Classifier_gen(nn.Module):
    def __init__(self, in_dim, hidden_dim_graph,hidden_dim1,n_classes,dropout,num_layers,pooling):
        super(Classifier_gen, self).__init__()
        if num_layers ==1:
            self.conv1 = GraphConv(in_dim, hidden_dim1)
        if num_layers==2:
            self.conv1 = GraphConv(in_dim, hidden_dim_graph)
            self.conv2 = GraphConv(hidden_dim_graph, hidden_dim1)
        if pooling == 'att':
            pooling_gate_nn = nn.Linear(hidden_dim1, 1)
            self.pooling = GlobalAttentionPooling(pooling_gate_nn)
        self.classify = nn.Sequential(nn.Linear(hidden_dim1,hidden_dim1),nn.Dropout(dropout))
        self.classify2 = nn.Sequential(nn.Linear(hidden_dim1, n_classes),nn.Dropout(dropout))
        self.out_act = nn.Sigmoid()
        self.num_layers = num_layers
        self.mypooling = pooling
    def forward(self, graph,feat, eweight=None):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        # Perform graph convolution and activation function.
        # Calculate graph representation by averaging all the node representations.
        num_layers = self.num_layers
        pooling = self.mypooling
        h = feat.float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(graph, h))
        if eweight is None:
          graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
        else:
          graph.edata['w'] = eweight
          graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))

        h = F.relu(self.conv2(graph, h))

        if eweight is None:
          graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
        else:
          graph.edata['w'] = eweight
          graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
            
        graph.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        #hg = dgl.mean_nodes(g, 'h')
        if pooling == "max":
            hg = dgl.max_nodes(graph, 'h')
        elif pooling=="mean":
            hg = dgl.mean_nodes(graph, 'h')
        elif pooling == "sum":
            hg = dgl.sum_nodes(graph, 'h') 
        elif pooling =='att':  
            # Calculate graph representation by averaging all the node representations.
            hg = self.pooling(graph,h) 
        
        
        a2=self.classify(hg)
        a3=self.classify2(a2)
        return self.out_act(a3)

But the edge features remain as if didn’t have any edge weights (the result is just a Gaussian gaussian distribution). Does something look odd to you in the above? Although I used the model provided in the example, I’ve never worked with edge weights in this setting so I might be missing some steps?

Thanks again for you help! Very much appreciated :slight_smile:

Actually this recommendation works great and now I see greater differences in node feature importance. Indeed, in the GNN Explainer paper the authors mention that the penalisations in the loss can be changed to have the constraints needed for different problems. I see that differences increase when I reduce the “node_size” chef and increase the “node_entropy” one. Is that as expected?

I’ve opened a PR to expose these arguments and added documentation for them. You can take a look at it. Thank you for inspiring me to improve this module.

I have no ground-truth for my problem, have you seen anyone using it in this setting? Should I just play with some parameters?

This is a common problem for using these GNN explanation methods in practice. I do not have a good answer yet. You may find this paper helpful.

But the edge features remain as if didn’t have any edge weights (the result is just a Gaussian gaussian distribution). Does something look odd to you in the above? Although I used the model provided in the example, I’ve never worked with edge weights in this setting so I might be missing some steps?

Having an eweight argument in the forward function only matters for the generated edge mask explanations.

@mufeli thanks for your enormous help! It would be great if those arguments could be changed directly and the paper is very relevant to my research! Thanks! :smiling_face_with_three_hearts:

regarding this:

Q: But the edge features remain as if didn’t have any edge weights (the result is just a Gaussian gaussian distribution). Does something look odd to you in the above? Although I used the model provided in the example, I’ve never worked with edge weights in this setting so I might be missing some steps?

A: Having an eweight argument in the forward function only matters for the generated edge mask explanations.

I meant that using the model above that in their takes into account the weight argument in my model the *edge importances remain unchanged. I just get the same as in the case where I have eweight being unused (just a gaussian distribution). I was wondering if something looks odd in the model above? Or should I take something else into account that I am missing?

You need to pass eweight self.conv1 and self.conv2. The lines you added related to eweight did not come into effect.

Ah, that makes a lot of sense, thanks!!

I’ve updated the forward function as:

    def forward(self, graph,feat, eweight=None):
        
        num_layers = self.num_layers
        pooling = self.mypooling
        h = feat.float()
        # Perform graph convolution and activation function.

        if eweight is None:
          graph.update_all(fn.copy_u('h_n', 'm'), fn.sum('m', 'h_n'))
          h = F.relu(self.conv1(graph, h))
        else:
          graph.edata['w'] = eweight
          graph.update_all(fn.u_mul_e('h_n', 'w', 'm'), fn.sum('m', 'h_n'))
          norm = self.w(g, graph.edata['w'])

          h = F.relu(self.conv1(graph, h,edge_weight=norm)) 

        

        if eweight is None:
          graph.update_all(fn.copy_u('h_n', 'm'), fn.sum('m', 'h_n'))
          h = F.relu(self.conv2(graph, h))
        else:
          graph.edata['w'] = eweight
          graph.update_all(fn.u_mul_e('h_n', 'w', 'm'), fn.sum('m', 'h_n'))
          norm = self.w(g, graph.edata['w'])

          h = F.relu(self.conv2(graph, h,edge_weight=norm)) 

        

            
        graph.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        #hg = dgl.mean_nodes(g, 'h')
        if pooling == "max":
            hg = dgl.max_nodes(graph, 'h')
        elif pooling=="mean":
            hg = dgl.mean_nodes(graph, 'h')
        elif pooling == "sum":
            hg = dgl.sum_nodes(graph, 'h') 
        elif pooling =='att':  
            # Calculate graph representation by averaging all the node representations.
            hg = self.pooling(graph,h) 
        
        
        a2=self.classify(hg)
        a3=self.classify2(a2)

with self.w = EdgeWeightNorm(). Is that what you meant? I do see changes in the edge mask (is not a gaussian anymore) so I think it is working! :smiling_face_with_three_hearts:

I will be working on this in my data for the next weeks so I might come back with more questions. But thank you so much for all your help already!