How to use `softmax()` on 5D tensors without using .view() or without runtime cuda error

I need to use softmax function on variables containing 5D tensors (batch_size X nChannels X W X D X H ) that is common when working with 3D images of size W X D X H.
Since softmax() doesn’t support 5D, I was changing the variable as below:
var = var.view(batch_size, nChannels, -1)
However running softmax(var) throws runtime error when W X D X H is bigger than a certain size. Please see here for an example to produce this error.

Is there another better way of making softmax() and other functions work on 3D images that end up being 5d tensors with the batch and channel dimension ?

1 Like