Detaching and reattaching model depending on phase

I was wondering, if I explicitly detach the model and inputs etc during the validation phase, do I have to reattach it in the next epoch of the training loop?

You can detach() a tensor, which is attached to the computation graph, but you cannot “detach” a model.
If you don’t disable the gradient calculation (e.g. via torch.no_grad()), the forward pass will create the computation graph and the model output tensor will be attached to it. You can check the .grad_fn of the output tensor to see, if it’s attached to a graph (shows a backward function) or not (is None).

In case you haven’t used the with torch.no_grad() context manager during validation, you might unnecessarily use more memory, but won’t have any other side effects and don’t need to reattach anything (which is also not possible).
The next training iteration will work as before.

thank you for the answer!

My question was motivated precisely due to the doubling of memory usage during validation!

Despite wrapping it with

with torch.no_grad()

As soon as I reach validation, my memory usage doubles, then it increases again when I compute the Class Activation Maps…

Could it be caused by me doing:

for input,label in val_loader:
      input = input.to(self.device)
      label = label.to(device)

But as far as I remember, I was getting a

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

error if I didn’t place the inputs and labels on the device…

You need to place the input and target tensors on the same device as the model parameters.
There are advanced use cases for model sharding, which doesn’t seem to be used in your case.

If your validation loop increases the memory usage, you might want to either delete all unnecessary training tensors manually or wrap the training code in a train() function (and val() function), so that the temporary variables will be deleted automatically. Python uses function scoping so temp. variables, such as the model output, loss etc. will be deleted once you are leaving the function scope.