RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same - Conv VAE

For PyTorch 1.12 and Python 3.10, I have a Convolutional Variational Autoencoder which accepts train, target sample as (90, 90, 3), (90, 90, 3). The sample is transposed to be (3, 90, 90) and then passed to the network as follows. I have access to 4 NVIDIA GTX TITAN X cards and I am trying to use all of them with the code (within Jupyter notebook) as-

# Specify GPU to be used-
%env CUDA_DEVICE_ORDER = PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES = 0, 1, 2, 3

# Check if there are multiple devices (i.e., GPU cards)-
print(f"Number of GPU(s) available = {torch.cuda.device_count()}")

if torch.cuda.is_available():
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"Current GPU name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
    print("PyTorch does not have access to GPU")
'''
Number of GPU(s) available = 4
Current GPU: 0
Current GPU name: NVIDIA GeForce GTX TITAN X
'''

# Device configuration-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Available device is {device}')
# Available device is cuda

The code for the class definition and an initial forward pass is:

class ConvVAE(nn.Module):
    def __init__(self, latent_space = 100):
        super(ConvVAE, self).__init__()
        
        self.latent_space = latent_space
        
        # Define encoder-
        self.conv_encoder1 = ConvEncoder_block(
            ip_channels = 3, op_channels = 64,
            reduce_spatial_dims = False
        )
        self.conv_encoder2 = ConvEncoder_block(
            ip_channels = 64, op_channels = 128,
            reduce_spatial_dims = True
        )
        self.conv_encoder3 = ConvEncoder_block(
            ip_channels = 128, op_channels = 256,
            reduce_spatial_dims = True
        )
        self.conv_encoder4 = ConvEncoder_block(
            ip_channels = 256, op_channels = 512,
            reduce_spatial_dims = True
        )
        self.conv_layer = nn.Conv2d(
            in_channels = 512, out_channels = 1024,
            kernel_size = 3, stride = 2,
            padding = 1, bias = True
        )
        self.conv_layer_bn = nn.BatchNorm2d(num_features = 1024)
        # output shape: torch.Size([32, 1024, 6, 6])
        # 6x6x1024 = 36864

        # Further reduce spatial dimensions-
        self.conv_reduction1 = nn.Conv2d(
            in_channels = 1024, out_channels = 512,
            kernel_size = 3, stride = 1,
            padding = 0, bias = False
        )
        self.conv_reduction1_bn = nn.BatchNorm2d(num_features = 512)
        self.conv_reduction2 = nn.Conv2d(
            in_channels = 512, out_channels = 256,
            kernel_size = 3, stride = 1,
            padding = 0, bias = False
        )
        # Final encoder output shape: (256, 2, 2)
        
        # Define mean & log-variance vectors to represent latent space 'z'-
        self.mu = torch.nn.Linear(in_features = (256 * 2 * 2), out_features = self.latent_space)
        self.log_var = torch.nn.Linear(in_features = (256 * 2 * 2), out_features = self.latent_space)
        
        # Define decoder-
        
        # Upsample/increase spatial dimensions-
        self.conv_upsample1 = nn.ConvTranspose2d(
            in_channels = 256, out_channels = 512,
            kernel_size = 2, stride = 2,
            padding = 0, bias = False
        )
        self.conv_upsample1_bn = nn.BatchNorm2d(num_features = 512)
        self.conv_upsample2 = nn.ConvTranspose2d(
            in_channels = 512, out_channels = 1024,
            kernel_size = 2, stride = 2,
            padding = 1, bias = False
        )
        self.conv_upsample2_bn = nn.BatchNorm2d(num_features = 1024)
        
        
        self.conv_decoder1 = ConvDecoder_block(
            ip_channels = 1024, op_channels = 512,
            padding = 0, stride = 2
        )
        self.conv_decoder2 = ConvDecoder_block(
            ip_channels = 512, op_channels = 256,
            padding = 0, stride = 2
        )
        self.conv_decoder3 = ConvDecoder_block(
            ip_channels = 256, op_channels = 128,
            padding = 1, stride = 2
        )
        self.conv_decoder4 = ConvDecoder_block(
            ip_channels = 128, op_channels = 64,
            padding = 1, stride = 2
        )
        self.final_op = nn.Conv2d(
            in_channels = 64, out_channels = 3,
            kernel_size = 3, stride = 1,
            padding = 1, bias = True
        )
    
   
    def reparameterize(self, mu, log_var):
        '''
        Input arguments:
        1. mu - mean coming from the encoder's latent space
        2. log_var - log variance coming from the encoder's latent space
        '''
        # Compute standard deviation using 'log_var'-
        std = torch.exp(0.5 * log_var)
        
        # 'eps' samples from a normal standard distribution to add
        # stochasticity to the sampling process-
        eps = torch.randn_like(std)
        
        # Reparameterization trick - sample as if it's from the input
        # space-
        z = mu + (std * eps)
        
        return z
    
    
    def shape_computation(self, x):
        x = x.to(0)
        print(f"Input shape: {x.shape}")
        x = self.conv_encoder1(x)
        print(f"Encoder block1 output shape: {x.shape}")
        x = self.conv_encoder2(x)
        print(f"Encoder block2 output shape: {x.shape}")
        x = self.conv_encoder3(x)
        print(f"Encoder block3 output shape: {x.shape}")
        x = self.conv_encoder4(x)
        print(f"Encoder block4 output shape: {x.shape}")
        x = F.leaky_relu(self.conv_layer_bn(self.conv_layer(x)))
        print(f"Conv layer output shape: {x.shape}")
        x = F.leaky_relu(self.conv_reduction1_bn(self.conv_reduction1(x)))
        print(f"Conv reduction1 output shape: {x.shape}")
        x = F.leaky_relu(self.conv_reduction2(x))
        print(f"Final conv output shape: {x.shape}")
        x = torch.flatten(torch.randn((256, 2, 2)))
        print(f"Flattened encoder output shape: {x.shape}")
        
        # parallel_net = parallel_net.to(0)
        # model.to(device)
        # mu = self.mu(x).to(0)
        # logvar = self.log_var(x).to(0)
        mu = self.mu(x)
        logvar = self.log_var(x)
        
        print(f"mean shape: {mu.shape} & logvar shape: {logvar.shape}")
        z = self.reparameterize(mu = mu, log_var = logvar)
        print(f"lv shape: {z.shape}")
        z = torch.reshape(torch.randn(batch_size, 2 * 2 * 256), (-1, 256, 2, 2))
        print(f"Reshaped lv shape: {z.shape}")
        
        x = F.leaky_relu(self.conv_upsample1_bn(self.conv_upsample1(z)))
        print(f"Conv upsample1 output shape: {x.shape}")
        x = F.leaky_relu(self.conv_upsample2_bn(self.conv_upsample2(x)))
        print(f"Conv upsample2 output shape: {x.shape}")
        # x = self.conv_decoder1(x)
        x = self.conv_decoder1(x)
        print(f"Decoder block1 output shape: {x.shape}")
        x = self.conv_decoder2(x)
        print(f"Decoder block2 output shape: {x.shape}")
        x = self.conv_decoder3(x)
        print(f"Decoder block3 output shape: {x.shape}")
        x = self.conv_decoder4(x)
        print(f"Decoder block4 output shape: {x.shape}")
        x = self.final_op(x)
        print(f"Final output shape: {x.shape}")
        
        return None
    
    
    def forward(self, x):
        x = self.conv_encoder1(x)
        x = self.conv_encoder2(x)
        x = self.conv_encoder3(x)
        x = self.conv_encoder4(x)
        x = F.leaky_relu(self.conv_layer_bn(self.conv_layer(x)))
        x = F.leaky_relu(self.conv_reduction1_bn(self.conv_reduction1(x)))
        x = F.leaky_relu(self.conv_reduction2(x))
        
        # Flatten output for linear layer-
        x = torch.flatten(torch.randn((256, 2, 2)))
        
        # Get mean & log-var vectors representing latent space distribution-
        mu, logvar = self.mu(x), self.log_var(x)
        
        # Obtain the latent vector 'z' using reparameterization-        
        z = self.reparameterize(mu = mu, log_var = logvar)
        
        # Reshape into a conv input-
        z = torch.reshape(torch.randn(batch_size, 2 * 2 * 256), (-1, 256, 2, 2))

        x = F.leaky_relu(self.conv_upsample1_bn(self.conv_upsample1(z)))
        x = F.leaky_relu(self.conv_upsample2_bn(self.conv_upsample2(x)))
        x = self.conv_decoder1(x)
        x = self.conv_decoder2(x)
        x = self.conv_decoder3(x)
        x = self.conv_decoder4(x)
        
        # Since target is in range [0, 1]-
        x = torch.sigmoid(self.final_op(x))
        return x, mu, logvar


