HeteroGraphConv not updating the initial node representations

Hi, I’m currently using dglnn.HeteroGraphConv to update the initial node embeddings. In my code, the initial node embeddings are the self.digit_feats.

The self.digit_feats is the matrix for the initial node embeddings, and they are not being updated after optimizer.step().

Apparently, dglnn.HeteroGraphConv returns a dictionary of node embeddings with node types as the keys. I’m not sure if this is the problem, but could this be the reason causing a disconnection in the backward gradient flow?

The g is a dgl.heterograph with its node types and edges properly initialized, with node types digits and numbers.

    class NumberGraph(nn.Module):
    def __init__(self, args):
        super(NumberGraph, self).__init__()
        self.args = args
        self.hidden_size = self.args.hidden_size
        self.num_labels = self.args.num_labels
        self.digit_feats = nn.Parameter(torch.Tensor(10, self.args.hidden_size))
        # https://docs.dgl.ai/api/python/nn.pytorch.html#dgl.nn.pytorch.HeteroGraphConv

        # TODO: Think about using GCN instead of GATConv
        self.gat1 = dglnn.HeteroGraphConv({
            "d1" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "d10" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "d100" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "d1000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "d10000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "d100000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "d1000000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "ge" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "lt" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
        }, aggregate='sum')  # TODO: May need to change aggregate function (test it!) - ‘sum’, ‘max’, ‘min’, ‘mean’, ‘stack’.
        
        self.gat2 = dglnn.HeteroGraphConv({
            "d1" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "d10" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "d100" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "d1000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "d10000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "d100000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "d1000000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "ge" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
            "lt" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
        }, aggregate='sum') 

        self.output = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size * 2, bias=True),
            nn.ReLU(),
            nn.Dropout(self.args.dropout),
            nn.Linear(self.hidden_size * 2, self.num_labels, bias=True),
        )  # gte vs. lt

    def forward(self, g, dict_map, inputs, labels):
        '''
        Args

        g : dgl.heterograph
        inputs : input tuple of numbers (num1, num2)
        labels : 0 = "less than" , 1 = "greater than / equal to"
        '''
        num2idx, idx2num = dict_map
        # Initialize node representations

        g.nodes['digits'].data['feature']  = self.digit_feats  # TODO: This is NOT being updated!
        in_feats1 = {"digits": self.digit_feats, "numbers": g.nodes['numbers'].data['feature']}
        g_out1 = self.gat1(g, (in_feats1, in_feats1)) 
        
        in_feats2 = {"digits": F.relu(g_out1["digits"]), "numbers": F.relu(g_out1["numbers"])}
        g_out2 = self.gat2(g, (in_feats2, in_feats2))

        num_emb_mat = g_out2['numbers'].squeeze(-2)  #.detach()
        concat_num_pair = num_emb_mat[inputs].view(-1, self.args.hidden_size * 2)

        logits = self.output(concat_num_pair)

        loss = 0.0
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)  # TODO: Add DICE regularizer term

        return loss, logits

I just couldn’t find what’s causing this issue. Any idea?

Can you provide the code snippet of a toy example to reproduce the issue, including a toy graph?

The code snippet and a brief explanation of what’s being done

  1. I’m currently building a digit-to-number graph that constructs an arbitrary number representation (e.g., 217, 83, 912) out of digits (0~9) embeddings.

  2. The digit embeddings (in this case: self.g.nodes["digits"].data["feature"]) should have been updated after calling the optimizer.step().

  3. However, when I print the values of the digit embeddings ( self.g.nodes["digits"].data["feature"]) both Before and After the update, the digit embedding values simply do not change.

