Unable to store the U-net model in cuda device

I am building a CNN model with residual network achidecture for input images with different dimension. At the same time, I want to make sure the output has the same dimension as the input.

In the forward pass of the Generator object, I declare a list object size_list to store the dimension of the tensors during its down-sampling steps. So that when the tensor is unsampling, I can pass the required size from the list to the forward function of ConvTranspose2D in the argument output_size.

However, once I added this list object, I run into this error even though I explicitly pass both input and model into my cuda device.

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

The error smply would not show up if I don’t use the list object, or run the model in cpu device. I don’t know how the list object influence my model getting into the cuda device. Any help is appreciated. Thank you.

class Transpose_Block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, output_size, padding=1):
        super(Transpose_Block, self).__init__()
        self.output_size = output_size
        self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

    def centercrop(self, x):
        dimension = x.shape
        top = 0 + math.floor((dimension[-2] - self.output_size[0]) / 2)
        bottom = dimension[-2] - math.ceil((dimension[-2] - self.output_size[0]) / 2)
        left = 0 + math.floor((dimension[-1] - self.output_size[1]) / 2)
        right = dimension[-1] - math.ceil((dimension[-1] - self.output_size[1]) / 2)
        print(top, bottom, left, right)
        return x[:,:, top:bottom, left:right]

    def forward(self, x):
        try:
            out = self.block(x, output_size=self.output_size)
        except ValueError:
            out = self.block(x)
            out = self.centercrop(out)
            print('manual crop')
        return out


class Generator(nn.Module):
    def __init__(self, input_channels=INPUT_CHANNELS, channels_list=[64, 128, 256, 512], n_res=NUM_RES):
        super(Generator, self).__init__()
        self.input_channels = input_channels
        self.channels_list = channels_list
        self.list = []

        self.initial = nn.Sequential(
            nn.Conv2d(input_channels, channels_list[0], kernel_size=7, stride=1, padding=3, padding_mode='reflect'),
            nn.InstanceNorm2d(channels_list[0]),
            nn.ReLU(),
        )

        self.down_layers = nn.ModuleList()
        self.res_layers = nn.ModuleList()

        self.down_layers.append(self.initial)
        for down in range(len(channels_list)-1):
            self.down_layers.append(Gen_Block(channels_list[down], channels_list[down + 1]))
        for res in range(n_res):
            self.res_layers.append(Res_Block(channels_list[-1]))


    def forward(self, x):
        # size_tensor = torch.empty((len(self.channels_list),2))
        for (index, down) in enumerate(self.down_layers):
            # size_tensor[index] = torch.tensor(x.size()[-2:])
            self.list.append(x.size()[-2:])
            x = down(x)
        for res in self.res_layers:
            x = res(x)
        for up in range(len(self.channels_list)-1, 0, -1):
            x = Transpose_Block(self.channels_list[up], self.channels_list[up-1], kernel_size=3, stride=2, output_size=self.list[up], padding=1)(x)
            x = nn.InstanceNorm2d(self.channels_list[up-1])(x)
            x = nn.ReLU()(x)
        x = Transpose_Block(self.channels_list[0], self.input_channels, kernel_size=7, stride=1, output_size=self.list[0], padding=3)(x)
        x = nn.Tanh()(x)

        return x


def gen_test():
    x = torch.randn(1, 1, 100, 150).to(DEVICE)
    model = Generator().to(DEVICE)
    print(model)
    print(model(x).shape)



if __name__ == '__main__':
    gen_test()
    pass

Hi Alan!

When you move your model (a Generator) to your gpu, pytorch goes
through (at least roughly speaking) your model’s properties, and moves
the various layers, etc., to the gpu.

The problem is that when you do this no instance of Transpose_Block
has yet been created, so there is nothing for pytorch to move to the gpu
nor to even know about.

Instead, in Generator’s forward() method, you create Transpose_Block
on the fly, so it – and its “layer,” ConvTranspose2d, is created (by default)
on the cpu, and never moved to the gpu. Then when the gpu tensors try to
flow through it, you get your error.

As an aside, every time you create your Transpose_Block on the fly,
you instantiate a new ConvTranspose2d with new, different, randomly
initialized values for its weight and bias. (These parameters are never
added to your optimizer, so they never get updated by the optimizer – but
that wouldn’t matter because you discard them anyway.)

Best.

K. Frank

Thank you so much for your reminder!! I didn’t think about this.

I guess the Transpose_Block cannot be instantiatized inside forward method too because it otherwise would be re-created in every forward pass? In such case where I can only get the size_list after passing the forward method, what is the solution if I still want to pass the output_size argument when instantiatizing the Transpose_Block? Do I have to manually calculate the size_list according to the kernel_size, padding etc for every steps in the down-sampling first, before the forward method?

Hi Alan!

Yes, that is correct.

You do pass in an output_size argument when you instantiate
Transpose_Block (but it doesn’t make sense to try to do it this way).

However, you do not pass in output_size when you instantiate the
ConvTranspose2d block property of Transpose_Block. Instead,
you pass in output_size as an additional argument when you apply
the instance of ConvTranspose2d to x in Transpose_Block’s forward()
method.

So you don’t need to know output_size when you instantiate
Transpose_Block and / or ConvTranspose2d – you just need to know
it when you apply an already-instantiated Transpose_Block to x in the
for up loop in Generator’s forward() method.

No, I don’t believe so. I don’t fully understand your logic, but I believe
that you can continue to compute the self.list of output_sizes in
the for (index, down) loop of Generator’s forward() method, and
then pass those output_sizes in as a newly-added argument to
Transpose_Block’s forward() method which will then pass it along
into the call where you apply that Transpose_Block’s
ConvTranspose2d to x.

If you store Generator’s Transpose_Blocks in a ModuleList that is
a property of Generator, and each ConvTranspose2d is a property of
its enclosing Transpose_Block, then when you call
model = Generator().to(DEVICE) the Transpose_Blocks and their
contained ConvTranspose2ds should all be properly moved to the gpu
(resolving your original error).

Best.

K. Frank

1 Like

Yes you are right. I got lost and confused mid-way in my thinking. I just need to separate the forward method of my Transpose_Block from others so that the output_size argument can be passed into it during the forward method of Generator. Thank you so much.