Hi I am doing minibatch training on my own dataset. I am following this tutorial:
https://docs.dgl.ai/en/1.0.x/guide/minibatch-link.html
but I need the node embeddings for all the nodes in the entire graph.
For traditional link prediction task with global training, for example, this tutorial:
https://docs.dgl.ai/en/0.8.x/guide/training-link.html
we can use the following
node_embeddings = model.sage(graph, node_features)
But in minibatch training, model.sage takes (blocks, feature) instead of (graph, feature). I think I have to do something like node_embedding = model.sage(blocks, feature) where, this time, blocks = [blocks[0], blocks[1]] and blocks[0] and blocks[1] are from the entire graph. We convert the entire graph to appropriate blocks and feed to the forward function of sage.
Does that make sense? If it makes sense, how do I convert g (the entire graph) to the appropriate blocks. If not, how can I get the the embeddings for every node in g. What I did was something like:
def forward(self, blocks, x, global_item_embedding_dic, global_user_embedding_dic, epoc_count):
x = self.conv1(blocks[0], x)
x = {k: F.relu(v) for k, v in x.items()}
x = self.conv2(blocks[1], x)
if epoc_count == num_epoch - 1:
item_emb = x['item'].detach().numpy().tolist()
user_emb = x['user'].detach().numpy().tolist()
item_id_list = list(blocks[1].dstnodes['item'].data[dgl.NID].numpy())
user_id_list = list(blocks[1].dstnodes['user'].data[dgl.NID].numpy())
for iid, emb in zip(item_id_list, item_emb):
global_item_embedding_dic[iid] = emb
for uid, emb in zip(user_id_list, user_emb):
global_user_embedding_dic[uid] = emb
return x, global_item_embedding_dic, global_user_embedding_dic
item_tuple_list = [ (k,v) for k,v in global_item_embedding_dic.items()]
item_tuple_list = sorted(item_tuple_list)
item_tuple_list = [v for (k,v) in item_tuple_list]
item_feat_tensor = torch.FloatTensor(item_tuple_list)
user_tuple_list = [ (k,v) for k,v in global_user_embedding_dic.items()]
user_tuple_list = sorted(user_tuple_list)
user_tuple_list = [v for (k,v) in user_tuple_list]
user_feat_tensor = torch.FloatTensor(user_tuple_list)
node_embeddings = {'user':user_feat_tensor, 'item':item_feat_tensor}
In each minibatch in the final epoc, I use global dictionary and update embedding for each node, and eventually sort the features to get a 2*2 tensor for final embeddings.
Please share some thoughts, thanks in advance!