GATConv: mismatch between nodes and features

Hey,

a newbie question here: I’m trying to define a simple GAT network for batched graph classification using the GATConv module from DGL:

class GATClassifier(nn.Module):
    def __init__(self, in_dim, out_dim, n_heads, n_classes):
        super(GATClassifier, self).__init__()

        # Add GATConv layers
        self.gat_1 = GATConv(in_feats=in_dim, out_feats=out_dim,
                             num_heads=n_heads)
        self.gat_2 = GATConv(in_feats=out_dim, out_feats=out_dim,
                             num_heads=n_heads)

        # Add linear classifier
        self.classify = nn.Linear(out_dim, n_classes)

    def forward(self, g, features):

        x = features

        x = self.gat_1(g, x)
        x = self.gat_2(g, x)

        g.ndata['h'] = x

        hg = dgl.mean_nodes(g, 'h')

        return self.classify(hg)

This returns the following error message:

dgl._ffi.base.DGLError: Expect number of features to match number of nodes (len(u)). Got 406 and 203 instead.

Each node in my batched graph has a 9-dimensional feature vector, hence the input size is 9 x 203.

What am I doing wrong?

Probably this is due to an inappropriate handle of attention heads? For example, maybe you used two attention heads, after the first GATConv, you got something like a list of two tensors, each with shape (203, 9). And then you by accident concatenate them in a wrong dimension, yielding a tensor of shape (406, 9)?

Sounds plausible, as I defined two attention heads for the model.

How should I account for this in the model definition?

See the examples here. The output of a GATConv is of shape (N, H, M), N for the number of nodes, H for the number of heads and M for the output size of each head. Basically you need to flatten the output or take an average of the output over the multi-head results. If you choose to flatten the output, then the input size to the next layer need to be multiplied by the number of heads in the current layer.

1 Like

As always, thanks for your help @mufeili! These examples make a lot of sense – I was looking at the code for the GAT tutorial, which lacks the averaging / flattening during the forward pass.

1 Like

Glad that helps. In fact the tutorial does the trick with the code block below:

class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))

Unfortunately, the nn modules were not ready by the time of the tutorial.

1 Like

Hi, I am curious how to solve this issue since the example link is no longer valid.
Could you provide an example for addressing the multiple head issue here?

Thank you.

See if this helps.

import torch.nn as nn
import torch.nn.functional as F

from dgl.nn.pytorch import GATConv

class GATLayer(nn.Module):
    r"""Single GAT layer from `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__

    Parameters
    ----------
    in_feats : int
        Number of input node features
    out_feats : int
        Number of output node features
    num_heads : int
        Number of attention heads
    feat_drop : float
        Dropout applied to the input features
    attn_drop : float
        Dropout applied to attention values of edges
    alpha : float
        Hyperparameter in LeakyReLU, which is the slope for negative values.
        Default to 0.2.
    residual : bool
        Whether to perform skip connection, default to True.
    agg_mode : str
        The way to aggregate multi-head attention results, can be either
        'flatten' for concatenating all-head results or 'mean' for averaging
        all head results.
    activation : activation function or None
        Activation function applied to the aggregated multi-head results, default to None.
    """
    def __init__(self, in_feats, out_feats, num_heads, feat_drop, attn_drop,
                 alpha=0.2, residual=True, agg_mode='flatten', activation=None):
        super(GATLayer, self).__init__()

        self.gat_conv = GATConv(in_feats=in_feats, out_feats=out_feats, num_heads=num_heads,
                                feat_drop=feat_drop, attn_drop=attn_drop,
                                negative_slope=alpha, residual=residual)
        assert agg_mode in ['flatten', 'mean']
        self.agg_mode = agg_mode
        self.activation = activation

    def forward(self, bg, feats):
        """Update node representations

        Parameters
        ----------
        bg : DGLGraph
            DGLGraph for a batch of graphs.
        feats : FloatTensor of shape (N, M1)
            * N is the total number of nodes in the batch of graphs
            * M1 is the input node feature size, which equals in_feats in initialization

        Returns
        -------
        feats : FloatTensor of shape (N, M2)
            * N is the total number of nodes in the batch of graphs
            * M2 is the output node representation size, which equals
              out_feats in initialization if self.agg_mode == 'mean' and
              out_feats * num_heads in initialization otherwise.
        """
        feats = self.gat_conv(bg, feats)
        if self.agg_mode == 'flatten':
            feats = feats.flatten(1)
        else:
            feats = feats.mean(1)

        if self.activation is not None:
            feats = self.activation(feats)

        return feats

