CUDA out of memory when training audio RNN (GRU)

Hi,
I’m trying to train a simple audio classification model on Colab, but my GPU memory (running on a 16GB instance) use keeps expanding and getting out of control every few epochs. Here is the model definition and a minimal snippet of my training code:

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ELU(),
        )

    def forward(self, x):
        x = self.layers(x)
        return x


class AudioClassifier(nn.Module):
    def __init__(self, stereo=True, dropout=0.1):
        super().__init__()
        in_channels = 2 if stereo else 1
        self.spec = MelspectrogramStretch(hop_length=None, 
                                          num_mels=128, 
                                          fft_length=2048, 
                                          norm='whiten', 
                                          stretch_param=[0.4, 0.4])

        self.features = nn.Sequential(*[
            ConvBlock(in_channels=2, out_channels=32, kernel_size=3, stride=1),
            nn.MaxPool2d(3,3),
            nn.Dropout(p=dropout),
            ConvBlock(in_channels=32, out_channels=64, kernel_size=3, stride=1),
            nn.MaxPool2d(4,4),
            nn.Dropout(p=dropout),
            ConvBlock(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.MaxPool2d(4,4),
            nn.Dropout(p=dropout),
        ])
        self.min_len = 80896
        self.gru_hidden_size = 64
        self.gru_layers = 2

        self.rnn = nn.GRU(128, self.gru_hidden_size, num_layers=self.gru_layers) 
        self.ret = nn.Sequential(*[nn.Linear(self.gru_hidden_size,1), nn.Sigmoid()])
  
    
    def modify_lengths(self, lengths):
        def safe_param(elem):
            return elem if isinstance(elem, int) else elem[0]

        for name, layer in self.features.named_children():
            if isinstance(layer, (nn.Conv2d, nn.MaxPool2d)):
                p, k, s = map(safe_param, [layer.padding, layer.kernel_size,layer.stride]) 
                lengths = ((lengths + 2*p - k)//s + 1).long()

        return torch.where(lengths > 0, lengths, torch.tensor(1, device=lengths.device))

    def _many_to_one(self, t, lengths):
        return t[torch.arange(t.size(0)), lengths - 1]

    def init_hidden(self, batch_size, device):
        return torch.zeros(self.gru_layers, batch_size, self.gru_hidden_size, device=device)

    def forward(self, wave, lengths):
        x = wave
        raw_lengths = lengths
        xt = x.float().transpose(1,2)
        xt, lengths = self.spec(xt, raw_lengths)
        xt = self.features(xt)
        lengths = self.modify_lengths(lengths)
        x = xt.transpose(1, -1)

        batch, time = x.size()[:2]
        x = x.reshape(batch, time, -1)
        lengths = lengths.clamp(max=x.shape[1])

        # Handle variable input size
        x_pack = torch.nn.utils.rnn.pack_padded_sequence(x, lengths.clamp(max=x.shape[1]), batch_first=True)
        x_pack, self.hidden = self.rnn(x_pack)
        x, _ = torch.nn.utils.rnn.pad_packed_sequence(x_pack, batch_first=True)
        x = self._many_to_one(x, lengths)
        x = self.ret(x)
        return x



def train():
    for epoch in range(1,epochs+1):
        model.train()
        batch_losses=[]
        
        for batch_idx, batch in enumerate(pbar):
            optimizer.zero_grad()
            wave, lengths, lbl = batch
            model.hidden = model.init_hidden(BATCH_SIZE, device)
            
            pred = model(wave.to(device), lengths.to(device)).squeeze()
            loss = loss_fn(pred, lbl.to(device))
            
            loss.backward()
            batch_losses.append(loss.detach().item())
            optimizer.step()
            
            del loss, pred, wave, lengths, lbl

I have tried deleting the prediction and input variables every loop, made sure I was detaching every variable I keep, and added an init_hidden() function to refresh the hidden states. These have helped me get to around 3-4 epochs before crashing, but it still happens. I’m running a batch size of just 8, and the input audio files are at most 15-20 seconds long.

Is there anything I can do without reducing the batch size even further? Sorry if there are any stupid mistakes there but I’m a CV guy just getting into audio processing. The code from the classifier was heavily inspired by https://github.com/ksanjeevan/crnn-audio-classification.