Hi all. I’m hoping someone could please give me some feedback on my attention model. The ideas here being that I have a runner
node (features pertaining to a horse running in the prediction race e.g. weight, starting gate etc.) with incoming run
edges (run info for a historical race
e.g. starting gate, finishing time etc.) that are connected to historical race
(features such as distance, year, month etc.) nodes.
I am trying to aggregate all historical races for a runner
in the current race using an attention model that uses features of the runner
, each incoming run
edge and the corresponding race
edges to calculate “relevance” of previous performances.
class RunnerUpdateModule(nn.Module):
def __init__(self, runner_embed_dim, race_embed_dim, run_edge_dim):
super(RunnerUpdateModule, self).__init__()
self.get_node_updates = nn.Sequential(
nn.Linear(race_embed_dim + run_edge_dim, runner_embed_dim),
nn.Dropout(0.4),
nn.ReLU(runner_embed_dim),
nn.Linear(runner_embed_dim, runner_embed_dim),
nn.Dropout(0.4)
)
self.normalise_nodes = nn.BatchNorm1d(runner_embed_dim)
attn_input_size = run_edge_dim + race_embed_dim + runner_embed_dim
self.attention_fc = nn.Sequential(
nn.Linear(attn_input_size, attn_input_size // 2),
nn.Dropout(0.4),
nn.ReLU(),
nn.Linear(attn_input_size // 2, 1),
nn.Dropout(0.4),
nn.LeakyReLU()
)
def forward(self, g: dgl.DGLHeteroGraph):
def compute_attention_score(edges):
concat_features = torch.cat([edges.src["h"], edges.dst["h"], edges.data["h"]], dim=-1)
scores = self.attention_fc(concat_features)
return {"a_raw": scores}
g.apply_edges(compute_attention_score, etype=("race", "run", "runner"))
def message_func(edges):
concat_features = torch.cat([edges.src["h"], edges.data["h"]], dim=-1)
return {"concat_features": concat_features, "a_raw": edges.data["a_raw"]}
def learnable_reduce_func(nodes):
attn_norm = torch.softmax(nodes.mailbox["a_raw"], dim=1)
h_sum = torch.sum(attn_norm * nodes.mailbox["concat_features"], dim=1)
return {"h_agg": self.get_node_updates(h_sum)}
g.multi_update_all(
{("race", "run", "runner"): (message_func, learnable_reduce_func)},
"sum"
)
g.nodes["runner"].data["h"] = self.normalise_nodes(g.nodes["runner"].data["h"]
+ g.nodes["runner"].data["h_agg"])
return g
Is this an appropriate implementation of attention that uses src node, edge and dst nodes in its calculations? I’ve been finding that the model basically outputs uniform importance across all edges.