Dynamically set Conv2d based on input channels

I am currently doing a module that initializes a Conv2d and BatchNorm2d outside the __init__. Code works fine in cpu, but when moving module to cuda these conv and bn aren’t moved properly. Code:

import torch
import torch.nn as nn

class ExpandChannels(nn.Module):
    def __init__(self, num_classes: int = None):
        super(ExpandChannels, self).__init__()
        self.num_classes = num_classes
        self.conv = None
        self.bn = None

    def reset_parameters(self, x):
        self.conv = nn.Conv2d(x.size(1), self.num_classes, kernel_size=1)
        self.bn = nn.BatchNorm2d(self.num_classes)

    def forward(self, x):
        if self.conv is None:
        x = self.conv(x)
        x = self.bn(x)
        return x

m = ExpandChannels(100).cuda()
m(torch.randn(4, 3, 28, 28).cuda())

Error: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

You are creating the new self.conv and self.bn layers inside the forward pass without specifying the device, so they will be created on the CPU by default.
To properly push them to the GPU, you could use:

    def reset_parameters(self, x):
        self.conv = nn.Conv2d(x.size(1), self.num_classes, kernel_size=1).to(x.device)
        self.bn = nn.BatchNorm2d(self.num_classes).to(x.device)

Additionally, you could also check the lazy modules, e.g. LazyConv2d, which perform a similar approach.

1 Like