How to efficiently normalize a batch of tensor to [0, 1]

@ptrblck Here it is

batch_size = 3
height = 2
width = 2

A = torch.randint(2, 11, (batch_size, height, width)).float()
AA = A.clone()
print(A)

# I can get what I want from below for-loop solution
for i in range(batch_size):
    A[i] -= torch.min(A[i])
    A[i] /= torch.max(A[i])

# Your solution
AA -= AA.min(1, keepdim=True)[0]
AA /= AA.max(1, keepdim=True)[0]

print(A)  # A and AA are different
print(AA)