Confuse about memory consumption

I have a trouble with memory consumption. If I change the edge classification predictor output dimension from five to one, the cuda memory consumption will get an unnormally increase.
I follow the tutorial and write the predictor code:

class MLPPredictor(nn.Module, ABC):
def __init__(self,
             in_units,
             num_classes,
             dropout_rate=0.0):
    super(MLPPredictor, self).__init__()
    self.dropout = nn.Dropout(dropout_rate)
    self.predictor = nn.Sequential(
        nn.Linear(in_units * 2, in_units, bias=False),
        nn.Tanh(),
        nn.Linear(in_units, num_classes, bias=False),
    )
    self.reset_parameters()

def reset_parameters(self):
    for p in self.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

def apply_edges(self, edges):
    h_u = edges.src['h']
    h_v = edges.dst['h']
    score = self.predictor(th.cat([h_u, h_v], dim=1))
    return {'score': score}

def forward(self, graph, ufeat, ifeat):
    graph.nodes['movie'].data['h'] = ifeat
    graph.nodes['user'].data['h'] = ufeat
    with graph.local_scope():
        graph.apply_edges(self.apply_edges)
        return graph.edata['score']

If the variable num_class is five, the memory usage after counting CrossEntropy loss is about 1401M.
If I change the num_class to one, the memory usage after counting MSE will increase to 11623M.

Is there anyone meet the same situation? Can you help me figure out the reason of increased memory consumption? Or can you give me some possible solutions or methods to find out the reason?

Why did you use MSE for edge classification. Have you tried BCEWithLogitsLoss?

Thanks for your reply.

I am reproducing GCMC. This paper using RGCN-like model to predict rating (e.g. 1-5) by given a user and an item. In the paper, the author trains model by classification task, regarding different rating as different class. The difference between using CrossEntropy or using MSE is regrading the rating as discrete or continuous value.

The rating is from 1-5. So there is at least five classes. I dont try BCELoss.

Can you try BCEWithLogitsLoss? To reproduce the issue, can you provide a minimal code snippet for it?

I tried BCELoss then I found how to repreoduce the weird memory consumption issue: when the output shape is N * 1. After squeezing output shape into N, the memory comsumption is normal.

My dgl version is 0.61 with cuda 10.2

Can you also post your code?

Here is the code:

class MLPPredictor(nn.Module):

    def __init__(self,
                 in_units,
                 do_squeeze=True,
                 dropout_rate=0.0):
        super(MLPPredictor, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.predictor = nn.Sequential(
            nn.Linear(in_units * 2, in_units, bias=False),
            nn.Tanh(),
            nn.Linear(in_units, 1, bias=False),
        )
        self.do_squeeze=do_squeeze

        self.reset_parameters()

    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        if self.do_squeeze:
            score = self.predictor(th.cat([h_u, h_v], dim=1)).squeeze()
        else:
            score = self.predictor(th.cat([h_u, h_v], dim=1))

        return {'score': score}

    def forward(self, graph, ufeat, ifeat):
        graph.nodes['movie'].data['h'] = ifeat
        graph.nodes['user'].data['h'] = ufeat
        with graph.local_scope():
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']

Dataset is ML-100k with 944 users and 1683 items. If do not squeeze, the code will try to allocate 19.31GB Memory.

My enviroments:
dgl=0.62
pytorch=1.60

Moreover, there is complete code https://github.com/JarenceSJ/dgi_snippet

This is because dataset.train_labels is of shape (N). If the prediction is of shape (N, 1), this will trigger broadcasting. If you set reduction='none' in initializing the MSELoss object, the result will be (N, N).