Find "bounding box" around ones in batch of masks

So I have a batch of masks (values 0.0/1.0) with dimension (b,h,w) where b is the batch-dim.

For each mask, i need the coordinate of the the bounding box containing all the ones. I thought about using various operators (torch.where/torch.nonzero) but I can’t get it to work in a batched setting.

Ok, I think I’ve solved it:

#get batched bounds for pytorch
def get_bounds_t(img):
    b,h,w = img.shape[:3]
    rows = torch.any(img, axis=2)
    cols = torch.any(img, axis=1)
    rmins = torch.argmax(rows.float(), dim=1)
    rmaxs = h - torch.argmax(rows.float().flip(dims=[1]), dim=1)
    cmins = torch.argmax(cols.float(), dim=1)
    cmaxs = w - torch.argmax(cols.float().flip(dims=[1]), dim=1)

    return rmins, rmaxs, cmins, cmaxs