Fine-tuning Features from BERT Model


I’m new to DGL and GNN’s and had a question about using BERT embeddings as features in the GNN:

I want to build a graph that uses BERT embeddings as features for the nodes. I understand that I can pre-compute these and assign it as a feature to the node, as long as all the nodes have this feature. As the GNN is trained (for example for a classification task), these initial embeddings will change. Is it possible to link this change back to fine-tune the BERT model through back-prop? If so, how?

Hi, if want to fine-tune BERT during training, just don’t fix the BERT model, and compute the initial node features “online”.

A useful trick is to use different learning rate for BERT and GNN, use a lower lr for BERT may make the training more stable.


Thanks, that makes sense. Is there a tutorial or quick setup for how I can compute the initial node features online to do this? Or maybe you can give me an idea of where I would do this? What I’ve seen so far is about assigning the features when building the graph.

Just use BERT as a sub module in your model.

Suppose each node is a text sequence, and tensor input_ids contains the token ids with shape N*L, where N is number of nodes, and L is the sequence length. Then the following code would work.

class YourModel(nn.Module):
    def __init__(self):
        self.bert = BertModel()
        self.gnn = GNN()

    def forward(self, input_ids, graph):
        # use vector in  '[CLS]' position as the node features
        node_feats = self.bert(input_ids)[:, 0, :] 
        # set node_feats as the initial node features in GNN
        graph.ndata['h'] = node_feats

1 Like