Intel OpenFL - RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x512 and 2048x4096)

I am trying to run my notebook (that works fine on google colab or other similar platforms) on Intel OpenFL, the new framework for FL of Intel.
I am using MNIST with this transformation:

trf = transforms.Compose(
     [transforms.Resize(32),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
     ]) 

and this is my net:

class Net(nn.Module):
     def __init__(self):
         super(Net, self).__init__()

        # calculate same padding:
        # (w - k + 2*p)/s + 1 = o
        # => p = (s(o-1) - w + k)/2

         self.block_1 = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=64,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      # (1(32-1)- 32 + 3)/2 = 1
                      padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
         )

         self.block_2 = nn.Sequential(
            nn.Conv2d(in_channels=64,
                      out_channels=128,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128,
                      out_channels=128,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
         )   

         self.block_3 = nn.Sequential(
            nn.Conv2d(in_channels=128,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
         )

         self.block_4 = nn.Sequential(
            nn.Conv2d(in_channels=256,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
         )   
          
         self.classifier = nn.Sequential(
            nn.Linear(2048, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.65),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.65),
            nn.Linear(4096, classes) 
        )

         for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
#                 nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    m.bias.detach().zero_()

        # self.avgpool = nn.AdaptiveAvgPool2d((7, 7))

     def forward(self, x):

        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        x = self.block_4(x)
        # x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

But I have this error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/var/folders/pm/nj7yy3px76n6b62knyrn8r_40000gn/T/ipykernel_10337/666012611.py in <module>
----> 1 final_model = fx.run_experiment(collaborators,{'aggregator.settings.rounds_to_train':5})

/opt/anaconda3/envs/my_env/lib/python3.7/site-packages/openfl/native/native.py in run_experiment(collaborator_dict, override_config)
    282         for col in plan.authorized_cols:
    283             collaborator = collaborators[col]
--> 284             collaborator.run_simulation()
    285 
    286     # Set the weights for the final model

/opt/anaconda3/envs/my_env/lib/python3.7/site-packages/openfl/component/collaborator/collaborator.py in run_simulation(self)
    170                 self.logger.info(f'Received the following tasks: {tasks}')
    171                 for task in tasks:
--> 172                     self.do_task(task, round_number)
    173                 self.logger.info(f'All tasks completed on {self.collaborator_name} '
    174                                  f'for round {round_number}...')

/opt/anaconda3/envs/my_env/lib/python3.7/site-packages/openfl/component/collaborator/collaborator.py in do_task(self, task, round_number)
    245             round_num=round_number,
    246             input_tensor_dict=input_tensor_dict,
--> 247             **kwargs)
    248 
    249         # Save global and local output_tensor_dicts to TensorDB

/opt/anaconda3/envs/my_env/lib/python3.7/site-packages/openfl/federated/task/runner_pt.py in validate(self, col_name, round_num, input_tensor_dict, use_tqdm, **kwargs)
    106                 data, target = pt.tensor(data).to(self.device), pt.tensor(
    107                     target).to(self.device, dtype=pt.int64)
--> 108                 output = self(data)
    109                 # get the index of the max log-probability
    110                 pred = output.argmax(dim=1, keepdim=True)

/opt/anaconda3/envs/my_env/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

/var/folders/pm/nj7yy3px76n6b62knyrn8r_40000gn/T/ipykernel_10337/3611293808.py in forward(self, x)
    125         # x = self.avgpool(x)
    126         x = x.view(x.size(0), -1)
--> 127         x = self.classifier(x)
    128         return x

/opt/anaconda3/envs/my_env/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

/opt/anaconda3/envs/my_env/lib/python3.7/site-packages/torch/nn/modules/container.py in forward(self, input)
    117     def forward(self, input):
    118         for module in self:
--> 119             input = module(input)
    120         return input
    121 

/opt/anaconda3/envs/my_env/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

/opt/anaconda3/envs/my_env/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
     92 
     93     def forward(self, input: Tensor) -> Tensor:
---> 94         return F.linear(input, self.weight, self.bias)
     95 
     96     def extra_repr(self) -> str:

/opt/anaconda3/envs/my_env/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1751     if has_torch_function_variadic(input, weight):
   1752         return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1753     return torch._C._nn.linear(input, weight, bias)
   1754 
   1755 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x512 and 2048x4096)

The shape mismatch seems to be raised in the first linear layer, which expects an input with 2048 features, while the passed activation has 512.
Change in_features to 512 and it should work.

You mean here?
self.classifier = nn.Sequential( nn.Linear(2048, 4096),
change this into:
self.classifier = nn.Sequential( nn.Linear(512, 4096),
If I do in this way, then:

model = Net()
output = model(test_x)
output.shape

I have the same error:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x2048 and 512x4096)
So, I think it is not the solution and that the first layer must remain with 2048 in_features self.classifier = nn.Sequential( nn.Linear(2048, 4096),
I think the problem is in Intel OpenFL that I do not know how makes the training. Indeed the original error is during the training, not during the instance of the model.

Note that the error message changed, which could point to inputs with different shapes. While the initial error showed an activation shape of [128, 512] it’s now showing [1, 2048], so you might want to check the input shapes to the model and make sure to use constant ones or use an adaptive pooling layer to create a defined activation shape.

But I mean, this code runs perfectly on google colab, or on my personal computer.
The problem is only on Intel OpenFL…

Assuming you’ve checked the input shapes and thus verified that the shape mismatch is not caused by the input, then it might indeed be a library issue (I can’t test it as I’m not familiar with OpenFL).
Could you please create an issue on GitHub so that it can be tracked and fixes, please?

Yes, I am going to open it. But maybe it is a problem of this new framework OpenFL rather than PyTorch.