Hi Omer!
You can .reshape()
your input tensor and use the groups
constructor
argument of Conv1d
:
>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> b = 2
>>> c = 3
>>> n = 5
>>> m = 2
>>> t = torch.rand ([b, c, n, m])
>>> u = t.permute (0, 3, 1, 2).reshape (b, m * c, n)
>>> conv = torch.nn.Conv1d (m * c, m * c, 1, groups = m)
>>> conv (u)
tensor([[[-0.0624, -0.3195, -0.3976, -0.0028, -0.3123],
[ 0.5862, 0.6727, 0.9059, 0.3821, 0.9464],
[-0.9408, -1.1464, -0.7950, -0.8718, -0.6577],
[-0.4530, -0.5568, -0.4779, -0.8847, -0.5256],
[-0.5623, -0.7672, -0.3805, -0.7634, -0.5598],
[-0.6248, -0.6711, -0.2747, -0.6640, -0.0652]],
[[-0.3047, -0.4761, -0.3934, -0.1065, -0.0300],
[ 0.6465, 0.8431, 0.7757, 0.8131, 0.4114],
[-1.0906, -0.9153, -0.9905, -0.7797, -0.7686],
[-0.8211, -0.4375, -1.0047, -0.2747, -0.2525],
[-1.0808, -0.6798, -0.8439, -0.2883, -0.5769],
[-0.2513, -0.5921, -0.3546, -0.4071, -0.3656]]],
grad_fn=<SqueezeBackward1>)
>>> t
tensor([[[[0.1304, 0.5134],
[0.7426, 0.7159],
[0.5705, 0.1653],
[0.0443, 0.9628],
[0.2943, 0.0992]],
[[0.8096, 0.0169],
[0.8222, 0.1242],
[0.7489, 0.3608],
[0.5131, 0.2959],
[0.7834, 0.7405]],
[[0.8050, 0.3036],
[0.9942, 0.5025],
[0.3734, 0.0413],
[0.8387, 0.0604],
[0.1773, 0.3301]]],
[[[0.6857, 0.6960],
[0.8303, 0.5216],
[0.7438, 0.8290],
[0.0219, 0.0813],
[0.0172, 0.1464]],
[[0.7492, 0.9450],
[0.6737, 0.1135],
[0.7421, 0.7810],
[0.9446, 0.0451],
[0.4282, 0.2427]],
[[0.9363, 0.7784],
[0.5605, 0.5312],
[0.7132, 0.1075],
[0.4496, 0.1255],
[0.6784, 0.6550]]]])
>>> u
tensor([[[0.1304, 0.7426, 0.5705, 0.0443, 0.2943],
[0.8096, 0.8222, 0.7489, 0.5131, 0.7834],
[0.8050, 0.9942, 0.3734, 0.8387, 0.1773],
[0.5134, 0.7159, 0.1653, 0.9628, 0.0992],
[0.0169, 0.1242, 0.3608, 0.2959, 0.7405],
[0.3036, 0.5025, 0.0413, 0.0604, 0.3301]],
[[0.6857, 0.8303, 0.7438, 0.0219, 0.0172],
[0.7492, 0.6737, 0.7421, 0.9446, 0.4282],
[0.9363, 0.5605, 0.7132, 0.4496, 0.6784],
[0.6960, 0.5216, 0.8290, 0.0813, 0.1464],
[0.9450, 0.1135, 0.7810, 0.0451, 0.2427],
[0.7784, 0.5312, 0.1075, 0.1255, 0.6550]]])
Best.
K. Frank