How to train a subset of model weights using masks while using AMP in PyTorch (using NVIDIA DL/Resnet implementation)?

Hello all, I am trying to train a Resnet50 model on imagenet. My requirement is that I want to only train a subset of weights using a mask. I have successfully done this with general resnet models on cifar10 using the following sequence of code:

        optimizer.zero_grad()
        loss.backward()
        for name, p in model.named_parameters():
            if 'weight' in name:
                p.grad.data = p.grad.data * masks[name]
        optimizer.step()

Where masks is a dictionary containing layerwise masks with values of 1s and 0s.

Now, currently I am trying to incorporate the same idea for distributed imagenet training. I am using nVidia’s Resnet repo which utilizes AMP. I tried with including the following codeblock in nVidia’s implementation in this function in line 209 here:

        scaler.scale(loss).backward()

        for name, p in model_and_loss.named_parameters():
            if 'weight' in name:
                p.grad.data = p.grad.data * masks[name]

However, with the above code, still all of the weights are getting trained after gradient step. The optimizer step is done with (I think) here:

        optimizer_step = ((i + 1) % batch_size_multiplier) == 0
        loss = step(input, target, optimizer_step=optimizer_step)
        if ema is not None:
            ema(model_and_loss, epoch*steps_per_epoch+i)

I am unsure, how to achieve this in this nvidia implementation, also unsure how the gradient step is being taken.

My goal was to zero-out a subset of the gradients before taking the optimizer.step(), however with AMP it does not seem to be straightforward, specially in nvidia’s repo.

My question is, is there a different way to accomplish this in the training code here.

  1. Moreover, during AMP gradient update, is the same gradients from parameters.grad are used or am I missing something?

Thanks a lot for reading this. Please let me know if I need to clarify something here, I hope to be quick to response. I appreciate the help.