@ptrblck Here it is
batch_size = 3
height = 2
width = 2
A = torch.randint(2, 11, (batch_size, height, width)).float()
AA = A.clone()
print(A)
# I can get what I want from below for-loop solution
for i in range(batch_size):
A[i] -= torch.min(A[i])
A[i] /= torch.max(A[i])
# Your solution
AA -= AA.min(1, keepdim=True)[0]
AA /= AA.max(1, keepdim=True)[0]
print(A) # A and AA are different
print(AA)