# Initialize (encoder) model-
model = ConvVAE(latent_space = 256 * 2 * 2)

# Move to CUDA GPU-
# x = x.to(0)
# y = y.to(0)

x = x.to(device)
y = y.to(device)

# Get a batch of training samples-
x, y = next(iter(train_loader))

# Move to CUDA GPU-
# x = x.to(0)
# y = y.to(0)
# x = x.to(device)
# y = y.to(device)

x.shape, y.shape
# (torch.Size([32, 90, 90, 3]), torch.Size([32, 90, 90, 3]))

# Swap dimensions to have channels first (after batch-size)-
x = x.permute((0, 3, 1, 2))
y = y.permute((0, 3, 1, 2))

x.shape, y.shape
# (torch.Size([32, 3, 90, 90]), torch.Size([32, 3, 90, 90]))

# Sanity check-
model.shape_computation(x)
'''
Input shape: torch.Size([32, 3, 90, 90])
Encoder block1 output shape: torch.Size([32, 64, 90, 90])
Encoder block2 output shape: torch.Size([32, 128, 45, 45])
Encoder block3 output shape: torch.Size([32, 256, 23, 23])
Encoder block4 output shape: torch.Size([32, 512, 12, 12])
Conv layer output shape: torch.Size([32, 1024, 6, 6])
Conv reduction1 output shape: torch.Size([32, 512, 4, 4])
Final conv output shape: torch.Size([32, 256, 2, 2])
Flattened encoder output shape: torch.Size([1024])
mean shape: torch.Size([1024]) & logvar shape: torch.Size([1024])
lv shape: torch.Size([1024])
Reshaped lv shape: torch.Size([32, 256, 2, 2])
Conv upsample1 output shape: torch.Size([32, 512, 4, 4])
Conv upsample2 output shape: torch.Size([32, 1024, 6, 6])
Decoder block1 output shape: torch.Size([32, 512, 12, 12])
Decoder block2 output shape: torch.Size([32, 256, 24, 24])
Decoder block3 output shape: torch.Size([32, 128, 46, 46])
Decoder block4 output shape: torch.Size([32, 64, 90, 90])
Final output shape: torch.Size([32, 3, 90, 90])
'''

