I am copying a physics-informed neural network (PINN) architecture found in the following paper by Andres Beltran-Pulido et al: Physics-Informed Neural Networks for Solving Parametric Magnetostatic Problems | IEEE Journals & Magazine | IEEE Xplore
I have implemented the Encoding layer as described, and the layer works as expected in the forward pass.
class Encoding(nn.Module):
def __init__(self, m=3):
super().__init__()
self.m = m
def forward(self, data, norm_params):
x, y = data[0], data[1]
Lx, Ly = norm_params[-2], norm_params[-1]
phi_x = data.new_ones(size=(self.m*2+1,))
phi_y = data.new_ones(size=(self.m*2+1,))
j = torch.arange(1, self.m+1)
phi_x[2*j-1] = torch.cos(2*torch.pi*x*j/Lx)
phi_x[2*j] = torch.sin(2*torch.pi*x*j/Lx)
phi_y[2*j-1] = torch.cos(2*torch.pi*y*j/Ly)
phi_y[2*j] = torch.sin(2*torch.pi*y*j/Ly)
phi_vec = torch.outer(phi_y, phi_x).view(-1) #equivalent to vec(phix*phiy^T) as it stacks columns of phi_vec
return phi_vec
Because this is a PINN for solving a PDE, I need the Hessian of the output wrt the input for each input individually. So, for a batch size of 32, I would want a 32x2x2 Hessian tensor, where the h[i,:,:]
slice is the 2x2 Hessian for the data[i,:]
. The forward pass of the model is called using vmap
so that I could vmap
the Hessian call as well. The following call works properly:
layer = torch.func.vmap(Encoding())
data = torch.rand(32, 2)
params = torch.rand(32, 4)
layer(data, params).shape
However, trying to get the Hessian of this layer using h = torch.func.hessian(model)(data, params)
fails at the line phi_x[2*j-1] = ...
with following error:
RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Double for the source.
I have double checked the dtypes of x
, Lx
, j
, phi_x
, phi_x[2*j-1]
, and torch.cos(2*torch.pi*x*j/Lx)
, and they are all torch.float32
. I have also explicitly declared everything with the .float()
method, on all of the variables in this snippet. It still throws this error when I run it.
Does the error message not refer to the dtypes as I understand it? Why does it fail in the Hessian call (based on the error log it is in the jacfwd
call so it doesn’t even get to the jacrev
call) but not the forward pass? Any insight into what is happening under the hood would be much appreciated.