Constrained GAT implementation question


I’m trying to implement the constrained GAT model described here. The authors describe two margin-based loss functions, one on the graph structure (Equation 1)

\begin{align} L_{g} = \sum_{i \in V} \sum_{j \in N_{i} \setminus N_{i}^{-}} \sum_{k \in V \setminus N_{i}} max(0, \phi(v_{i}, v_{k}) + \zeta_{g} - \phi(v_{i}, v_{j})) \end{align}

and one on class boundaries (Equation 2)

\begin{align} L_{b} = \sum_{i \in V} \sum_{j \in N_{i}^{+}} \sum_{k \in N_{i}^{-}} max(0, \phi(v_{i}, v_{k}) + \zeta_{b} - \phi(v_{i}, v_{j})) \end{align}

where \phi(v_{i},v_{j}) is the attention function LeakyReLU(MLP(v_{i}, v_{j})) between source node i and destination node j. Equation 1 encourages adjacent nodes to have larger attention values than distant nodes. Equation 2 encourages adjacent nodes with the same label to have larger attention values than adjacent nodes with different labels.

For a single source node i, the authors compute all pairwise differences in attention values and then sum the differences, rather than summing up the positive and negative attention values and then taking the difference in sums.

I’m a bit stuck on how to implement this pairwise difference step – given the GATConv layer, the unnormalized attention values e are already messages on edges between a source node and its adjacent destination nodes. The pairwise differences might be possible if we compute the edge graph for a source node (or of the entire graph, really, but this might be exceptionally large) – i.e. treat all edges incident on a source node as adjacent nodes in the edge graph. Does DGL offer some solution to this “comparison of messages” problem already?



L 290 stores the attention values as an edge feature. Assume the attention values is of shape (|E|, 1), the pairwise difference can be a tensor of shape (|E|^2, 1). Is this something you want to compute?

HI @mufeili

I think it would be at most |E|^{2} if the graph were fully dense. If we assume constant node degree d << |V|, I think we would have d^{2}|V|. I only want to compute pairwise differences between edges incident on a single node, for every node, rather than all pairwise differences between all edges.

For example, if node i is adjacent to j \in N_{i}^{+} in the positive graph, with positive graph edge weights p_{i,j}, and adjacent to k \in N_{i}^{-} in the negative graph, with negative edge weights n_{i,k} then I’d want

\sum_{j \in N_{i}^{+}} \sum_{k \in N_{i}^{-}} max(0, n_{i,k} + \zeta - p_{i,j})

See if this code snippet works for you.

import dgl
import dgl.function as fn
import torch

g = dgl.graph(([0, 1, 1, 2, 3, 1, 2, 3], [0, 0, 0, 3, 3, 3, 1, 2]))
g.edata['pos_or_neg'] = torch.tensor([-1., -1., 1., 1., -1., 1., 1., -1.])
g.edata['weight'] = torch.randn(g.num_edges())
g.edata['comb_weight'] = torch.stack([g.edata['pos_or_neg'], g.edata['weight']], dim=1)

def reduce(nodes):
    msg = nodes.mailbox['m']
    pos_or_neg = msg[:, :, 0]
    weight = msg[:, :, 1]
    pos_weight = weight.masked_fill(pos_or_neg==-1, 0).unsqueeze(2)
    neg_weight = weight.masked_fill(pos_or_neg==1, 0).unsqueeze(1)
    # broadcast for pairwise difference
    diff = (neg_weight - pos_weight).clamp(min=0)
    return {'diff_sum': diff.sum(-1).sum(-1, keepdim=True)}

g.update_all(fn.copy_e('comb_weight', 'm'), reduce)
1 Like

Hi @mufeili

Thanks for the code – this works well. Using some other code provided in this post, and tweaking your code, I was able to define a constrained GAT layer.