The memory consumption starts with around 4GB but keeps increasing. In the mid of the training it OOMs.
I use VGG16 to extract features.
def get_features(data, N, body, scaler):
data = data[0]
MIN, MAX = min_max_lists(data, N)
for i in range(len(data)):
data[i] = scale_minmax(data[i], MIN[i], MAX[i], min=0, max=1)
loader = DataLoader(data, shuffle=False, batch_size=32)
features = []
# the picture has to be upscaled to 224x224
# as the VGG16 needs this size
# and I extract the features of the avgpool layer
for batch in tqdm(loader):
# in [0] the raw vectors of n_mels * n_frames are stored
# and I reshape those to a 2D form
batch = batch.reshape(batch.shape[0], 64, 64)
# data in vectors needs to be transposed to actually resemble the slice of a
# mel spectrogram
batch = np.transpose(batch, (0, 2, 1))
#batch = batch / 255
# now VGG takes RGB = 3 channels tensors,
# so I have to copy channels
batch = np.repeat(batch[..., np.newaxis], 3, -1)
# now a transpose, such that the tensor is of shape batchsize, channels, width, height
batch = np.transpose(batch, (0, 3, 1, 2))
batch = scaler(batch).float()
#batch = torch.from_numpy(batch)
inp = torch.nn.functional.interpolate(batch, size=(224, 224), mode='bilinear')
inp = inp.float().to(device)
with torch.no_grad():
out = body(inp)
out = torch.flatten(out['avgpool'], 1)
features.append(out.detach())
del out
# from here I concatenate the tensors to
# one big tensor
features = torch.cat(features)
features = features.detach()
return features
m = vgg16().to(device)
body = create_feature_extractor(m, return_nodes={'avgpool':'avgpool'})