class GAT(nn.Module):
    r"""GAT from `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__

    Parameters
    ----------
    in_feats : int
        Number of input node features
    hidden_feats : list of int
        ``hidden_feats[i]`` gives the output size of an attention head in the i-th GAT layer.
        ``len(hidden_feats)`` equals the number of GAT layers. By default, we use ``[32, 32]``.
    num_heads : list of int
        ``num_heads[i]`` gives the number of attention heads in the i-th GAT layer.
        ``len(num_heads)`` equals the number of GAT layers. By default, we use 4 attention heads
        for each GAT layer.
    feat_drops : list of float
        ``feat_drops[i]`` gives the dropout applied to the input features in the i-th GAT layer.
        ``len(feat_drops)`` equals the number of GAT layers. By default, this will be zero for
        all GAT layers.
    attn_drops : list of float
        ``attn_drops[i]`` gives the dropout applied to attention values of edges in the i-th GAT
        layer. ``len(attn_drops)`` equals the number of GAT layers. By default, this will be zero
        for all GAT layers.
    alphas : list of float
        Hyperparameters in LeakyReLU, which are the slopes for negative values. ``alphas[i]``
        gives the slope for negative value in the i-th GAT layer. ``len(alphas)`` equals the
        number of GAT layers. By default, this will be 0.2 for all GAT layers.
    residuals : list of bool
        ``residual[i]`` decides if residual connection is to be used for the i-th GAT layer.
        ``len(residual)`` equals the number of GAT layers. By default, residual connection
        is performed for each GAT layer.
    agg_modes : list of str
        The way to aggregate multi-head attention results for each GAT layer, which can be either
        'flatten' for concatenating all-head results or 'mean' for averaging all-head results.
        ``agg_modes[i]`` gives the way to aggregate multi-head attention results for the i-th
        GAT layer. ``len(agg_modes)`` equals the number of GAT layers. By default, we flatten
        all-head results for each GAT layer.
    activations : list of activation function or None
        ``activations[i]`` gives the activation function applied to the aggregated multi-head
        results for the i-th GAT layer. ``len(activations)`` equals the number of GAT layers.
        By default, no activation is applied for each GAT layer.
    """
    def __init__(self, in_feats, hidden_feats=None, num_heads=None, feat_drops=None,
                 attn_drops=None, alphas=None, residuals=None, agg_modes=None, activations=None):
        super(GAT, self).__init__()

        if hidden_feats is None:
            hidden_feats = [32, 32]

        n_layers = len(hidden_feats)
        if num_heads is None:
            num_heads = [4 for _ in range(n_layers)]
        if feat_drops is None:
            feat_drops = [0. for _ in range(n_layers)]
        if attn_drops is None:
            attn_drops = [0. for _ in range(n_layers)]
        if alphas is None:
            alphas = [0.2 for _ in range(n_layers)]
        if residuals is None:
            residuals = [True for _ in range(n_layers)]
        if agg_modes is None:
            agg_modes = ['flatten' for _ in range(n_layers - 1)]
            agg_modes.append('mean')
        if activations is None:
            activations = [F.elu for _ in range(n_layers - 1)]
            activations.append(None)
        lengths = [len(hidden_feats), len(num_heads), len(feat_drops), len(attn_drops),
                   len(alphas), len(residuals), len(agg_modes), len(activations)]
        assert len(set(lengths)) == 1, 'Expect the lengths of hidden_feats, num_heads, ' \
                                       'feat_drops, attn_drops, alphas, residuals, ' \
                                       'agg_modes and activations to be the same, ' \
                                       'got {}'.format(lengths)
        self.hidden_feats = hidden_feats
        self.num_heads = num_heads
        self.agg_modes = agg_modes
        self.gnn_layers = nn.ModuleList()
        for i in range(n_layers):
            self.gnn_layers.append(GATLayer(in_feats, hidden_feats[i], num_heads[i],
                                            feat_drops[i], attn_drops[i], alphas[i],
                                            residuals[i], agg_modes[i], activations[i]))
            if agg_modes[i] == 'flatten':
                in_feats = hidden_feats[i] * num_heads[i]
            else:
                in_feats = hidden_feats[i]

    def forward(self, g, feats):
        """Update node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        feats : FloatTensor of shape (N, M1)
            * N is the total number of nodes in the batch of graphs
            * M1 is the input node feature size, which equals in_feats in initialization

        Returns
        -------
        feats : FloatTensor of shape (N, M2)
            * N is the total number of nodes in the batch of graphs
            * M2 is the output node representation size, which equals
              hidden_sizes[-1] if agg_modes[-1] == 'mean' and
              hidden_sizes[-1] * num_heads[-1] otherwise.
        """
        for gnn in self.gnn_layers:
            feats = gnn(g, feats)
        return feats

