Efficient Gradient Penalty

I want optimise the following loss function wrt to the parameters of f efficiently:

loss(x,y) = criterion(f(x),y) + lambda*max_i ||grad_(x_i) f(x_i)||_2

so I want to optimise the criterion regularised by the max L2 Norm of the gradient of f(x_i) with respect to (x_i), so the input. With (x_i) i mean a single element of the batch x. How can I formulate this? I struggle to construct a single computational graph, x and y is a batch of data.

This is not in the context of a GAN, f is a usual classifier