Hi, I’m currently using dglnn.HeteroGraphConv to update the initial node embeddings. In my code, the initial node embeddings are the self.digit_feats
.
The self.digit_feats
is the matrix for the initial node embeddings, and they are not being updated after optimizer.step().
Apparently, dglnn.HeteroGraphConv
returns a dictionary of node embeddings with node types as the keys. I’m not sure if this is the problem, but could this be the reason causing a disconnection in the backward gradient flow?
The g
is a dgl.heterograph
with its node types and edges properly initialized, with node types digits
and numbers
.
class NumberGraph(nn.Module):
def __init__(self, args):
super(NumberGraph, self).__init__()
self.args = args
self.hidden_size = self.args.hidden_size
self.num_labels = self.args.num_labels
self.digit_feats = nn.Parameter(torch.Tensor(10, self.args.hidden_size))
# https://docs.dgl.ai/api/python/nn.pytorch.html#dgl.nn.pytorch.HeteroGraphConv
# TODO: Think about using GCN instead of GATConv
self.gat1 = dglnn.HeteroGraphConv({
"d1" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"d10" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"d100" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"d1000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"d10000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"d100000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"d1000000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"ge" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"lt" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
}, aggregate='sum') # TODO: May need to change aggregate function (test it!) - ‘sum’, ‘max’, ‘min’, ‘mean’, ‘stack’.
self.gat2 = dglnn.HeteroGraphConv({
"d1" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"d10" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"d100" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"d1000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"d10000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"d100000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"d1000000" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"ge" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
"lt" : dglnn.GATConv(self.hidden_size, self.hidden_size, num_heads=1),
}, aggregate='sum')
self.output = nn.Sequential(
nn.Linear(self.hidden_size * 2, self.hidden_size * 2, bias=True),
nn.ReLU(),
nn.Dropout(self.args.dropout),
nn.Linear(self.hidden_size * 2, self.num_labels, bias=True),
) # gte vs. lt
def forward(self, g, dict_map, inputs, labels):
'''
Args
g : dgl.heterograph
inputs : input tuple of numbers (num1, num2)
labels : 0 = "less than" , 1 = "greater than / equal to"
'''
num2idx, idx2num = dict_map
# Initialize node representations
g.nodes['digits'].data['feature'] = self.digit_feats # TODO: This is NOT being updated!
in_feats1 = {"digits": self.digit_feats, "numbers": g.nodes['numbers'].data['feature']}
g_out1 = self.gat1(g, (in_feats1, in_feats1))
in_feats2 = {"digits": F.relu(g_out1["digits"]), "numbers": F.relu(g_out1["numbers"])}
g_out2 = self.gat2(g, (in_feats2, in_feats2))
num_emb_mat = g_out2['numbers'].squeeze(-2) #.detach()
concat_num_pair = num_emb_mat[inputs].view(-1, self.args.hidden_size * 2)
logits = self.output(concat_num_pair)
loss = 0.0
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels) # TODO: Add DICE regularizer term
return loss, logits
I just couldn’t find what’s causing this issue. Any idea?