Hi Mufei,
I will try to explain what I want to do. These code may not work, just for explanation.
The following is what I want to do with a single graph.
import dgl
import torch as th
def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) \
/ math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
x = th.randn(3, 10)
g.ndata['x'] = x
q = nn.Linear(10, 20)(g.ndata['x'])
k = nn.Linear(10, 20)(g.ndata['x'])
v = nn.Linear(10, 20)(g.ndata['x'])
output = attention(q, k, v)
However, this will not works with dgl batch graph.
x1 = th.randn(3, 10)
x2 = th.randn(2, 10)
g1.ndata['x'] = x1
g2.ndata['x'] = x2
g = batchgraph([g1, g2])
if we do something similar with above
q = nn.Linear(10, 20)(g.ndata['x'])
k = nn.Linear(10, 20)(g.ndata['x'])
v = nn.Linear(10, 20)(g.ndata['x'])
output = attention(q, k, v)
In the attention layer, q[0] will match with k[3] (query of the g1 will match with the key in the g2), this is what need to be avoid.
The current solution in dgl transformer tutorial.
x1 = th.randn(3, 10)
x2 = th.randn(2, 10)
g1.ndata['x'] = x1
g1.add_edges([0, 1, 2], [0, 1, 2])
g2.ndata['x'] = x2
g2.add_edges([0, 1], [0, 1])
g = batchgraph([g1, g2])
g.ndata['q'] = nn.Linear(10, 20)(g.ndata['x'])
g.ndata['k'] = nn.Linear(10, 20)(g.ndata['x'])
g.ndata['v'] = nn.Linear(10, 20)(g.ndata['x'])
def src_dot_dst(src_field, dst_field, out_field):
def func(edges):
return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}
return func
def scaled_exp(field, scale_constant):
def func(edges):
# clamp for softmax numerical stability
return {field: th.exp((edges.data[field] / scale_constant).clamp(-5, 5))}
return func
def propagate_attention(self, g, eids):
# Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
# Update node state
g.send_and_recv(eids,
[fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
[fn.sum('v', 'wv'), fn.sum('score', 'z')])
What this do is matching the source key with destination query of each edge, and then summarize all information from the edges. This is a very clever solution. However, this works only with complete graph. If there is no edges between 0 and 1 in g1, this will not work because the information of node 1 will not come to node 0.
One solution I am thinking is tranform the graph into a complete one before the attention layer, and then delete them. (Similar with what was done in thos generative models tutorial).
Hope this makes clear what I want to do.
Thanks.