Thanks, code below works:
gidx = graph._graph
src, dst, edge = gidx.edges(0)
adj = self.build_adj(src, dst, srcfeat)
out = self.gcn(srcfeat.unsqueeze(0).permute(0, 2, 1), adj.to(srcfeat.device)).permute(0, 2, 1)[0, torch.unique(dst).long(), ...]
gcn_feat = torch.cat([dstfeat, out], dim=-1)
And, come code:
# self.debug_pred_conn()
# output_bipartite = bipartites[-1]
# output_bipartite.srcdata['conv_features'] = neighbor_x_src.to('cpu')
# output_bipartite.dstdata['conv_features'] = center_x_src.to('cpu')
#
# output_bipartite.apply_edges(self.pred_conn)
# output_bipartite.edata['prob_conn'] = F.softmax(output_bipartite.edata['pred_conn'], dim=1)
# output_bipartite.update_all(self.pred_den_msg, fn.mean('pred_den_msg', 'pred_den'))
# return output_bipartite
can be replaced by:
output_bipartite = bipartites[-1]
# print("forward with raw_affine shape=", output_bipartite.edata['raw_affine'].shape, ", data=", output_bipartite.edata['raw_affine'][:9])
gidx = output_bipartite._graph
src, dst, edge = gidx.edges(0)
src_feat = self.src_mlp(neighbor_x_src[src.long()])
### feature alignment
len_center_x_src = center_x_src.shape[0]
len_src_feat = src_feat.shape[0]
repeat = int(len_src_feat / len_center_x_src)
center_x_src_r = torch.repeat_interleave(center_x_src, repeat, dim=0)
###
dst_feat = self.dst_mlp(center_x_src_r)
pred_conn = self.classifier_conn(src_feat + dst_feat)
output_bipartite.edata['pred_conn'] = pred_conn.to('cpu') # 保存用于计算loss
prob = F.softmax(pred_conn, dim=1)
raw_affine = output_bipartite.edata['raw_affine'].to(device)
pred_den_msg = raw_affine * (prob[:, 1] - prob[:, 0])
### 分组mean
g_pred_den_msg = torch.cat([pred_den_msg[i:i + repeat].unsqueeze(0) for i in range(0, len(pred_den_msg), repeat)], dim=0)
###
pred_den = torch.mean(g_pred_den_msg, dim=1)
output_bipartite.dstdata['pred_den'] = pred_den.to('cpu') # 存到cpu上中转,不会丢失梯度
return output_bipartite