Wrap model_A with torch.no_grad()
or use .detach()
on the outputs(preferably the former for avoiding memory issues).
with torch.no_grad():
data=model_A(input)
#rest of training not wrapped
Wrap model_A with torch.no_grad()
or use .detach()
on the outputs(preferably the former for avoiding memory issues).
with torch.no_grad():
data=model_A(input)
#rest of training not wrapped