Hello all!
I just want to ask a silly question…
how can we compute the divergence of a vector field f(x,v,t) only with respect to v by using the jacrev, trace and vmap?
My current implementation is
vecfield= lambda v: f(x,v,t)
div = vmap(torch.trace(jacrev(vecfield)))(x)
and it reports the error XD
Hi @iamctR,
Do you have a minimal reproducible example?
Hello, here is the example code:
from torch.func import vjp, jvp, vmap, jacrev
import torch
def div_fn(u):
J = jacrev(u)
return lambda x: torch.trace(J(x))
def fnc(x,y,z):
"""
x: [bs, 4]
y: [b,1]
z: [bs,1]
output: [bs,1]
"""
x= x.reshape(bs,-1)
y= y.reshape(bs,-1)
z= z.reshape(bs,-1)
return z+y+x.sum(-1)[...,None]
bs=512
x=torch.randn(bs,4)
y=torch.randn(bs,1)
z=torch.randn(bs,1)
vecfield=fnc(x,y,z)
vecfield = lambda yy: fnc(x,yy,z)
div = vmap(div_fn(vecfield))(y)
print(div)
Thank you!
Hi @iamctR,
Here’s a working example, but some parts of your example weren’t 100% clear to me. So this might not be what you want, but it should help you understand how torch.func
works.
When working with torch.func
, you need to create a function which is independent of your batch-dim, i.e. the function should work fine with 1 sample, so you can vectorize over all samples via torch.func.vmap
.
So, remove .reshape(bs,-1)
for example, as within the vmap
call the batch size is technically 1.
from torch.func import vjp, jvp, vmap, jacrev
import torch
def fnc(x,y,z):
inp = torch.cat([x,y,z],dim=-1)
return torch.sum(inp,dim=-1)
def div_fn(x,y,z):
J = jacrev(fnc,argnums=1)(x,y,z) #grad w.r.t y
return J #torch.trace(J)
bs=512
x=torch.randn(bs,4)
y=torch.randn(bs,1)
z=torch.randn(bs,1)
div=vmap(div_fn, in_dims=(0,0,0))(x,y,z)
print(div.shape) #returns shape [512,1]
1 Like