Change all Conv2d and BatchNorm2d to their 3d counterpart

Hi,

I have a pretrained ResNet model and want to change all of its Conv2d and BatchNorm2d layers to their respective 3d counterparts.

Is their an “automatic”, iterative way of doing this? I really don’t want to replace every module manually because I also want to try different architectures.

So far I tried this, but it does not change the modules:

for module in model.modules():
    if(isinstance(module, nn.Conv2d)):
        kernel_size = module.kernel_size[0]
        stride = module.stride[0]
        padding = module.padding[0]
        weight = module.weight.unsqueeze(2) / kernel_size
        weight = torch.cat([weight for _ in range(0, kernel_size)], dim=2)
        bias = module.bias

        if(bias is None):
            module = nn.Conv3d(in_channels=module.weight.shape[1], out_channels=module.weight.shape[0],
                               kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
        else:
            module = nn.Conv3d(in_channels=module.weight.shape[1], out_channels=module.weight.shape[0],
                               kernel_size=kernel_size, padding=padding, stride=stride, bias=True)
            module.bias = bias

        module.weight.data = weight

    elif(isinstance(module, nn.BatchNorm2d)):
        weight = module.weight
        bias = module.bias
        module = nn.BatchNorm3d(weight.shape[0])
        module.weight = weight
        module.bias = bias

The model still has only Conv2d and BN2d layers afterwards. Any alternative in iterating over all modules?

Alright, I just stumbled across the model.named_modules() method that also returns the name of the layer. Here is the modified code to achieve the automatic replacement if anyone is interested:

for name, module in model.named_modules():
    if(isinstance(module, nn.Conv2d)):
        kernel_size = module.kernel_size[0]
        stride = module.stride[0]
        padding = module.padding[0]
        weight = module.weight.unsqueeze(2) / kernel_size
        weight = torch.cat([weight for _ in range(0, kernel_size)], dim=2)
        bias = module.bias

        if(bias is None):
            modules[name] = nn.Conv3d(in_channels=module.weight.shape[1], out_channels=module.weight.shape[0],
                               kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
        else:
            modules[name] = nn.Conv3d(in_channels=module.weight.shape[1], out_channels=module.weight.shape[0],
                               kernel_size=kernel_size, padding=padding, stride=stride, bias=True)
            modules[name].bias = bias

            modules[name].weight.data = weight

    elif(isinstance(module, nn.BatchNorm2d)):
        weight = module.weight
        bias = module.bias
        modules[name] = nn.BatchNorm3d(weight.shape[0])
        modules[name].weight = weight
        modules[name].bias = bias

for name in modules:
    parent_module = model
    objs = name.split(".")
    if len(objs) == 1:
        model.__setattr__(name, modules[name])
        continue

    for obj in objs[:-1]:
        parent_module = parent_module.__getattr__(obj)

    parent_module.__setattr__(objs[-1], modules[name])