Loss function for unsupervised heterogeneous GraphSAGE

Hi,
I am trying to modify the unsupervised GraphSAGE to heterograph.
my graph has one node type and 3 edge types.
do you have any suggestions on how to compute the loss for such a case?
I tried to compute the loss for each edge type separately and aggregate it in some way (e.g. average loss across all edge types), but I couldn’t do it and keep the loss output object to apply backward.

Thanks!

This sounds quite reasonable. How did you compute the loss? Normally if you just aggregate them the loss should be backpropagatable.

I tried again and it worked! I used torch.stack instead of cat.
here is my function in case someone will be interested:

class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata['h'] = block_outputs
            for etype in pos_graph.canonical_etypes:
                pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype) 
            pos_score = pos_graph.edata['score']
        with neg_graph.local_scope():
            neg_graph.ndata['h'] = block_outputs
            for etype in neg_graph.canonical_etypes:
                neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype) 
            neg_score = neg_graph.edata['score']
        loss_list = []
        for etype in neg_graph.canonical_etypes:
            score = th.cat([pos_score[etype], neg_score[etype]])
            label = th.cat([th.ones_like(pos_score[etype]), th.zeros_like(neg_score[etype])]).long()
            loss = F.binary_cross_entropy_with_logits(score, label.float())
            loss_list.append(loss)
        mean_loss = th.mean(th.stack(loss_list))
        return mean_loss