Hi!
I have a problem with VGG 11 model and federated data loader. On training, I have a runtime error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-17-448339389201> in <module>
1 for epoch in range(1, args.epochs + 1):
----> 2 train(args, model, device, federated_train_loader, optimizer, epoch)
<ipython-input-14-9b8111af22ce> in train(args, model, device, train_loader, optimizer, epoch)
5 data, target = data.to(device), target.to(device)
6 optimizer.zero_grad()
----> 7 output = model(data)
8 loss = F.nll_loss(output, target)
9 loss.backward()
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/torchvision/models/vgg.py in forward(self, x)
42 x = self.features(x)
43 x = self.avgpool(x)
---> 44 x = x.view(x.size(0), -1)
45 x = self.classifier(x)
46 return x
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/hook/hook.py in overloaded_native_method(self, *args, **kwargs)
675 # Send the new command to the appropriate class and get the response
676 method = getattr(new_self, method_name)
--> 677 response = method(*new_args, **new_kwargs)
678
679 # For inplace methods, just directly return self
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/hook/hook.py in overloaded_pointer_method(self, *args, **kwargs)
511 command = (attr, self, args, kwargs)
512
--> 513 response = owner.send_command(location, command)
514
515 return response
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/workers/base.py in send_command(self, recipient, message, return_ids)
425
426 try:
--> 427 ret_val = self.send_msg(codes.MSGTYPE.CMD, message, location=recipient)
428 except ResponseSignatureError as e:
429 ret_val = None
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/workers/base.py in send_msg(self, msg_type, message, location)
221
222 # Step 2: send the message and wait for a response
--> 223 bin_response = self._send_msg(bin_message, location)
224
225 # Step 3: deserialize the response
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/workers/virtual.py in _send_msg(self, message, location)
8 class VirtualWorker(BaseWorker, FederatedClient):
9 def _send_msg(self, message: bin, location: BaseWorker) -> bin:
---> 10 return location._recv_msg(message)
11
12 def _recv_msg(self, message: bin) -> bin:
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/workers/virtual.py in _recv_msg(self, message)
11
12 def _recv_msg(self, message: bin) -> bin:
---> 13 return self.recv_msg(message)
14
15 @staticmethod
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/workers/base.py in recv_msg(self, bin_message)
252 print(f"worker {self} received {sy.codes.code2MSGTYPE[msg_type]} {contents}")
253 # Step 1: route message to appropriate function
--> 254 response = self._message_router[msg_type](contents)
255
256 # Step 2: Serialize the message to simple python objects
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/workers/base.py in execute_command(self, message)
363 else:
364 try:
--> 365 response = getattr(_self, command_name)(*args, **kwargs)
366 except TypeError:
367 # TODO Andrew thinks this is gross, please fix. Instead need to properly deserialize strings
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/hook/hook.py in overloaded_native_method(self, *args, **kwargs)
661 except BaseException as e:
662 # we can make some errors more descriptive with this method
--> 663 raise route_method_exception(e, self, args, kwargs)
664
665 else: # means that there is a wrapper to remove
~/miniconda3/envs/pysyft/lib/python3.7/site-packages/syft/frameworks/torch/hook/hook.py in overloaded_native_method(self, *args, **kwargs)
655 try:
656 if isinstance(args, tuple):
--> 657 response = method(*args, **kwargs)
658 else:
659 response = method(args, **kwargs)
RuntimeError: shape '[0, -1]' is invalid for input of size 1605632
I use https://susanqq.github.io/UTKFace/ dataset. Here is like my training data and labels look like (both are in np-arrays):
Training data (224 over 224 3 channels images):
array([[[[ 25., 32., 36., ..., 70., 72., 74.],
[ 23., 30., 34., ..., 70., 73., 74.],
[ 22., 26., 31., ..., 70., 73., 74.],
...,
[ 5., 4., 2., ..., 3., 2., 1.],
[ 15., 12., 8., ..., 4., 2., 0.],
[ 22., 18., 12., ..., 4., 3., 0.]],
[[ 4., 8., 12., ..., 44., 46., 48.],
[ 2., 6., 10., ..., 44., 47., 48.],
[ 1., 5., 7., ..., 44., 47., 48.],
...,
[ 41., 39., 37., ..., 3., 2., 1.],
[ 51., 47., 43., ..., 4., 2., 0.],
[ 58., 53., 47., ..., 4., 3., 0.]],
[[ 1., 6., 10., ..., 17., 19., 21.],
[ 0., 4., 6., ..., 17., 20., 21.],
[ 0., 0., 3., ..., 17., 20., 21.],
...,
[ 67., 67., 67., ..., 3., 2., 1.],
[ 77., 75., 73., ..., 4., 2., 0.],
[ 84., 81., 75., ..., 4., 3., 0.]]],
[[[ 68., 70., 77., ..., 246., 246., 246.],
[ 70., 72., 79., ..., 246., 246., 246.],
[ 75., 77., 84., ..., 245., 246., 246.],
...,
[194., 194., 194., ..., 238., 238., 238.],
[195., 194., 194., ..., 238., 238., 238.],
[195., 195., 194., ..., 238., 238., 238.]],
[[ 34., 36., 43., ..., 246., 246., 246.],
[ 36., 38., 45., ..., 246., 246., 246.],
[ 41., 43., 50., ..., 245., 246., 246.],
...,
[144., 144., 144., ..., 238., 238., 238.],
[145., 144., 144., ..., 238., 238., 238.],
[145., 145., 144., ..., 238., 238., 238.]],
[[ 24., 26., 33., ..., 248., 248., 248.],
[ 26., 28., 35., ..., 248., 248., 248.],
[ 31., 33., 40., ..., 247., 248., 248.],
...,
[133., 133., 133., ..., 240., 240., 240.],
[134., 133., 133., ..., 240., 240., 240.],
[134., 134., 133., ..., 240., 240., 240.]]],
[[[ 68., 76., 78., ..., 80., 83., 85.],
[ 77., 85., 90., ..., 81., 82., 84.],
[ 83., 90., 95., ..., 79., 79., 80.],
...,
[122., 125., 131., ..., 27., 29., 30.],
[135., 127., 124., ..., 27., 28., 30.],
[135., 121., 117., ..., 27., 28., 30.]],
[[ 57., 65., 67., ..., 59., 62., 64.],
[ 66., 74., 79., ..., 60., 61., 63.],
[ 72., 79., 84., ..., 58., 58., 59.],
...,
[115., 118., 124., ..., 27., 29., 30.],
[128., 120., 118., ..., 27., 28., 30.],
[128., 114., 111., ..., 27., 28., 30.]],
[[ 51., 59., 61., ..., 38., 41., 43.],
[ 60., 68., 73., ..., 39., 40., 42.],
[ 66., 73., 78., ..., 37., 37., 38.],
...,
[122., 125., 131., ..., 27., 29., 30.],
[135., 127., 122., ..., 27., 28., 30.],
[135., 121., 115., ..., 27., 28., 30.]]],
...,
[[[106., 106., 108., ..., 132., 165., 186.],
[102., 104., 106., ..., 128., 162., 186.],
[ 99., 101., 103., ..., 123., 158., 184.],
...,
[180., 185., 181., ..., 162., 160., 162.],
[193., 186., 175., ..., 175., 172., 173.],
[201., 191., 173., ..., 184., 182., 184.]],
[[ 32., 35., 37., ..., 78., 109., 130.],
[ 31., 33., 35., ..., 74., 106., 130.],
[ 28., 29., 31., ..., 67., 102., 128.],
...,
[134., 137., 133., ..., 140., 138., 140.],
[147., 140., 127., ..., 153., 150., 151.],
[155., 145., 127., ..., 162., 160., 162.]],
[[ 3., 5., 7., ..., 52., 82., 103.],
[ 1., 5., 7., ..., 48., 79., 103.],
[ 0., 4., 6., ..., 42., 75., 101.],
...,
[ 98., 101., 97., ..., 143., 141., 143.],
[111., 104., 91., ..., 156., 153., 154.],
[119., 109., 93., ..., 165., 163., 165.]]],
[[[ 28., 29., 29., ..., 12., 11., 10.],
[ 30., 31., 33., ..., 12., 11., 11.],
[ 34., 35., 36., ..., 13., 12., 12.],
...,
[ 58., 55., 49., ..., 102., 117., 139.],
[ 60., 57., 49., ..., 105., 110., 131.],
[ 62., 58., 50., ..., 68., 57., 72.]],
[[ 29., 30., 30., ..., 14., 13., 12.],
[ 31., 32., 34., ..., 14., 13., 13.],
[ 35., 36., 37., ..., 15., 14., 14.],
...,
[ 45., 45., 39., ..., 137., 152., 175.],
[ 47., 47., 39., ..., 140., 145., 166.],
[ 49., 48., 40., ..., 103., 92., 107.]],
[[ 24., 25., 25., ..., 13., 12., 11.],
[ 26., 27., 29., ..., 13., 12., 12.],
[ 30., 31., 32., ..., 14., 13., 13.],
...,
[ 37., 36., 30., ..., 175., 206., 237.],
[ 39., 38., 30., ..., 178., 199., 230.],
[ 41., 39., 31., ..., 141., 146., 171.]]],
[[[252., 252., 253., ..., 57., 57., 57.],
[252., 252., 253., ..., 57., 57., 57.],
[252., 252., 253., ..., 56., 56., 56.],
...,
[246., 246., 245., ..., 170., 170., 170.],
[245., 245., 245., ..., 172., 172., 171.],
[245., 245., 244., ..., 173., 173., 173.]],
[[253., 253., 252., ..., 54., 54., 54.],
[253., 253., 252., ..., 54., 54., 54.],
[253., 253., 252., ..., 53., 53., 53.],
...,
[239., 239., 238., ..., 123., 123., 123.],
[238., 238., 238., ..., 125., 125., 124.],
[238., 238., 237., ..., 126., 126., 126.]],
[[248., 248., 248., ..., 49., 49., 49.],
[248., 248., 248., ..., 49., 49., 49.],
[248., 248., 248., ..., 48., 48., 48.],
...,
[247., 247., 246., ..., 105., 105., 105.],
[246., 246., 246., ..., 107., 107., 106.],
[246., 246., 245., ..., 108., 108., 108.]]]], dtype=float32)
train_data.shape ->
(7820, 3, 224, 224)
labels (gender, 0-male, 1-female):
array([0., 1., 0., ..., 1., 1., 1.], dtype=float32)
Federated data-loader ->
base=sy.BaseDataset(torch.from_numpy(train_data),
torch.from_numpy(train_labels_after))
base_federated=base.federate((bob, alice))
federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader
base_federated,batch_size=args.batch_size)
If you have any ideas over what can help, please advise…