It’s very helpful.
Thank you very much~!

I’m revisiting this question as I am trying to implement a classifier that uses GATConv. I am using the GATConv as explained here and I get the same error (" Expect number of features to match number of nodes"). I am unsure of how to implement the mentioned solutions when using the GATConv layer. Any advise please? For either merging or concatenating the results of each head?

I am currently using this:

 class Classifier_GAT(nn.Module):
     def __init__(self, in_dim, hidden_dim_graph,hidden_dim1, n_classes,dropout,num_heads):
         super(Classifier_GAT, self).__init__()
         self.conv1 = GATConv(in_dim, hidden_dim_graph,num_heads,dropout)
         self.conv2 = GATConv(hidden_dim_graph, hidden_dim1,num_heads,dropout)        
         self.classify = nn.Sequential(nn.Linear(hidden_dim1, n_classes),nn.Dropout(dropout))
         self.out_act = nn.Sigmoid()
     def forward(self, g,readout_method):
         h = g.ndata['h_n'].float()
         h = F.relu(self.conv1(g, h))
         h = F.relu(self.conv2(g, h))
         g.ndata['h'] = h
         hg = dgl.mean_nodes(g, 'h')
         a3=self.classify(hg)
         return self.out_act(a3)

Also, I would like to reproduce the image with edge attentions as shown in the tutorial. How can I return the edge attentions? And could I return the graph features (hg) as well somehow? Many thanks in advance!

1 Like

The output of F.relu(self.conv1(g, h)) is of shape (V, K, H), where V is the number of nodes, K is the number of attention heads, H is the hidden size. Before passing it to the next GATConv, you need to aggregate the results of the multiple attention heads, e.g. h = h.mean(dim=1) or h = h.flatten(1). In the second case, you need to set self.conv2 = GATConv(hidden_dim_graph * num_heads, hidden_dim1, num_heads, dropout).

For the last GAT layer, i.e. self.conv2 in your case, you can simply perform h = h.mean(dim=1).

For attention visualization, see this thread.

To return the graph features as well, just replace return self.out_act(a3) with return self.out_act(a3), hg.

5 Likes

Many thanks Mufeili, it works perfectly! :smiley:

Hi @mufeili

In the @paulamartingonzalez example, she use a relu after each GATConv layer. I wonder if it is correct, because GATConv use a leakyrelu as activation function and she is use two activations after each layer.

It’s correct. leakyrelu is used only for attention computation and you can apply an activation function to the output of message passing.