How to freeze feature extractor and train only classifier in DistributedDataParallel?

I want to train only the last fc layer in my pretrained CNN model with distributed data parallel module.

I tried to make the whole model to eval mode and then change the fc layer to train.

model.module.eval()
model.module.fc.train()

and I got following error msg,

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/app/train_action_model_apex.py", line 466, in main_worker
    train_model(args, root_dir)
  File "/app/train_action_model_apex.py", line 235, in train_model
    trainer.train_epoch(epoch, use_amp=True)
  File "/app/trainers/action_model_trainer.py", line 202, in train_epoch
    self.optimize_model(loss_dict[self.update_loss_name], use_amp)
  File "/app/trainers/action_model_trainer.py", line 68, in optimize_model
    scaled_loss.backward()
  File "/usr/lib/python3.5/contextlib.py", line 77, in __exit__
    self.gen.throw(type, value, traceback)
  File "/usr/local/lib/python3.5/dist-packages/apex/amp/handle.py", line 117, in scale_loss
    yield (loss.float())*loss_scale
  File "/app/trainers/action_model_trainer.py", line 68, in optimize_model
    scaled_loss.backward()
  File "/usr/local/lib/python3.5/dist-packages/torch/tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/usr/local/lib/python3.5/dist-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: expected scalar type Half but found Float

How can I properly fix the problem?

It seems you are using some higher-level wrapper with amp?
Could you post a code snippet to reproduce this issue, please?

@ptrblck, thanks for your reply. I’m using amp.

Here is a code snippet to reproduce the issue.

import torch
import torch.nn as nn
from apex import amp
import torch.distributed as dist


class SomeModel(nn.Module):
    def __init__(self):
        super(SomeModel, self).__init__()
        self.conv = nn.Conv3d(
            3,
            16,
            kernel_size=(1, 3, 3),
            stride=1,
            padding=(0, 1, 1),
            bias=False)
        self.bn1 = nn.BatchNorm3d(16)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(16, 3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x



print('init process group')
dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:7001',
                            world_size=1, rank=0)

model = SomeModel().cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, )
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
print('ddp')
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)

print('model train')
# model.train() # works

model.eval()
model.module.fc.train()

x = torch.randn((5, 3, 7, 7, 7), device='cuda')
y = torch.ones((5, ), device='cuda').long()

print('model forward')
outputs = model(x)
print('calculate loss')
loss = criterion(outputs, y)
print('model backward')
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
print('optimizer step')
optimizer.step()

Also, while I make the code snippet, I found that BN cause the issue.
Without BN, no error raised, though I’m not sure it works properly as intended.

@kkjh0723 I couldn’t get your original code to work since I kept running into this error

RuntimeError: Expected tensor for argument #2 'input' to have the same device as tensor for argument #3 'weight'; but device 1 does not equal 0 (while checking arguments for slow_conv_dilated_all_cuda_template)

I added device_ids=[0] to the DistributedDataParallel constructor and the code seems to work fine now:

import torch
import torch.nn as nn
from apex import amp
import torch.distributed as dist


class SomeModel(nn.Module):
    def __init__(self):
        super(SomeModel, self).__init__()
        self.conv = nn.Conv3d(
            3,
            16,
            kernel_size=(1, 3, 3),
            stride=1,
            padding=(0, 1, 1),
            bias=False)
        self.bn1 = nn.BatchNorm3d(16)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(16, 3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x



print('init process group')
dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:7001',
                            world_size=1, rank=0)

model = SomeModel().cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, )
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
print('ddp')
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True, device_ids=[0])

print('model train')
# model.train() # works

model.eval()
model.module.fc.train()

x = torch.randn((5, 3, 7, 7, 7), device='cuda')
y = torch.ones((5, ), device='cuda').long()

print('model forward')
outputs = model(x)
print('calculate loss')
loss = criterion(outputs, y)
print('model backward')
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
print('optimizer step')
optimizer.step()

@pritamdamania87, Thanks for answering.
I tried with only 1 visible GPU using CUDA_VISIBLE_DEVICES=0 in my original code.
I also got the same error as you when multiple GPUs are visible.

And I still got the following error when I add device_ids=[0]
RuntimeError: expected scalar type Half but found Float

I wonder if different version of pytorch might cause the problem?
I’m currently using 1.1.0.

I was using the latest PyTorch 1.4 release, will try to repro this with 1.1.0

Looks like I see the same issue with 1.1.0 and 1.2.0, although it seems to work 1.3 onwards. Could you try out a version >= 1.3?

Thanks! I found it works after updating to PyTorch 1.4.