Loss function for unsupervised heterogeneous GraphSAGE

I am trying to modify the unsupervised GraphSAGE to heterograph.
my graph has one node type and 3 edge types.
do you have any suggestions on how to compute the loss for such a case?
I tried to compute the loss for each edge type separately and aggregate it in some way (e.g. average loss across all edge types), but I couldn’t do it and keep the loss output object to apply backward.


This sounds quite reasonable. How did you compute the loss? Normally if you just aggregate them the loss should be backpropagatable.

I tried again and it worked! I used torch.stack instead of cat.
here is my function in case someone will be interested:

class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata['h'] = block_outputs
            for etype in pos_graph.canonical_etypes:
                pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype) 
            pos_score = pos_graph.edata['score']
        with neg_graph.local_scope():
            neg_graph.ndata['h'] = block_outputs
            for etype in neg_graph.canonical_etypes:
                neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype) 
            neg_score = neg_graph.edata['score']
        loss_list = []
        for etype in neg_graph.canonical_etypes:
            score = th.cat([pos_score[etype], neg_score[etype]])
            label = th.cat([th.ones_like(pos_score[etype]), th.zeros_like(neg_score[etype])]).long()
            loss = F.binary_cross_entropy_with_logits(score, label.float())
        mean_loss = th.mean(th.stack(loss_list))
        return mean_loss