Argmax to reduce multiple dimensions

The current version of PyTorch argmax() only reduce one dimension at a time by setting the dim parameter. Is there anyway to reduce multiple dimensions at a time? For example, if x is of shape [a, b, c, d], and I have argmax(x, dims = [1, 2, 3]), the returned value should be of shape [a].