If you add some debug statements to the `forward`

method youâ€™ll see that due to the second input, the processing fails, since it cannot be chunked into 8 parts:

```
class Net(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
print('x: {}, {}'.format(x.device, x.shape))
print('y: {}, {}'.format(y.device, y.shape))
return x, y
```

Your setup:

```
x = torch.ones((1024, 3, 100, 100))
y = torch.ones((1, 3, 100, 100))
x_out, y_out = net(x, y)
print(x_out.shape)
x: cuda:0, torch.Size([128, 3, 100, 100])
y: cuda:0, torch.Size([1, 3, 100, 100])
torch.Size([128, 3, 100, 100])
```

Iâ€™m currently unsure what the expected behavior is, as it currently seems to fall back to the smallest possible splitting (I would assume an error would be raised).

If you are using tensors with the same shape in `dim0`

, itâ€™s working as expected:

```
x = torch.ones((1024, 3, 100, 100))
y = torch.ones((1024, 3, 100, 100))
x_out, y_out = net(x, y)
print(x_out.shape)
x: cuda:0, torch.Size([128, 3, 100, 100])
y: cuda:0, torch.Size([128, 3, 100, 100])
x: cuda:1, torch.Size([128, 3, 100, 100])
y: cuda:1, torch.Size([128, 3, 100, 100])
x: cuda:2, torch.Size([128, 3, 100, 100])
x: cuda:3, torch.Size([128, 3, 100, 100])
y: cuda:2, torch.Size([128, 3, 100, 100])
x: cuda:4, torch.Size([128, 3, 100, 100])
x: cuda:5, torch.Size([128, 3, 100, 100])
y: cuda:4, torch.Size([128, 3, 100, 100])
y: cuda:5, torch.Size([128, 3, 100, 100])
x: cuda:6, torch.Size([128, 3, 100, 100])
x: cuda:7, torch.Size([128, 3, 100, 100])
y: cuda:6, torch.Size([128, 3, 100, 100])
y: cuda:7, torch.Size([128, 3, 100, 100])
y: cuda:3, torch.Size([128, 3, 100, 100])
torch.Size([1024, 3, 100, 100])
```