Functorch jacrev is taking too much memory

I’m trying to get the gradient of a pretrain VGG net on cifar10, the code is the following

net = torch.hub.load('pytorch/vision:v0.10.0', 'vgg11', pretrained=True).to(device)
net.eval()
fnet, params = make_functional(net)

def fnet_single(params, x):
    return fnet(params, x.unsqueeze(0)).squeeze(0)

def get_ntk_feature(fnet_single, params, x1):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = [j.flatten(2) for j in jac1]

    return jac1

x, y = next(iter(cifar_trainLoader))
o=get_ntk_feature(fnet_single, params, x.cuda())

However, CUDA tries to allocate 152.59 GB memory. Here x is a batch of size 10, so shouldn’t the memory be just 10*sizeof(net)?