That’s not the case for nn.Module
s as the to()
operation will be applied recursively to all registered submodules, parameters, and buffers.
model = models.resnet18()
for name, module in model.named_modules():
if "bn" in name:
module.double()
print([(name, p.dtype) for name, p in model.named_parameters()])
# [('conv1.weight', torch.float32), ('bn1.weight', torch.float64), ('bn1.bias', torch.float64), ('layer1.0.conv1.weight', torch.float32), ('layer1.0.bn1.weight', torch.float64), ('layer1.0.bn1.bias', torch.float64), ('layer1.0.conv2.weight', torch.float32), ('layer1.0.bn2.weight', torch.float64), ('layer1.0.bn2.bias', torch.float64), ('layer1.1.conv1.weight', torch.float32), ('layer1.1.bn1.weight', torch.float64), ('layer1.1.bn1.bias', torch.float64), ('layer1.1.conv2.weight', torch.float32), ('layer1.1.bn2.weight', torch.float64), ('layer1.1.bn2.bias', torch.float64), ('layer2.0.conv1.weight', torch.float32), ('layer2.0.bn1.weight', torch.float64), ('layer2.0.bn1.bias', torch.float64), ('layer2.0.conv2.weight', torch.float32), ('layer2.0.bn2.weight', torch.float64), ('layer2.0.bn2.bias', torch.float64), ('layer2.0.downsample.0.weight', torch.float32), ('layer2.0.downsample.1.weight', torch.float32), ('layer2.0.downsample.1.bias', torch.float32), ('layer2.1.conv1.weight', torch.float32), ('layer2.1.bn1.weight', torch.float64), ('layer2.1.bn1.bias', torch.float64), ('layer2.1.conv2.weight', torch.float32), ('layer2.1.bn2.weight', torch.float64), ('layer2.1.bn2.bias', torch.float64), ('layer3.0.conv1.weight', torch.float32), ('layer3.0.bn1.weight', torch.float64), ('layer3.0.bn1.bias', torch.float64), ('layer3.0.conv2.weight', torch.float32), ('layer3.0.bn2.weight', torch.float64), ('layer3.0.bn2.bias', torch.float64), ('layer3.0.downsample.0.weight', torch.float32), ('layer3.0.downsample.1.weight', torch.float32), ('layer3.0.downsample.1.bias', torch.float32), ('layer3.1.conv1.weight', torch.float32), ('layer3.1.bn1.weight', torch.float64), ('layer3.1.bn1.bias', torch.float64), ('layer3.1.conv2.weight', torch.float32), ('layer3.1.bn2.weight', torch.float64), ('layer3.1.bn2.bias', torch.float64), ('layer4.0.conv1.weight', torch.float32), ('layer4.0.bn1.weight', torch.float64), ('layer4.0.bn1.bias', torch.float64), ('layer4.0.conv2.weight', torch.float32), ('layer4.0.bn2.weight', torch.float64), ('layer4.0.bn2.bias', torch.float64), ('layer4.0.downsample.0.weight', torch.float32), ('layer4.0.downsample.1.weight', torch.float32), ('layer4.0.downsample.1.bias', torch.float32), ('layer4.1.conv1.weight', torch.float32), ('layer4.1.bn1.weight', torch.float64), ('layer4.1.bn1.bias', torch.float64), ('layer4.1.conv2.weight', torch.float32), ('layer4.1.bn2.weight', torch.float64), ('layer4.1.bn2.bias', torch.float64), ('fc.weight', torch.float32), ('fc.bias', torch.float32)]