Hi, I am running a slightly modified version of resnet18 (just added one more convent and batchnorm layers at the beginning of the network). When I start iterating over my dataset it starts training fine, but after some iterations I run out of memory. If I reduce the batch size, training runs some for more iterations, but it always ends up running out of memory.
Could you help me find my memory leak?
Oh, one more thing, if I select one batch and always iterate over the same batch, the network runs just fine, so it seems to be a problem with the dataloader. It seems to keep references to memory which arent getting cleaned up or something like that.
def train():
second_convnet = lalo.resnet2.resnet18(pretrained=False)
if os.path.isfile(CHECKPOINT_OUTPUT_FILE):
checkpoint = torch.load(CHECKPOINT_OUTPUT_FILE)
second_convnet.load_state_dict(checkpoint)
print("Checkpoint found, continuing with training...")
else:
print("No checkpoint found, training from scratch...")
second_convnet.cuda()
second_convnet.train()
criterion = torch.nn.CrossEntropyLoss().cuda()
learning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
optimizer = torch.optim.SGD(second_convnet.parameters(), learning_rate,
momentum=momentum,
weight_decay=weight_decay)
for i, (input, target) in enumerate(data_loader(PROCESSED_FOLDERS['training'], BATCH_SIZE)):
output = second_convnet(input)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Batch {} processed succesfully".format(i))
def data_loader(folder, batch_size):
""" Our dataset is very unbalanced, so I am forcing our data loader
to load the same amount of positive and negative samples.
"""
patients_list = []
labels_list = []
while True:
for _ in range(batch_size):
label = random.choice(['0', '1'])
patient_ids = os.listdir(os.path.join(folder, label))
patient_id = random.choice(patient_ids)
patient_path = os.path.join(folder, label, patient_id)
patients_list.append(torch.load(patient_path))
label = torch.Tensor([int(label)])
labels_list.append(label)
batch_variable = Variable(torch.stack(patients_list),
requires_grad=False)
batch_labels = torch.squeeze(Variable(torch.stack(labels_list),
requires_grad=False))
yield batch_variable, batch_labels.long().cuda()