Normalized cut critera

Hi

I’m interested in implementing a normalized cut criteria as an additional loss function. I want this normalized cut loss to be incorporated into the back-propagation. For example, if the output of a given network are the logits, and pushed through a F.softmax(logits) layer, I want to be able to apply the normalized cut criteria to the predicted labels of the softmax output.

How can I go about doing this? Other forums mentioned that Pytorch does not explicitly compute the one-hot encoding matrix for losses like CrossEntropy. Below is an implementation that computes the ncut criteria – I’m just not sure if this would be incorporated into a global combination of loss functions such as LOSS = CrossEntropy() + NCuts() and then be incorporated into the backprop computation.

import torch
import numpy as np

def normalized_cut(graph, one_hot):
        
    """
    Parameters:
    - - - - - - - - -
    graph: DGL graph
        graph structure of data
    one_hot: int numpy array
        one-hot encoding matrix of predicted labels

    Returns:
    - - - - 
    loss: torch float tensor
        ncut loss value 
    """

    A = graph.adjacency_matrix()
    d = torch.sparse.sum(A, dim=0)

    one_hot = torch.tensor(one_hot, requires_grad=True).float()

    assoc = torch.matmul(one_hot.t(), torch.matmul(A, one_hot))
    degree = torch.matmul(d, one_hot)

    loss = torch.nansum(assoc.diag() / degree)

    return -loss

Any help is appreciated – perhaps this question might be more appropriate for the Pytorch forum? Thanks.

What do you want to update? Do you want to update the adjacency matrix?

No, I would like to use the NCuts lost, in addition to the CrossEntropy loss, to update the weights in a GCN.

The normalized cuts should help promote the constraint that the labels of nodes are contiguous in the graph domain – for my purposes (segmentation), this is important.

Assuming your implementation of normalized_cut is correct, then I guess you can pass the model prediction to a softmax function and then use the result for one_hot?

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.