Employing UNet architecture to train my model

I have succeded in creating heatmaps to annotate key points on my medical cardiograms. I want to use UNet architecture to train my model to return heatmaps for each keyponts on the image. How can i go about creating this model. I have coded like below but it has failed to run withoit errors. Please advise.

    for layer in self.decoder:
        if isinstance(layer, nn.ConvTranspose2d):
          x = layer(x)
          skip_connection = encoder_outputs.pop()
        # Adjust the size of the skip connection tensor
          skip_connection = F.interpolate(skip_connection, size=x.size()[2:], mode='bilinear', align_corners=False)
        # Ensure the number of channels matches for concatenation
          x = torch.cat([x, skip_connection], dim=1)
        else:
          x = layer(x)
    x = self.final_conv(x)
    return x

Define your model with the required arguments

model = UNet(in_channels=3, out_channels=3, channels=[64, 128, 256, 512], strides=[2, 2, 2])

Print the model architecture

print(model)

Set the model to training mode

model.train()

Define optimizer and loss function

optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

Cast images to torch.float before passing to the model

Training loop

num_epochs = 10
for epoch in range(num_epochs):
for images, true_heatmaps in data_loader:
optimizer.zero_grad()
# Cast images to torch.float before passing to the model
images = images.float()

    # Transpose the input tensor to match the expected shape (batch_size, channels, height, width)
    images = images.permute(0, 3, 1, 2)  # Assuming images has shape (batch_size, height, width, channels)


    # Predict
    pred_heatmaps = model(images)

    # Calculate loss
    loss = criterion(pred_heatmaps, true_heatmaps)

    # Backpropagation
    loss.backward()
    optimizer.step()

print(f"Epoch {epoch+1}, Loss: {loss.item()}")

What is the error you are seeing?

Traceback (most recent call last)
in <cell line: 66>()
75
76 # Predict
—> 77 pred_heatmaps = model(images)
78
79 # Calculate loss

6 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
→ 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1520 return forward_call(*args, **kwargs)
1521
1522 try:

in forward(self, x)
40 x = torch.cat([x, skip_connection], dim=1)
41 else:
—> 42 x = layer(x)
43 x = self.final_conv(x)
44 return x

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
→ 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1520 return forward_call(*args, **kwargs)
1521
1522 try:

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py in forward(self, input)
458
459 def forward(self, input: Tensor) → Tensor:
→ 460 return self._conv_forward(input, self.weight, self.bias)
461
462 class Conv3d(_ConvNd):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
454 weight, bias, self.stride,
455 _pair(0), self.dilation, self.groups)
→ 456 return F.conv2d(input, weight, bias, self.stride,
457 self.padding, self.dilation, self.groups)
458

RuntimeError: Given groups=1, weight of size [256, 512, 3, 3], expected input[4, 768, 192, 256] to have 512 channels, but got 768 channels instead

 Traceback (most recent call last)

[<ipython-input-14-58faa5b773cb>](https://localhost:8080/#) in <cell line: 66>()
     75 
     76         # Predict
---> 77         pred_heatmaps = model(images)
     78 
     79         # Calculate loss