Hi,
I’m trying to train a simple audio classification model on Colab, but my GPU memory (running on a 16GB instance) use keeps expanding and getting out of control every few epochs. Here is the model definition and a minimal snippet of my training code:
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
nn.BatchNorm2d(out_channels),
nn.ELU(),
)
def forward(self, x):
x = self.layers(x)
return x
class AudioClassifier(nn.Module):
def __init__(self, stereo=True, dropout=0.1):
super().__init__()
in_channels = 2 if stereo else 1
self.spec = MelspectrogramStretch(hop_length=None,
num_mels=128,
fft_length=2048,
norm='whiten',
stretch_param=[0.4, 0.4])
self.features = nn.Sequential(*[
ConvBlock(in_channels=2, out_channels=32, kernel_size=3, stride=1),
nn.MaxPool2d(3,3),
nn.Dropout(p=dropout),
ConvBlock(in_channels=32, out_channels=64, kernel_size=3, stride=1),
nn.MaxPool2d(4,4),
nn.Dropout(p=dropout),
ConvBlock(in_channels=64, out_channels=64, kernel_size=3, stride=1),
nn.MaxPool2d(4,4),
nn.Dropout(p=dropout),
])
self.min_len = 80896
self.gru_hidden_size = 64
self.gru_layers = 2
self.rnn = nn.GRU(128, self.gru_hidden_size, num_layers=self.gru_layers)
self.ret = nn.Sequential(*[nn.Linear(self.gru_hidden_size,1), nn.Sigmoid()])
def modify_lengths(self, lengths):
def safe_param(elem):
return elem if isinstance(elem, int) else elem[0]
for name, layer in self.features.named_children():
if isinstance(layer, (nn.Conv2d, nn.MaxPool2d)):
p, k, s = map(safe_param, [layer.padding, layer.kernel_size,layer.stride])
lengths = ((lengths + 2*p - k)//s + 1).long()
return torch.where(lengths > 0, lengths, torch.tensor(1, device=lengths.device))
def _many_to_one(self, t, lengths):
return t[torch.arange(t.size(0)), lengths - 1]
def init_hidden(self, batch_size, device):
return torch.zeros(self.gru_layers, batch_size, self.gru_hidden_size, device=device)
def forward(self, wave, lengths):
x = wave
raw_lengths = lengths
xt = x.float().transpose(1,2)
xt, lengths = self.spec(xt, raw_lengths)
xt = self.features(xt)
lengths = self.modify_lengths(lengths)
x = xt.transpose(1, -1)
batch, time = x.size()[:2]
x = x.reshape(batch, time, -1)
lengths = lengths.clamp(max=x.shape[1])
# Handle variable input size
x_pack = torch.nn.utils.rnn.pack_padded_sequence(x, lengths.clamp(max=x.shape[1]), batch_first=True)
x_pack, self.hidden = self.rnn(x_pack)
x, _ = torch.nn.utils.rnn.pad_packed_sequence(x_pack, batch_first=True)
x = self._many_to_one(x, lengths)
x = self.ret(x)
return x
def train():
for epoch in range(1,epochs+1):
model.train()
batch_losses=[]
for batch_idx, batch in enumerate(pbar):
optimizer.zero_grad()
wave, lengths, lbl = batch
model.hidden = model.init_hidden(BATCH_SIZE, device)
pred = model(wave.to(device), lengths.to(device)).squeeze()
loss = loss_fn(pred, lbl.to(device))
loss.backward()
batch_losses.append(loss.detach().item())
optimizer.step()
del loss, pred, wave, lengths, lbl
I have tried deleting the prediction and input variables every loop, made sure I was detaching every variable I keep, and added an init_hidden()
function to refresh the hidden states. These have helped me get to around 3-4 epochs before crashing, but it still happens. I’m running a batch size of just 8, and the input audio files are at most 15-20 seconds long.
Is there anything I can do without reducing the batch size even further? Sorry if there are any stupid mistakes there but I’m a CV guy just getting into audio processing. The code from the classifier was heavily inspired by https://github.com/ksanjeevan/crnn-audio-classification.