I found that the ngcf model in the dgl examples is updated based on the graph. When the graph is very large, get oom, so is there a version based on block training?
Infact,i tried to implement the code according to the training method of block and encountered an error:
Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
my code:
self.layer_out_dict_arr = []
for i in range(self.num_layers):
l = {k: torch.zeros_like(v, requires_grad=False).to(device) for k, v in self.feature_dict.items()}
def forward(self, pair_graph, neg_pair_graph, block, user_key, item_key):
h_dict = {ntype: self.feature_dict[ntype] for ntype in block.ntypes}
for i, layer in enumerate(self.layers):
feature_dict, ids_dict = layer(block, h_dict)
for ntype in block.ntypes:
for j, id in enumerate(ids_dict[ntype]):
self.layer_out_dict_arr[i][ntype][id] = feature_dict[ntype][j]
h_dict = self.layer_out_dict_arr[i]
# h_dict = {k: v.to(block.device) for k, v in h_dict.items()}
user_embd = torch.cat([self.layer_out_dict_arr[i][user_key] for i in range(len(self.layers))], 1)
item_embd = torch.cat([self.layer_out_dict_arr[i][item_key] for i in range(len(self.layers))], 1)
# pdb.set_trace()
res = {}
for srctype, etype, dsttype in pair_graph.canonical_etypes:
pos_src, pos_dst = pair_graph.edges(etype=(srctype, etype, dsttype))
neg_src, neg_dst = neg_pair_graph.edges(etype=(srctype, etype, dsttype))
h_pos = user_embd[pos_src] * item_embd[pos_dst]
h_neg = user_embd[neg_src] * item_embd[neg_dst]
res[(srctype, etype, dsttype)] = (h_pos, h_neg)
return res
and origin code in https://github.com/dmlc/dgl/blob/master/examples/pytorch/NGCF/NGCF/model.py:
def forward(self, g, user_key, item_key, users, pos_items, neg_items):
h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes}
# obtain features of each layer and concatenate them all
user_embeds = []
item_embeds = []
for layer in self.layers:
h_dict = layer(g, h_dict)
user_embd = torch.cat(user_embeds, 1)
item_embd = torch.cat(item_embeds, 1)
u_g_embeddings = user_embd[users, :]
pos_i_g_embeddings = item_embd[pos_items, :]
neg_i_g_embeddings = item_embd[neg_items, :]
return u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings