I tried again and it worked! I used torch.stack instead of cat.
here is my function in case someone will be interested:
class CrossEntropyLoss(nn.Module):
def forward(self, block_outputs, pos_graph, neg_graph):
with pos_graph.local_scope():
pos_graph.ndata['h'] = block_outputs
for etype in pos_graph.canonical_etypes:
pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
pos_score = pos_graph.edata['score']
with neg_graph.local_scope():
neg_graph.ndata['h'] = block_outputs
for etype in neg_graph.canonical_etypes:
neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
neg_score = neg_graph.edata['score']
loss_list = []
for etype in neg_graph.canonical_etypes:
score = th.cat([pos_score[etype], neg_score[etype]])
label = th.cat([th.ones_like(pos_score[etype]), th.zeros_like(neg_score[etype])]).long()
loss = F.binary_cross_entropy_with_logits(score, label.float())
loss_list.append(loss)
mean_loss = th.mean(th.stack(loss_list))
return mean_loss