GPU Out of memory on loss.backward() call

hey guys, i’m facing a huge issue of running out of memory on my backward calls. The thing is, I’m already training a single sample at a time. I’m not sure if operations like torch.cat is causing some issue. At the same time, I can’t seem to figure out where possible memory leaks are happening.

I’m using the torch_geometric package for some graph neural network learning, and combining this with a custom algorithm. Here’s a general overview of what my code does


class network(torch.nn.Module):
    def __init__(self, in_channels, out_channels, h_size, e_size, n_embed, **kwargs):
        super(network, self).__init__(**kwargs)

        # Feature sizes
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.h_size = h_size
        self.n_embed = n_embed

        # Embeddings
        self.embed = torch.nn.Embedding(n_embed, out_channels)

        # MPNN
        self.wt = torch.nn.Sequential(torch.nn.Linear(e_size, h_size * 2),
                                      torch.nn.ReLU(),
                                      torch.nn.Linear(h_size * 2, in_channels * out_channels))
        self.mpnn = torch_geometric.nn.NNConv(in_channels, out_channels, self.wt)
        self.m_bn = torch.nn.BatchNorm1d(out_channels)
        self.m_relu = torch.nn.ReLU()
        self.gru = torch.nn.GRU(out_channels, out_channels)
        self.g_bn = torch.nn.BatchNorm1d(out_channels)

        # Weight translation layer
        self.weight_transform = torch.nn.Sequential(torch.nn.Linear(out_channels * 2, h_size * 4),
                                                    torch.nn.BatchNorm1d(h_size * 4),
                                                    torch.nn.ReLU(),
                                                    torch.nn.Linear(h_size * 4, h_size * 2),
                                                    torch.nn.BatchNorm1d(h_size * 2),
                                                    torch.nn.ReLU(),
                                                    torch.nn.Linear(h_size * 2, 1),
                                                    torch.nn.BatchNorm1d(1))

    def reset_parameters(self):
        self.mpnn.reset_parameters()
        self.gru.reset_parameters()

    def forward(self, data):
        n_nodes = data.d_feat.size(0)
        x = torch.cat((data.p_feat, data.d_feat), dim=0)
        r_edge_index = torch.cat((data.real_edges_original,
                                  data.real_edges_original.index_select(1, torch.LongTensor([1, 0]).to(
                                      data.pax_feat.device))), dim=0).t()
        r_edge_attr = torch.cat((data.real_edge_attr, data.real_edge_attr), dim=0).unsqueeze(1)
        # Message passing NN
        for i in range(3):
            m = self.mpnn(x, edge_index=r_edge_index, edge_attr=r_edge_attr)
            m = self.m_bn(m)
            m = self.m_relu(m)
            if i == 0:
                h = m.clone().unsqueeze(0)
            out, h = self.gru(m.unsqueeze(0), h)
            out = out.squeeze(0)
            out = self.g_bn(out)
        # Get embeddings
        embeds = self.embed(data.p_embed)

        # Stack matrices
        real_p = out[:data.p_feat.size(0), :]
        real_d = out[data.p_feat.size(0):, :]
        all_feat = torch.cat((real_p, embeds, real_d), dim=0)
        # Construct weight matrix
        node_pairs = torch.cat((all_feat[data.assign_edge_index[0, :]], all_feat[data.assign_edge_index[1, :]]), dim=1)
        wts = self.weight_transform(node_pairs)

        wt_mat = torch.zeros((n_nodes * 2, n_nodes * 2), device=data.p_feat.device)
        wt_mat[data.assign_edge_index[0, :], data.assign_edge_index[1, :]] = wts.view(-1)
        wt_mat = wt_mat[:n_nodes, n_nodes:].contiguous()
        # LBP
        b_a, b_b = self.sim_lbp(n_nodes, wt_mat, 2)

        return b_b

The sim_lbp function just has a copy of b_a and b_b for updates and does max and sum operations only.

I’m kinda stumped at this stage as I’m using torch 1.4.0, on CUDA 10.1, on a NVIDIA V100 16Gb GPU.

I’m not sure why this should be caused by a memory leak.
The backward pass will use some memory to store all gradients.
Depending on your model architecture and thus the shape of these gradients, the memory might increase by a large amount.

If you are seeing the OOM in the first iteration(s), then this is most likely the cause and you could try to use e.g. torch.utils.checkpoint to trade compute for memory.

However, if you are seeing the OOM after e.g. an epoch, then you might want to check if you are accidentally storing the computation graph somewhere, e.g. via losses.append(loss) (you would have to call loss.detach() in order to avoid storing the whole graph).

That’s what’s puzzling me as well. The only saving of loss that I do is in storing the loss for tensorboard. Here’s my training loop

 for epoch in tqdm(range(START_EPOCH, NUM_EPOCHS)):
        if ENABLE_LOGS:
            torch.save(
                get_model_dict(model, optim, sched, epoch, cnt, cnt_val),
                'NeuralBP_featsize_{}_batchsize_{}_epoches_{}.pt'.format(NAME, OUT_FEAT_SIZE, BSIZE, epoch))
        # Training Loop
        for i, data in tqdm(enumerate(train_loader)):
            optim.zero_grad()

            # Feed-forward
            data = data.to(device)
            out = model(data)
            out = torch.softmax(out, dim=1)
            out = log_sinkhorn(out, num_it=10)

            loss = loss_fn(out.reshape(-1), data.lbl_mat.reshape(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm(parameters, 1.0)
            optim.step()
            loss_val = float(loss)
            if ENABLE_LOGS:
                writer.add_scalar('loss', loss_val, cnt)
            cnt += 1
        sched.step()
        # Validation loop
        if epoch % VAL_EPOCH == 0:
            with torch.no_grad():
                for j, data in tqdm(enumerate(val_loader)):
                    data = data.to(device)
                    out = model.forward(data)

                    res = torch.softmax(out, dim=1)
                    res = log_sinkhorn(res, num_it=10)
                    tmax = torch.max(res, dim=1, keepdim=True)[0]
                    mask = res.ge(tmax)
                    res[mask] = 1.0
                    res[~mask] = 0.0

                    acc = accuracy_score(res.cpu().numpy().reshape(-1), data.lbl_mat.cpu().numpy().reshape(-1))
                    prec = precision_score(res.cpu().numpy().reshape(-1), data.lbl_mat.cpu().numpy().reshape(-1))
                    recall = recall_score(res.cpu().numpy().reshape(-1), data.lbl_mat.cpu().numpy().reshape(-1))
                    f1 = f1_score(res.cpu().numpy().reshape(-1), data.lbl_mat.cpu().numpy().reshape(-1))
                    if ENABLE_LOGS:
                        writer.add_scalar('acc', acc, cnt_train)
                        writer.add_scalar('prec', prec, cnt_train)
                        writer.add_scalar('recall', recall, cnt_train)
                        writer.add_scalar('f1', f1, cnt_train)

                    cnt_train += 1