Maybe like this (as suggested by @ptrblck)? Note this only works for disjoint column sets cols1,cols2 (otherwise cols2 will overwrite cols1), and all column numbers need to be contained (i.e. from 0 to num_columns). The code keeps the order of the columns. Don’t know if it is memory-efficient, though.
act1 = torch.tanh
act2 = torch.relu
class MyFuncX(nn.Module):
def __init__(self,cols1,cols2):
super().__init__()
self.cols1 = cols1
self.cols2 = cols2
return
def forward(self,x):
out = torch.zeros_like(x)
out[:,self.cols1] = act1(x[:,self.cols1])
out[:,self.cols2] = act2(x[:,self.cols2])
return out
model = MyFuncX((1,3),(0,2,4))
x = torch.normal(0,1,(16,5))
out = model(x)
See the output:
tensor([[ 0.5571, -0.7352, 0.0000, 0.2542, 0.0263],
[ 0.0000, -0.5948, 0.1626, 0.1253, 0.0000],
[ 0.0000, 0.7736, 0.5284, -0.5993, 0.2219],
[ 0.5151, -0.3131, 1.3649, 0.9967, 0.3191],
[ 0.5829, -0.2705, 0.0000, 0.2713, 0.0000],
[ 0.0997, 0.9836, 1.5062, 0.5994, 0.0000],
[ 0.1977, 0.7754, 0.0000, -0.6180, 0.0000],
[ 0.0000, -0.2901, 0.5313, -0.8883, 0.0000],
[ 0.1594, -0.0371, 0.3020, 0.5257, 0.0000],
[ 1.7526, -0.8893, 0.0000, 0.5998, 0.0000],
[ 0.0000, 0.2289, 0.1292, 0.9542, 0.0000],
[ 0.0000, 0.9498, 1.9529, 0.8740, 0.0000],
[ 0.0000, -0.8019, 0.0000, 0.7863, 0.6394],
[ 0.0000, 0.0918, 0.5882, -0.2390, 0.0000],
[ 1.0211, -0.4522, 0.0000, -0.5290, 0.0000],
[ 0.6756, 0.7053, 0.4621, -0.8896, 0.8450]])
Clearly, the application of relu zeros out negative inputs in columns 0,2,4, and in cols 1,3 tanh is applied, as intended. I think zeros_like will also take care of creating the tensor on the same device as your input data.