Size mismatch of input if more than one CUDA device

I’m implementing some RL and got stuck at a, in my opinion, weird behaviour.
I’ll use DataParallel and the device-tag to move my Nets/ Data to the available device(s).
Using CPU and one CUDA device everything works fine, but if I use more than one device, I’ll get the following error:

File “”, line 229, in
File “”, line 177, in main
mae = ddpg.validate(states_val, labels_val, mean_train, std_train).item()
File “RL/DDPG/”, line 415, in validate
forecast = self.actor_target(state)
File “env/lib/python3.5/site-packages/torch/nn/modules/”, line 491, in call
result = self.forward(*input, **kwargs)
File “env/lib/python3.5/site-packages/torch/nn/parallel/”, line 114, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “env/lib/python3.5/site-packages/torch/nn/parallel/”, line 124, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “env/lib/python3.5/site-packages/torch/nn/parallel/”, line 65, in parallel_apply
raise output
File “env/lib/python3.5/site-packages/torch/nn/parallel/”, line 41, in _worker
output = module(*input, **kwargs)
File “env/lib/python3.5/site-packages/torch/nn/modules/”, line 491, in call
result = self.forward(*input, **kwargs)
File “RL/Nets/actor/”, line 28, in forward
x = F.relu(self.input_layer(x))
File “env/lib/python3.5/site-packages/torch/nn/modules/”, line 491, in call
result = self.forward(*input, **kwargs)
File “env/lib/python3.5/site-packages/torch/nn/modules/”, line 55, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: size mismatch, m1: [1 x 144], m2: [288 x 256] at /pytorch/aten/src/THC/generic/

The input layer of actor_target is some simple linear layer:
self.input_layer = nn.Linear(288, 256)

I already checked the batch count, such that the data can be distributed evenly among the devices.

Could you post your forward function?
I guess you are using view, which might yield the wrong shape.

This is my forward function:

def forward(self, state):
# no batch normalization if there is no batch
no_batch = True if len(state[0].shape) == 0 else False
x = self.dropout_layer_1(state)
x = F.relu(self.input_layer(x))
x = self.batch_norm_1(x) if not no_batch else x
x = self.dropout_layer_2(x)
x = F.relu(self.hidden_layer_1(x))
x = self.batch_norm_2(x) if not no_batch else x
x = self.dropout_layer_3(x)
x = F.relu(self.hidden_layer_2(x))
x = self.batch_norm_3(x) if not no_batch else x
x = self.dropout_layer_4(x)
x = self.output_layer(x)
# TODO: clamp between useful values
return x
Apparently it was a wrong assumption.
Could you print the shape of state in forward?

I inserted

def forward(self, state):
    # no batch normalization if there is no batch
    no_batch = True if len(state[0].shape) == 0 else False
    x = self.dropout_layer_1(state)

and got (while running on 4 devices):


After this it crashed :wink:

It looks like it is failing during validation, so batch size is 1 - which could explain the shape of 72 (72*4=288) but I called net.eval() before… .
Do I have to do some other calls before calling forward where the batch size is smaller than the amount of devices?
Additionally, I’m implementing DDPG where I have to iterativly evaluate/ call the net with one state/ a batch of one state… .

Could you try to add a batch dimension to your data?
For a batch size of 1, your input shape should be [1, in_features].
I assume 72 is your feature dimension.
If so, nn.DataParallel might split on the wrong dimension.

My number of features is 288. I did state.view(1,-1) such that the dimension of the state is now [1, 288] but not I get the following error:

{ValueError}Expected more than 1 value per channel when training, got input size [1, 256]

If i do state.view(-1,1), such that the dimension of the state is [288,1] then I’ll get:

{RuntimeError}size mismatch, m1: [288 x 1], m2: [288 x 256] at c:\programdata\miniconda3\conda-bld\pytorch-cpu_1524541161962\work\aten\src\th\generic/THTensorMath.c:2033

I think I should fix this with an eval call beforehead - I’ll check this…

Seems to be working - thank you very much!
(It’s running on a cluster now -> need some time until I get the log…)