I’m having a strange issue where not a single one of my model’s parameters is getting updated. I saw another answer that’s similar to the problem I have (Model.parameters() is None while training), but the solution that’s given doesn’t quite solve my problem, because notably every one of my model’s parameter grads is None.
for p in model.parameters(): if p.grad is not None: print(p.grad.data)
I.e. the above loop won’t print out anything for my model. And even after calling loss.backward() and optimizer.step(), the parameter grads all remain None.
I’ve verified that a loss is being calculated, it just seems that optimization isn’t taking place, even if the backward() and step() functions are definitely getting run.
This is my optimization code:
def step(batch, model, criterion, optimizer=None): # let go of old gradients model.zero_grad() X = batch["X"].to(DEVICE) y = batch["y"].to(DEVICE) ## Forward Pass ## predictions = model(inputs) ## Calculate Loss ## loss = criterion(predictions, y) if optimizer is not None: # backward pass + optimize loss.backward() optimizer.step() return loss def train_model(model=None, lr=0.01): criterion = nn.CrossEntropyLoss().to(DEVICE) params = list(filter(lambda p: p.requires_grad, model.parameters())) optimizer = torch.optim.Adam(params=params, lr=lr) for epoch in range(1, N_EPOCHS+1): for i, batch in enumerate(tqdm(train_loader)): loss = step(batch, model, criterion, optimizer=optimizer) model = NeuralNet() model.to(DEVICE) train_model(model=model)
The only thing I can think of is that it’s that how I’ve split my optimization code into functions, rather than keeping it all in the same scope. I was able to get a toy neural net with toy randn()-generated data to train properly in a single training loop that isn’t spread out across multiple functions (the model was initialized within the same scope it was trained).
I have a couple pre-trained
torch.nn modules in NeuralNet(): a pre-trained Embedding layer and a pre-trained ResNet, both of which have frozen weights by setting
What could be causing this issue?