Custom attention model

Hi all. I’m hoping someone could please give me some feedback on my attention model. The ideas here being that I have a runner node (features pertaining to a horse running in the prediction race e.g. weight, starting gate etc.) with incoming run edges (run info for a historical race e.g. starting gate, finishing time etc.) that are connected to historical race (features such as distance, year, month etc.) nodes.

I am trying to aggregate all historical races for a runner in the current race using an attention model that uses features of the runner, each incoming run edge and the corresponding race edges to calculate “relevance” of previous performances.

class RunnerUpdateModule(nn.Module):
    def __init__(self, runner_embed_dim, race_embed_dim, run_edge_dim):
        super(RunnerUpdateModule, self).__init__()
        
        self.get_node_updates = nn.Sequential(
            nn.Linear(race_embed_dim + run_edge_dim, runner_embed_dim),
            nn.Dropout(0.4),
            nn.ReLU(runner_embed_dim),
            nn.Linear(runner_embed_dim, runner_embed_dim),
            nn.Dropout(0.4)
        )
        self.normalise_nodes = nn.BatchNorm1d(runner_embed_dim)
        attn_input_size = run_edge_dim + race_embed_dim + runner_embed_dim
        self.attention_fc = nn.Sequential(
            nn.Linear(attn_input_size, attn_input_size // 2),
            nn.Dropout(0.4),
            nn.ReLU(),
            nn.Linear(attn_input_size // 2, 1),
            nn.Dropout(0.4),
            nn.LeakyReLU()
        )

    def forward(self, g: dgl.DGLHeteroGraph):
        def compute_attention_score(edges):
            concat_features = torch.cat([edges.src["h"], edges.dst["h"], edges.data["h"]], dim=-1)
            scores = self.attention_fc(concat_features)
            return {"a_raw": scores}
        
        g.apply_edges(compute_attention_score, etype=("race", "run", "runner"))
    
        def message_func(edges):
            concat_features = torch.cat([edges.src["h"], edges.data["h"]], dim=-1)
            return {"concat_features": concat_features, "a_raw": edges.data["a_raw"]}
        
        def learnable_reduce_func(nodes):
            attn_norm = torch.softmax(nodes.mailbox["a_raw"], dim=1)
            h_sum = torch.sum(attn_norm * nodes.mailbox["concat_features"], dim=1)
            return {"h_agg": self.get_node_updates(h_sum)}

        g.multi_update_all(
            {("race", "run", "runner"): (message_func, learnable_reduce_func)},
            "sum"
        )

        g.nodes["runner"].data["h"] = self.normalise_nodes(g.nodes["runner"].data["h"] 
                                                           + g.nodes["runner"].data["h_agg"])

        return g

Is this an appropriate implementation of attention that uses src node, edge and dst nodes in its calculations? I’ve been finding that the model basically outputs uniform importance across all edges.

Hi @BarclayII - I’m sorry for tagging you directly. I’m really stuck on this and would be very grateful for your opinion.

What do you mean by “importance”? By looking at the code, I don’t see an immediate problem.

Basically, I’m trying to write an attention module that will learn how to prioritise different historical race nodes based on features of the current runner node (and the associated run edge.) Imagine there are N historical races for a runner, there would be N (race, run, runner) tuples. From domain expertise, I suspect that different tuples will be more relevant to updating the current runner than others. For example, more recent races might be more important than earlier races. Or races where the runner started from the same gate might be more important etc. You can see that I compute a sum of these incoming messages weighted by the attention scores.

I’m just not sure if the above is how one should correctly implement such an attention module because I inspected the attention scores after ~20 epochs and it was applying a roughly uniform importance/attention score to all historical races. It’s possible that the feature embeddings themselves are the issue and the attention component isn’t given a good signal to prioritise historical races. I just want to rule out the mechanism itself. I hope I communicated that clearly. Much appreciated.

Which layer are you inspecting the attention scores? The attention scores are solely determined by the incident node features and edge features, and if the node features/edge features are not distinguishable, then the attention scores indeed won’t differ much.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.