How to compute max of a few tensors?

It shouldnot be difficult but im not sure why i cannot make it work!!!
Lets say I have 3 tensors A,B,C , how I can compute the max between them along dim = 1?
I checked link but still cannot figure it out …

A = torch.rand(2,3,4,4)

B = torch.rand(2,3,4,4)

C = torch.rand(2,3,4,4)

How to compute max along dim = 1 and keep the dimention

I wrote this though I am not sure if this is what you want.

x = torch.cat((A, B, C), dim=0)
torch.max (x, dim = 1)

this is gives you an output with size of torch.Size([6, 4, 4]) but we expect the output be in size of [2,3,4,4]

I figured out that one way is to do the following:

L= [A,B,C]
L = torch.stack(L)
L.max(0)[0]

but i dont know I should do L.max(0) to get max value along the dim = 1???

Also please let me know if there is any other way to get the max value?