If you just to reshape your input data you can just use torch.reshape
,
let’s assume you have some data of data = torch.randn(B,CH,W,H)
and (for the sake of clarity) B=2, CH=2, W=4, H=4
this would give
>>> data=torch.randn(B,CH,W,H)
>>> data
tensor([[[[-0.7966, -0.0109, 0.8348, 1.4254],
[ 0.8557, -0.5183, 0.8383, -0.1150],
[-0.5543, -0.5676, -0.4311, -0.2642],
[ 1.5170, -0.3768, -0.0584, -0.7107]],
[[-0.4866, 0.6485, 1.0904, -1.7764],
[ 0.8951, 1.1638, 1.8855, -0.5628],
[-0.3668, 0.6140, 1.0155, -2.7003],
[-1.1971, -1.7511, 0.3515, -1.3025]]],
[[[ 1.0640, 0.4468, -0.1497, 1.6121],
[-0.4338, 0.9484, 1.0433, -0.6663],
[ 0.0659, -0.6567, -0.8577, -2.0923],
[-1.6651, 0.8021, -1.0585, -0.2755]],
[[-0.0414, -0.2734, -2.4475, -0.1633],
[-0.4347, 1.5978, 1.4764, 0.6679],
[-0.8553, 0.0685, 0.6054, 0.1157],
[-0.2803, -0.9768, 0.4805, -1.0493]]]])
If we do data.reshape(B,CH,W*H)
we get the desired result,
data.reshape(B,CH,W*H)
tensor([[[-0.7966, -0.0109, 0.8348, 1.4254, 0.8557, -0.5183, 0.8383,
-0.1150, -0.5543, -0.5676, -0.4311, -0.2642, 1.5170, -0.3768,
-0.0584, -0.7107],
[-0.4866, 0.6485, 1.0904, -1.7764, 0.8951, 1.1638, 1.8855,
-0.5628, -0.3668, 0.6140, 1.0155, -2.7003, -1.1971, -1.7511,
0.3515, -1.3025]],
[[ 1.0640, 0.4468, -0.1497, 1.6121, -0.4338, 0.9484, 1.0433,
-0.6663, 0.0659, -0.6567, -0.8577, -2.0923, -1.6651, 0.8021,
-1.0585, -0.2755],
[-0.0414, -0.2734, -2.4475, -0.1633, -0.4347, 1.5978, 1.4764,
0.6679, -0.8553, 0.0685, 0.6054, 0.1157, -0.2803, -0.9768,
0.4805, -1.0493]]])
as you can see the data has concatenated W, and H together into a single dimension. It should be noted that approach only works for neighbouring dimensions as W and H are the 3rd and 4th dimension here, if you want something like B,W,CH*H
you’d have to transpose your data beforehand!
Hopefully this helps!