Hi all, I’m working on a super-resolution CNN model and for some reason or another I’m running into GPU memory issues. I’m using the following training and validation loops in separate functions, and I am taking care to detach tensor data as appropriate, to prevent the computational graph from being replicated needlessly (as discussed in many other issues flagged in this forum):
Training Function:
def run_train(self, x, y, *args, **kwargs):
if self.eval_mode:
raise RuntimeError('Model initialized in eval mode, training not possible.')
self.net.train() # sets model to training mode (activates appropriate procedures for certain layers)
x, y = x.to(device=self.device), y.to(device=self.device)
out = self.run_model(x, **kwargs) # run data through model
loss = self.criterion(out, y) # compute loss
self.optimizer.zero_grad() # set all weight grads from previous training iters to 0
loss.backward() # backpropagate to compute gradients for current iter loss
if self.grad_clip is not None: # gradient clipping
nn.utils.clip_grad_norm_(self.net.parameters(), self.grad_clip)
self.optimizer.step() # update network parameters
if self.learning_rate_scheduler is not None:
self.learning_rate_scheduler.step()
return loss.detach().cpu().numpy()
Validation Function:
def run_eval(self, x, y=None, request_loss=False, tag=None, *args, **kwargs):
self.net.eval() # sets the system to validation mode
with torch.no_grad():
x = x.to(device=self.device)
out = self.run_model(x, image_names=tag, **kwargs) # forward the data in the model
if request_loss:
y = y.to(device=self.device)
loss = self.criterion(out, y).detach().cpu().numpy() # compute loss
else:
loss = None
return out.detach().cpu(), loss
For some reason, the GPU runs out of memory only in the middle of either the training run or in the middle of a validation run (i.e. after a number of images have already been tested/fed into the model without issue). This seems to be due to memory building up throughout validation/training. I have attempted to probe the issue by clearing the pytorch cache and deleting variables before exiting the function, but nothing seems to help this problem. This GPU buildup reaches a certain limit before stopping, and seems to be dependant on the training batch size (validation batch size is always 1). I.e. if I, for example, set the train batch size to 16, the GPU memory builds up to ~8GB during validation, then stops there for the remainder of the training run. If I set the batch size to 8, the GPU memory buildup stops at ~4GB, and sticks there (these are both hypothetical examples). This means that I have to severely limit my batch size in order to allow training to occur, which is too much of a tradeoff for me.
Do you have any further insight into what could be going wrong? Thanks for your help!