Documentation bug and feature suggestion for ChannelShuffle

Hi Forum!

As I read the ChannelShuffle documentation, it implies that the [-3]
dimension (the “C” dimension) of the input tensor will be shuffled. But
this is not the observed behavior. For example:

>>> import torch
>>> torch.__version__
'2.3.1'
>>> torch.nn.ChannelShuffle (2) (torch.ones (1, 2, 1, 1))      # agrees with doc
tensor([[[[1.]],

         [[1.]]]])
>>> torch.nn.ChannelShuffle (2) (torch.ones (1, 2, 1, 1, 1))   # disagrees with doc
tensor([[[[[1.]]],


         [[[1.]]]]])
>>> torch.nn.ChannelShuffle (2) (torch.ones (1, 2, 1))         # disagrees with doc
tensor([[[1.],
         [1.]]])
>>> torch.nn.ChannelShuffle (2) (torch.ones (1, 1, 2, 1, 1))   # disagrees with doc
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch_install>\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<path_to_pytorch_install>\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<path_to_pytorch_install>\torch\nn\modules\channelshuffle.py", line 54, in forward
    return F.channel_shuffle(input, self.groups)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Number of channels must be divisible by groups. Got 1 channels and 2 groups.

I’m not sure what behavior is intended, but either way, the documentation
should match the behavior.

(As noted in github issue #123053,, the example code in the documentation
is incorrect (specifically, randn() doesn’t produce integer values)).

The feature suggestion is to make ChannelShuffle more flexible (and
in my mind simpler). Let ChannelShuffle operate on tensors with
arbitrary numbers of dimensions, and add a dim argument that specifies
which dimension to shuffle. (One can shuffle any desired dimension by
using the appropriate unsqueeze()s and permute()s, but this is less
readable and a nuisance.)

If the new argument dim is given the default value dim = -3, the default
behavior will be backward-compatible with the documented behavior,
although not with the observed behavior.

(Sorry for not posting this as a github issue – google nuked the account
I had been using to post on github.)

Best.

K. Frank

(btw this should have been fixed on main!)