Mismatch in dimensions when using Bayesian FPN and RL

I am implementing a computer vision project. I have used an FPN (with ResNet50 backbone) and BayesianFPN in this project. This network is under a Reinforcement Learning Agent. Upon implementation, it throws me a RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3.

  1. What made the code throw this error?
  2. Is it because of the mismatch in the training image (RGB image) and validation image (Binary Mask)?
  3. Is there some internal dimension change I am missing?

Below is the code. This is the BayesianFPNwithRL class I want to implement.

PS: I tried reshaping the weights tensor. It did not work. Even unsqueezing or expanding it did not.

FOR REFERENCE: RGB image dimensions - 1280x720px; Binary mask dimensions: 1280x720px

# Bayesian FPN with RL
class BayesianFPNWithRL(nn.Module):
    def __init__(self, backbone_with_fpn, rl_agent, dropout_p=0.2):
        super(BayesianFPNWithRL, self).__init__()
        self.backbone_with_fpn = backbone_with_fpn
        self.dropout = nn.Dropout(p=dropout_p)
        self.rl_agent = rl_agent

    def forward(self, x, mc_samples=10, train_rl=False):
        fpn_outputs = self.backbone_with_fpn(x)
        keys = list(fpn_outputs.keys())
        features = [fpn_outputs[key] for key in keys]
        common_size = features[0].shape[2:] 
        features = [F.interpolate(f, size=common_size, mode="nearest") for f in features]
        if not self.training:
            sampled_features = []
            for _ in range(mc_samples):
                sampled_features.append([self.dropout(f) for f in features])
            features = [
                torch.mean(torch.stack([sample[i] for sample in sampled_features]), dim=0)
                for i in range(len(features))
        global_features = [f.mean(dim=(2, 3)) for f in features]  
        rl_input = torch.cat(global_features, dim=1) 
        action, log_prob = self.rl_agent.select_action(rl_input)
        weights = torch.zeros(len(features), device=x.device)
        weights[action] = 1.0  
        selected_features = sum(w * f for w, f in zip(weights, features))
        if train_rl:
            return selected_features, log_prob
        return selected_features

if __name__ == "__main__":

    resnet = resnet50(weights = ResNet50_Weights.DEFAULT)
    return_layers = {
        'layer1': '0',
        'layer2': '1',
        'layer3': '2',
        'layer4': '3'
    in_channels_list = [256, 512, 1024, 2048]
    out_channels = 256
    backbone_with_fpn = BackboneWithFPN(resnet, return_layers, in_channels_list, out_channels)
    rl_agent = RLAgent(input_dim=1280, hidden_dim=512, action_space=4)
    bayesian_fpn_rl = BayesianFPNWithRL(backbone_with_fpn, rl_agent).to('cuda')
    optimizer = torch.optim.Adam(bayesian_fpn_rl.parameters(), lr=1e-4)
    for epoch in range(10):  
        for images, ground_truth_masks in dataloader:
            images, ground_truth_masks = images.to('cuda'), ground_truth_masks.to('cuda')
            model_output, log_prob = bayesian_fpn_rl(images, train_rl=True)
            predicted_mask = (model_output > 0.5).int()
            reward = compute_reward(predicted_mask, ground_truth_masks)
            loss = rl_loss(log_prob, reward)
            print(f"Epoch [{epoch + 1}], Loss: {loss.item():.4f}, Reward: {reward:.4f}")