How to compute batch divergence?

Hello all!

I just want to ask a silly question…

how can we compute the divergence of a vector field f(x,v,t) only with respect to v by using the jacrev, trace and vmap?

My current implementation is

vecfield= lambda v: f(x,v,t)
div = vmap(torch.trace(jacrev(vecfield)))(x)

and it reports the error XD

Hi @iamctR,

Do you have a minimal reproducible example?

Hello, here is the example code:

from torch.func import vjp, jvp, vmap, jacrev
import torch

def div_fn(u):
    J = jacrev(u)
    return lambda x: torch.trace(J(x))

def fnc(x,y,z):
    """
    x: [bs, 4]
    y: [b,1]
    z: [bs,1]

    output: [bs,1]
    """
    x= x.reshape(bs,-1)
    y= y.reshape(bs,-1)
    z= z.reshape(bs,-1)
    return z+y+x.sum(-1)[...,None]


bs=512
x=torch.randn(bs,4)
y=torch.randn(bs,1)
z=torch.randn(bs,1)


vecfield=fnc(x,y,z)
vecfield = lambda yy: fnc(x,yy,z)
div = vmap(div_fn(vecfield))(y)
print(div)

Thank you!

Hi @iamctR,

Here’s a working example, but some parts of your example weren’t 100% clear to me. So this might not be what you want, but it should help you understand how torch.func works.

When working with torch.func, you need to create a function which is independent of your batch-dim, i.e. the function should work fine with 1 sample, so you can vectorize over all samples via torch.func.vmap.

So, remove .reshape(bs,-1) for example, as within the vmap call the batch size is technically 1.

from torch.func import vjp, jvp, vmap, jacrev
import torch

def fnc(x,y,z):
	inp = torch.cat([x,y,z],dim=-1)
	return torch.sum(inp,dim=-1)
	
def div_fn(x,y,z):
	J = jacrev(fnc,argnums=1)(x,y,z) #grad w.r.t y
	return J #torch.trace(J)

bs=512
x=torch.randn(bs,4)
y=torch.randn(bs,1)
z=torch.randn(bs,1)

div=vmap(div_fn, in_dims=(0,0,0))(x,y,z)
print(div.shape) #returns shape [512,1]
1 Like

Thank you sooo much!