Transfer learning with mixed-precision


I am currently trying to figure out how to facilitate mixed-precision training when using transfer learning.

Typically, here’s how I am building my custom model using a pre-trained base model:

# load up the ResNet50 model
model = models.resnet50(pretrained=True)

# append a new classification top to our feature extractor and pop it
# on to the current device
modelOutputFeats = model.fc.in_features
model.fc = nn.Linear(modelOutputFeats, len(trainDS.classes))
model =

I am also using nn.DataParallel. As demonstrated here, simply putting model inside autocast context, won’t help. So, is there a way I could perform it using the practices recommended here?

Doing the following is working for me but I am still a bit unsure if the autocast state if being sent to multiple devices. @ptrblck could you advise?

for (i, (x, y)) in enumerate(trainLoader):
	with torch.cuda.amp.autocast(enabled=True):
		# send the input to the device
		(x, y) = (,

		# perform a forward pass and calculate the training loss
		pred = model(x)
		loss = lossFunc(pred, y)

	# calculate the gradients

I’m not sure I understand the issue correctly. Could you explain what autocast state would refer to and what the initial issue was that you were seeing?
The linked docs give an example of how to use nn.DataParallel. Are you seeing any errors with it?

@ptrblck the linked doc (first one) discusses AMP when using nn.DataParallel. It states that the forward method of my model needs to be inside an autocast context.

My question is how do I achieve that when I am using a pre-trained model? Would the snippet I provided work in that case?

The autocast state would refer to the with torch.cuda.amp.autocast(enabled=True): context.

You could write a custom model, decorate the forward method of the custom model with the autocast context, and call the pretrained model inside.

This is what I am thinking as well but I am not clear on how would I only extract the backbone of a pre-trained model. Mentally, I have the following forward pass:

def forward (self, x)
    with autocast(enabled=True):
        features = self.base_model(x)
        logits = self.classifier(features)
    return logits

My question is how should I extract the base_model from let’s say, torchvision.models.resnet50(pretrained=True throwing out the fc block from it.

Using the backbone of a pretrained model seems to be orthogonal to the usage of amp and you could call the model directly in your custom model wrapper.

However, if you want to only use the backbone, replace other layers either with nn.Identity layers or check the source code of the model and see, if you could simply call the backbone directly.
Some models are implemented in this way, others are more complicated and would need to a bit more attention.

Okay. Thank you for the points. Appreciate the help.

Since you are working with a ResNet, have a look at its forward. As you can see, there is unfortunately no model.features attribute, so in this case it would be easier to replace model.classifier = nn.Identity() and just call the model directly.

1 Like

So, here’s what I did:

from torch.cuda.amp import autocast
from torch import nn

class CustomClassifier(nn.Module):
	def __init__(self, baseModel, numClasses):
		super(FoodClassier, self).__init__()

		# initialize the base model and the classification layer
		self.baseModel = baseModel
		self.classifier = nn.Linear(baseModel.classifier.in_features,

		# set the classifier of our base model to produce outputs
		# from the last convolution block
		self.baseModel.classifier = nn.Identity()

	# we decorate the `forward()` method with `autocast()` to enable
	# mixed-precision training in a distributed manner
	def forward(self, x):
		# pass the inputs through the base model and then obtain the
		# obtain the classifier outputs
		features = self.baseModel(x)
		logits = self.classifier(features)

		# return the classifier outputs
		return logits

I am using a densenet121 hence changed it to classifier from fc. However, I don’t see any speed-up in the total model training after making these changes.