Torchvision.models with PySyft troubles

Hi!

Recently I’ve started to play around federated learning. I tried to use ResNet from torchvision.models for image classification for FashionMNIST dataset.
But apparently I have an error on training:

for epoch in range(1, args.epochs + 1):
    train(args, model, device, federated_train_loader, optimizer, epoch)
    test(args, model, device, test_loader)
TypeError: object of type 'NoneType' has no len()

My model:

class resnet101(models.resnet.ResNet):

    def __init__(self, block, layers):
        super(resnet101, self).__init__(block, layers)
        self.inplanes = 64
        self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=2, stride=1, padding=1,
                               bias=False)

model = resnet101(models.resnet.Bottleneck, [3, 4, 23, 3], **kwargs)

As a trial code example, I fully use code from the blogpost -> https://blog.openmined.org/upgrade-to-federated-learning-in-10-lines/

I’ll be really grateful for any advice or help.

The full traceback of my error is:

---------------------------------------------------------------------------
PureTorchTensorFoundError                 Traceback (most recent call last)
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/native.py in handle_func_command(cls, command)
    259             new_args, new_kwargs, new_type, args_type = syft.frameworks.torch.hook_args.hook_function_args(
--> 260                 cmd, args, kwargs, return_args_type=True
    261             )

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/hook/hook_args.py in hook_function_args(attr, args, kwargs, return_args_type)
    156         # Try running it
--> 157         new_args = hook_args(args)
    158 

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/hook/hook_args.py in <lambda>(x)
    350 
--> 351     return lambda x: f(lambdas, x)
    352 

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/hook/hook_args.py in seven_fold(lambdas, args, **kwargs)
    558     return (
--> 559         lambdas[0](args[0], **kwargs),
    560         lambdas[1](args[1], **kwargs),

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/hook/hook_args.py in <lambda>(i)
    328         # Last if not, rule is probably == 1 so use type to return the right transformation.
--> 329         else lambda i: forward_func[type(i)](i)
    330         for a, r in zip(args, rules)  # And do this for all the args / rules provided

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/hook/hook_args.py in <lambda>(i)
     55     if hasattr(i, "child")
---> 56     else (_ for _ in ()).throw(PureTorchTensorFoundError),
     57     torch.nn.Parameter: lambda i: i.child

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/hook/hook_args.py in <genexpr>(.0)
     55     if hasattr(i, "child")
---> 56     else (_ for _ in ()).throw(PureTorchTensorFoundError),
     57     torch.nn.Parameter: lambda i: i.child

PureTorchTensorFoundError: 

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
<ipython-input-15-7c854db39ed0> in <module>
      1 for epoch in range(1, args.epochs + 1):
----> 2     train(args, model, device, federated_train_loader, optimizer, epoch)
      3     test(args, model, device, test_loader)

<ipython-input-5-9b8111af22ce> in train(args, model, device, train_loader, optimizer, epoch)
      5         data, target = data.to(device), target.to(device)
      6         optimizer.zero_grad()
----> 7         output = model(data)
      8         loss = F.nll_loss(output, target)
      9         loss.backward()

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/torchvision/models/resnet.py in forward(self, x)
    190 
    191     def forward(self, x):
--> 192         x = self.conv1(x)
    193         x = self.bn1(x)
    194         x = self.relu(x)

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/torch/nn/modules/conv.py in forward(self, input)
    336                             _pair(0), self.dilation, self.groups)
    337         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 338                         self.padding, self.dilation, self.groups)
    339 
    340 

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/hook/hook.py in overloaded_func(*args, **kwargs)
    715             cmd_name = f"{attr.__module__}.{attr.__name__}"
    716             command = (cmd_name, None, args, kwargs)
--> 717             response = TorchTensor.handle_func_command(command)
    718             return response
    719 

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/native.py in handle_func_command(cls, command)
    268             new_command = (cmd, None, new_args, new_kwargs)
    269             # Send it to the appropriate class and get the response
--> 270             response = new_type.handle_func_command(new_command)
    271             # Put back the wrappers where needed
    272             response = syft.frameworks.torch.hook_args.hook_response(

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/pointers/object_pointer.py in handle_func_command(cls, command)
     86 
     87         # Send the command
---> 88         response = owner.send_command(location, command)
     89 
     90         return response

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/workers/base.py in send_command(self, recipient, message, return_ids)
    425 
    426         try:
--> 427             ret_val = self.send_msg(codes.MSGTYPE.CMD, message, location=recipient)
    428         except ResponseSignatureError as e:
    429             ret_val = None

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/workers/base.py in send_msg(self, msg_type, message, location)
    221 
    222         # Step 2: send the message and wait for a response
--> 223         bin_response = self._send_msg(bin_message, location)
    224 
    225         # Step 3: deserialize the response

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/workers/virtual.py in _send_msg(self, message, location)
      8 class VirtualWorker(BaseWorker, FederatedClient):
      9     def _send_msg(self, message: bin, location: BaseWorker) -> bin:
---> 10         return location._recv_msg(message)
     11 
     12     def _recv_msg(self, message: bin) -> bin:

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/workers/virtual.py in _recv_msg(self, message)
     11 
     12     def _recv_msg(self, message: bin) -> bin:
---> 13         return self.recv_msg(message)
     14 
     15     @staticmethod

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/workers/base.py in recv_msg(self, bin_message)
    252             print(f"worker {self} received {sy.codes.code2MSGTYPE[msg_type]} {contents}")
    253         # Step 1: route message to appropriate function
--> 254         response = self._message_router[msg_type](contents)
    255 
    256         # Step 2: Serialize the message to simple python objects

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/workers/base.py in execute_command(self, message)
    383                 command = getattr(command, path)
    384 
--> 385             response = command(*args, **kwargs)
    386 
    387         # some functions don't return anything (such as .backward())

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/hook/hook.py in overloaded_func(*args, **kwargs)
    715             cmd_name = f"{attr.__module__}.{attr.__name__}"
    716             command = (cmd_name, None, args, kwargs)
--> 717             response = TorchTensor.handle_func_command(command)
    718             return response
    719 

~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/native.py in handle_func_command(cls, command)
    285             # in the execute_command function
    286             if isinstance(args, tuple):
--> 287                 response = eval(cmd)(*args, **kwargs)
    288             else:
    289                 response = eval(cmd)(args, **kwargs)

RuntimeError: weight should have at least three dimensions


Could you post the shape of your input?
The model should generally work for [batch_size, 1, 224, 224]-shaped inputs (at least if you remove the kwargs argument, as I’m not sure, if you are passing something in it).