[new NN pytorch module] GAT layer with fully valuable edge features (EGAT)

Hi,

I wrote variation over GAT layer, where edges are equipped with fully valuable features not only attention scores, it was used in processing protein structures recently in our paper: Rossmann-toolbox: a deep learning-based protocol for the prediction and design of cofactor specificity in Rossmann fold proteins - PubMed . We show there improvement over regular GAT and GCN layers, when we supplied additional features (residue-residue interactions in edges), especially when graph density is increasing.
In my opinion, it may be helpful not only in this particular case but in general. I may contribute it to the DGL repository if it looks helpful for other users. According to DGL docs, firstly I wrote about it here.

layer detailed description (with image ofc :)) is available here: https://academic.oup.com/bib/advance-article/doi/10.1093/bib/bbab371/6375059#supplementary-data

source code is based on dgl.nn.pytorch.conv.graphconv — DGL 0.8 documentation
example use:

egat = MultiHeadEGATLayer(in_node_feats=num_node_feats,
                          in_edge_feats=num_edge_feats, #input edge feature size
                          out_node_feats=10, #output edge feature size
                          out_edge_feats=10,
                          num_heads=3,
                          activation=th.nn.functional.leaky_relu) #add activation if needed

forward pass:

new_node_feats, new_edge_feats = egat(graph, node_feats, edge_feats)
#new_node_feats.shape = (*, num_heads, out_node_feats)
#new_edge_feats.shape = (*, num_heads, out_edge_feats)

Params naming, and source code is based on dgl GATConv implementation.

Looking forward to your comments.

Best wishes,
Kamil Kamiński

Really appreciate your contribution to DGL. pls let me add more people to make comments on this.
@BarclayII @mufeili @minjie pls help check on this new module.

1 Like

Here is a repo with layer GitHub - labstructbioinf/EdgeGat: graph neural network layer for Deep Graph Library with pytorch backend I forgot about that :slight_smile:

Sorry for the late reply. This is a great job and I agree it can be beneficial for other researchers and users as well. Currently your implementation is based on an old implementation of GATConv and we now recommend following the latest implementation instead. Perhaps you can first open a pull request in the DGL repo with your current implementation and we can have a discussion there to see if we can update the implementation. It will be great if you can also take a look at the latest implementation of GATConv.