Channel-wise proportional threshold

Hello all.

I’d like to implement a function:

  • argument:
    – x (torch.Tensor): x.shape == (B, C, H, W)
    – recipe (maybe python:dict): ex) {1: (0, 50), 2: (50, 75), 3: (75, 100)}

  • returns:
    – mask (torch.Tensor): mask.shape == (B, C, H, W)

  • how it works:
    In below figure, white matrix represents only one channel of given x.
    (i.e., x[sample_batch_idx][sample_channel_idx])
    also arbitrary recipe is given by proportional.(not threshold)
    the function returns mask represented by colored matrix.
    I’d like to do this process in channel-wisely.

Currently, I implemented a test function using for loop, but It shows horrible performance.
Is there any efficient way to implement this function?
I think the (automatic) broadcasting will be the key idea instead of for loop, but I don’t know how to apply this.

Any suggestions will be welcome for me.