Per row/column clamping


I am wondering what is the most efficient way to implement per-row/column clamping? For instance,

input = torch.tensor([[0,3,2,-1],[2,1,3,4]])
input_clamp = torch.Tensor([2,3])
output = per_row_clamp(input, input_clamp)

The output then should be [[0,2,2,-1],[2,1,3,3]]

One approach I can think of is to do a per-row division, then clamp everything to 1, then multiply the per-row clamp value back.


You can try below code. just do some boolean indexing in the data. After comparator, do the replacement in the data. You will get the idea.

import torch

input = torch.tensor([[0,3,2,-1],[2,1,3,4]],dtype = torch.float)
input_clamp = torch.Tensor([2,3])

tiled = input_clamp.repeat(4,1).t()

bigger = input > tiled
input[bigger] = tiled[bigger]