The situation is as follows. I have a pytorch training loop with roughly the following structure:
optimizer = get_opt() train_data_loader = Dataloader() net = get_model() for epoch in range(epochs): for batch in train_data_loader: output = net(batch) output["loss"].backward() optimizer.step() optimizer.zero_grad()
My dataset contains images and is very large (in number of items) and takes about half a day per epoch. Now some samples in my dataset can be of bad quality.
I want my training loop therefore to catch errors that incur during the forward pass, log them, and continue training, instead of interrupting the whole program because of some bad data element. This is what I had in mind:
optimizer = get_opt() train_data_loader = Dataloader() net = get_model() for epoch in range(epochs): for batch in train_data_loader: try: output = net(batch) except Exception as e: logging.error(e, exc_info=True) # log stack trace continue output["loss"].backward() optimizer.step() optimizer.zero_grad()
This way, if a forward pass fails, it will just get the next batch and not interrupt training. This works great for the validation loop, but during training I run into problems: GPU memory will not be released after the try/catch, and so I run into an OOM when pytorch tries to put the next batch on the GPU.
Things I’ve tried: after an error is catched (i.e. within the
- move every parameter in
net.parameters()to cpu, and/or detach them
- delete every parameter in
- manually run
gc.collect()(python garbage collection)
- manually set all gradients of every parameter in
None of these methods made any difference. Is there a way to do this? I suspect its the gradients/ graph not being cleared, as this problem does not happen during validation (e.g. within a
with torch.no_grad() context).