Maximum pixel value of RGB channel

Image is read as [1,3,256,256].
(1) How can i find maximum pixel value in each channel (max value of each RGB channel) and
(2) Pixel value with maximum RGB value

For example, if i have two pixel with value (255,250,190) and (240,240,240) then in first case output should be (255,250,240) and in second case (240,240,240).

image = torch.tensor([[255, 250, 190], [240, 240, 240]], dtype=torch.int32).T.view(1, 3, 2, 1)

# 1. Maximum Pixel Value in Each RGB Channel
max_per_channel = torch.amax(image, dim=(0, 2, 3))

# 2. Pixel Value with Maximum RGB Value
# Reshape to separate each pixel and sum across channels
sum_per_pixel = image.view(3, -1).sum(dim=0)
# Find the index of the pixel with the maximum sum
max_rgb_index = torch.argmax(sum_per_pixel)
# Retrieve the pixel with the maximum RGB sum
max_rgb_pixel = image.view(3, -1)[:, max_rgb_index]

print(max_per_channel, max_rgb_pixel)