class Trainer(object):
    def __init__(self, args, g=None, graph_out=None, train_dataset=None, dev_dataset=None, test_dataset=None):
        self.args = args
        self.train_dataset = train_dataset
        self.dev_dataset = dev_dataset
        self.test_dataset = test_dataset
        self.graph_out = graph_out
        self.g = g

        self.base_num = 10
        self.lt = 0  # less than
        self.ge = 1  # greater than or equal to

        self.label_lst = get_label(args)
        self.num_labels = len(self.label_lst)

        self.model = NumberGraph(self.args)

        # GPU or CPU
        self.device = "cuda" if torch.cuda.is_available() and not self.args.no_cuda else "cpu"
        self.model.to(self.device)

    def train_by_generate(self):
        print("Entering trainer...")
        # Generate dataset -> (num1, num2, rel_label)
        t_total = self.args.total_steps
        # Prepare optimizer and scheduler (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': self.args.weight_decay},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num Epochs = %d", self.args.num_train_epochs)
        logger.info("  Total train batch size = %d", self.args.train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)
        logger.info("  Logging steps = %d", self.args.logging_steps)
        logger.info("  Save steps = %d", self.args.save_steps)

        global_step = 0
        tr_loss = 0.0
        tr_acc = 0.0
        self.model.zero_grad()

        for e in tqdm(range(int(self.args.num_train_epochs)), desc="Epoch"):
            for step in tqdm(range(t_total), desc="Iteration"):
                self.model.train()
                batch = torch.randint(self.base_num, 1000, (self.args.train_batch_size, 2), device=self.device)
                labels = torch.tensor([self.lt if d[0] < d[1] else self.ge for d in batch], dtype=torch.long, device=self.device)
                assert len(batch) == len(labels)

                self.g, dict_map, batch2idx = graph_update(self.args, self.g, batch)
                print("g.nodes[digits].data[feature] (Before): ", self.g.nodes["digits"].data["feature"])
                print("g.nodes[numbers].data[feature] (Before): ", self.g.nodes["numbers"].data["feature"])

                inputs = {'g': self.g,
                          'dict_map': dict_map,
                          'inputs': batch2idx,
                          'labels': labels}
                outputs = self.model(**inputs)
                loss, logits = outputs

                if self.args.gradient_accumulation_steps > 1:
                    loss = loss / self.args.gradient_accumulation_steps

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)

                preds = logits.detach().cpu().numpy()
                label_ids = labels.detach().cpu().numpy()
                preds = np.argmax(preds, axis=1)
                result = compute_metrics(preds, label_ids)

                tr_loss += loss.item()
                tr_acc += result["acc"]
                if (step + 1) % self.args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule

                    # TODO: This is where the g.nodes["digits"].data["feature"] should have been updated
                    print("g.nodes[digits].data[feature] (After): ", self.g.nodes["digits"].data["feature"])
                    print("g.nodes[numbers].data[feature] (After): ", self.g.nodes["numbers"].data["feature"])

                    self.model.zero_grad()
                    global_step += 1

                    if self.args.logger:
                        neptune.log_metric('Loss', tr_loss / global_step)
                        neptune.log_metric('Accuracy', tr_acc / global_step)

                    if self.args.do_eval and self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0:
                        self.evaluate_by_generate("dev")

                    # if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                    #     self.save_model()

        return global_step, tr_loss / global_step

And here is the value of digits embeddings Before and After

g.nodes[digits].data[feature] (Before):  tensor([[ 0.1940,  2.1614, -0.1721,  ..., -0.9226,  0.7893, -1.8010],        [-1.9351,  1.1007,  0.0251,  ...,  0.1943,  0.3696, -0.8600],
        [ 1.0769, -1.2650, -0.6424,  ..., -0.6821,  0.7974, -0.8484],
        ...,
        [ 0.5791, -1.1815, -0.5890,  ...,  1.4540, -2.7733, -0.8932],
        [-0.4574,  0.8063,  0.1580,  ...,  0.9749,  0.5732,  1.3790],
        [ 0.7398, -0.4790, -0.3156,  ...,  0.4012,  0.2405,  1.1301]],
       device='cuda:0')

g.nodes[digits].data[feature] (After):  tensor([[ 0.1940,  2.1614, -0.1721,  ..., -0.9226,  0.7893, -1.8010],
        [-1.9351,  1.1007,  0.0251,  ...,  0.1943,  0.3696, -0.8600],
        [ 1.0769, -1.2650, -0.6424,  ..., -0.6821,  0.7974, -0.8484],
        ...,
        [ 0.5791, -1.1815, -0.5890,  ...,  1.4540, -2.7733, -0.8932],
        [-0.4574,  0.8063,  0.1580,  ...,  0.9749,  0.5732,  1.3790],
        [ 0.7398, -0.4790, -0.3156,  ...,  0.4012,  0.2405,  1.1301]],
       device='cuda:0')

Could this possibly be because I did not include the embeddings for g.nodes['numbers'].data['features'] in the model parameter, as I did with the self.digit_feats?

Did you include self.digit_feats in optimizer_grouped_parameters?

Thanks man. That was exactly the issue.
I include both the self.digit_feats and self.number_feats to the model parameters using nn.Parameter()
it worked.

Thank you for the help!

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.