# Get (random) predictions using model-
out, mu, logvar = model(x)

y.shape, out.shape
# (torch.Size([32, 3, 90, 90]), torch.Size([32, 3, 90, 90]))

mu.shape, logvar.shape
# (torch.Size([1024]), torch.Size([1024]))

But, when I try and move all of the data and model to GPU by uncommenting the following lines above-

model.to(device)

x = x.to(device)
y = y.to(device)

model.shape_computation(x)

gives the following output and then error-

Input shape: torch.Size([32, 3, 90, 90])
Encoder block1 output shape: torch.Size([32, 64, 90, 90])
Encoder block2 output shape: torch.Size([32, 128, 45, 45])
Encoder block3 output shape: torch.Size([32, 256, 23, 23])
Encoder block4 output shape: torch.Size([32, 512, 12, 12])
Conv layer output shape: torch.Size([32, 1024, 6, 6])
Conv reduction1 output shape: torch.Size([32, 512, 4, 4])
Final conv output shape: torch.Size([32, 256, 2, 2])
Flattened encoder output shape: torch.Size([1024])

--------------------------------------------------------------------------- RuntimeError Traceback (most recent call
last) Cell In [44], line 1
----> 1 model.shape_computation(x)

Cell In [35], line 135, in ConvVAE.shape_computation(self, x)
129 print(f"Flattened encoder output shape: {x.shape}“)
131 # parallel_net = parallel_net.to(0)
132 # model.to(device)
133 # mu = self.mu(x).to(0)
134 # logvar = self.log_var(x).to(0)
→ 135 mu = self.mu(x)
136 logvar = self.log_var(x)
138 print(f"mean shape: {mu.shape} & logvar shape: {logvar.shape}”)

File
~/anaconda3/envs/pytorch-cuda/lib/python3.10/site-packages/torch/nn/modules/module.py:1130,
in Module._call_impl(self, *input, **kwargs) 1126 # If we don’t
have any hooks, we want to skip the rest of the logic in 1127 #
this function, and just call forward. 1128 if not
(self._backward_hooks or self._forward_hooks or
self._forward_pre_hooks or _global_backward_hooks 1129 or
_global_forward_hooks or _global_forward_pre_hooks):
→ 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks,
non_full_backward_hooks = ,

File
~/anaconda3/envs/pytorch-cuda/lib/python3.10/site-packages/torch/nn/modules/linear.py:114,
in Linear.forward(self, input)
113 def forward(self, input: Tensor) → Tensor:
→ 114 return F.linear(input, self.weight, self.bias)

RuntimeError: Expected all tensors to be on the same device, but found
at least two devices, cpu and cuda:0! (when checking argument for
argument mat2 in method wrapper_mm)

Since, I am moving ‘x’, ‘y’ and ‘model’ to ‘device’, why am I getting this error? Also, I tried to move ‘mu’, ‘logvar’ and ‘z’ to ‘device’ as well, but this error still persists.

In your shape_computation method you are creating a new x tensor from randn on the CPU and are overriding the old activation tensor:

x = torch.flatten(torch.randn((256, 2, 2)))

Remove this line of code or flatten the previously computed x tensor and it should work.

1 Like