Using swin transformers on timm library in image segmentation

How do I modify the output shape of a TIMM model for image segmentation in the medical domain using the Kvasir-SEG dataset and PyLops? I have defined the num_classes=0 in the TIMM create model, but during training the output size of the logits is torch.Size([32, 768]). I need to change the model output to match the mask shape, which is torch.Size([32, 3, 224, 224]). Is there a way to do this within the TIMM library or do I need to implement additional modifications to the model architecture so that I can compute the loss?

# Load the model
model=timm.create_model("swin_s3_tiny_224", pretrained=True,num_classes=0)

# Set the model to evaluation mode

# Move the model to the device

# Define the loss function and the optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

the training code

num_epochs = 2

# Set the lists to store the training and validation losses
train_losses = []
valid_losses = []
# Loop over the epochs
for epoch in range(num_epochs):
    # Set the model to training mode

    # Set the training loss to 0
    train_loss = 0

    # Loop over the training data
    for image, mask in dataloader_train:

        # Move the data to the device
        image =
        mask =

        #image shape= torch.Size([32, 3, 224, 224])
        #mask shape= torch.Size([32, 3, 224, 224])

        # Zero the gradients

        # Compute the logits
        #logits = model(image)
        #logits shape torch.Size([32, 768])

        # Compute the loss
        loss = loss_fn(logits, mask)

        # Backpropagate the loss

        # Update the parameters

        # Accumulate the training loss
        train_loss += loss.item()