How to convert submodules to a different dtype simply

I want to train a model where some modules are in bf16 and others in float32.

If the model is in bf16, then it seems to me using torch.autocast(enabled=False) for the relevant modules is not enough as the weights will still be in bf16.

But I can’t just iterate through the named_modules() and change the dtype because .to() creates a copy.

Ideally I could use an inplace version like .to_() but that doesn’t seem to exist.

Is there a simple way to do this? (that doesn’t involve something like keeping track of the parent modules and then assigning)

That’s not the case for nn.Modules as the to() operation will be applied recursively to all registered submodules, parameters, and buffers.

model = models.resnet18()

for name, module in model.named_modules():
    if "bn" in name:
        module.double()
        
print([(name, p.dtype) for name, p in model.named_parameters()])
# [('conv1.weight', torch.float32), ('bn1.weight', torch.float64), ('bn1.bias', torch.float64), ('layer1.0.conv1.weight', torch.float32), ('layer1.0.bn1.weight', torch.float64), ('layer1.0.bn1.bias', torch.float64), ('layer1.0.conv2.weight', torch.float32), ('layer1.0.bn2.weight', torch.float64), ('layer1.0.bn2.bias', torch.float64), ('layer1.1.conv1.weight', torch.float32), ('layer1.1.bn1.weight', torch.float64), ('layer1.1.bn1.bias', torch.float64), ('layer1.1.conv2.weight', torch.float32), ('layer1.1.bn2.weight', torch.float64), ('layer1.1.bn2.bias', torch.float64), ('layer2.0.conv1.weight', torch.float32), ('layer2.0.bn1.weight', torch.float64), ('layer2.0.bn1.bias', torch.float64), ('layer2.0.conv2.weight', torch.float32), ('layer2.0.bn2.weight', torch.float64), ('layer2.0.bn2.bias', torch.float64), ('layer2.0.downsample.0.weight', torch.float32), ('layer2.0.downsample.1.weight', torch.float32), ('layer2.0.downsample.1.bias', torch.float32), ('layer2.1.conv1.weight', torch.float32), ('layer2.1.bn1.weight', torch.float64), ('layer2.1.bn1.bias', torch.float64), ('layer2.1.conv2.weight', torch.float32), ('layer2.1.bn2.weight', torch.float64), ('layer2.1.bn2.bias', torch.float64), ('layer3.0.conv1.weight', torch.float32), ('layer3.0.bn1.weight', torch.float64), ('layer3.0.bn1.bias', torch.float64), ('layer3.0.conv2.weight', torch.float32), ('layer3.0.bn2.weight', torch.float64), ('layer3.0.bn2.bias', torch.float64), ('layer3.0.downsample.0.weight', torch.float32), ('layer3.0.downsample.1.weight', torch.float32), ('layer3.0.downsample.1.bias', torch.float32), ('layer3.1.conv1.weight', torch.float32), ('layer3.1.bn1.weight', torch.float64), ('layer3.1.bn1.bias', torch.float64), ('layer3.1.conv2.weight', torch.float32), ('layer3.1.bn2.weight', torch.float64), ('layer3.1.bn2.bias', torch.float64), ('layer4.0.conv1.weight', torch.float32), ('layer4.0.bn1.weight', torch.float64), ('layer4.0.bn1.bias', torch.float64), ('layer4.0.conv2.weight', torch.float32), ('layer4.0.bn2.weight', torch.float64), ('layer4.0.bn2.bias', torch.float64), ('layer4.0.downsample.0.weight', torch.float32), ('layer4.0.downsample.1.weight', torch.float32), ('layer4.0.downsample.1.bias', torch.float32), ('layer4.1.conv1.weight', torch.float32), ('layer4.1.bn1.weight', torch.float64), ('layer4.1.bn1.bias', torch.float64), ('layer4.1.conv2.weight', torch.float32), ('layer4.1.bn2.weight', torch.float64), ('layer4.1.bn2.bias', torch.float64), ('fc.weight', torch.float32), ('fc.bias', torch.float32)]
1 Like

Ah cool! Thank you!

What if I wanted to convert specific parameters? When I try and assign like

module.weight = module.weight.to(torch.bfloat16)

I get

TypeError: cannot assign 'torch.cuda.BFloat16Tensor' as parameter 'weight' (torch.nn.Parameter or 
None expected)

Create a new nn.Parameter as the error message explains.

1 Like

Might it make more sense to apply .to() to weight.data ? Would that allow me to skip recreating the parameter?

module.weight = nn.Parameter(module.weight.to(torch.bfloat16), requires_grad=False)

this doesn’t seem to work as expected, i get OOM while doing this (which implies this is not overwriting the old parameter weights)

using .data seems to do the trick