Calculate the Jacobian batch-wise of multiple input neural network with complex weights

I have have a neural network that takes the input x of shape [batch, timesteps, x_features] and the input p of shape [batch, p_features]. The output is of shape [batch, timesteps, out_features].

What I want to calculate is the Jacobian of the output with respect to p. So the Jacobian should be of shape [batch, timesteps, out_features, p_features].
Notice that the differentiation should happen for the out_features at every timestep. I could reshape them such that I have an output of shape [batch, timestep*x_features] but I want to omit this because of the calculations that follow…

In practice I use the following values:
batch = 1
timesteps = 601
x_features = 20
p_features = 14
out_features = 15

I tried two things:

1st approach:

def partial_forward(p):
    return model(x, p)

jacobian = autograd.functional.jacobian(partial_forward, p)

This gives me a Jacobian of shape [timesteps, out_features, batch, p_features] which does not seem correct and the calculation took also 16.5 seconds. One forward calculation usually takes 0.002sec, and I noticed that 1 x 601 x 14 x 0.002sec = 16.8 sec which seems a bit strange to me since I additionally have 15 out_features. Hence I tried a

2nd approach:

params = dict(model.named_parameters())

def fmodel(params, inputs):
    return functional_call(model, params, inputs)

result = vmap(jacrev(fmodel, argnums=(2)), in_dims=(None, None, 0))(params, x, p)

which I took from AlphaBetaGamma96 answer:
> Blockquote

but here I run into the error ‘RuntimeError: jacrev: Expected all inputs to be real but received complex tensor at flattened input idx: 4’. This is because in my neural network model I use the Fourier Transform which I rely on. But I also use the Inverse Fourier Transform and thus my input and my output of the neural network are real. It is just that some of the weights are complex.

Is the 2nd approach the correct way for what I mean to calculate and is there a workaround for this RuntimeError regarding complex numbers?

Can you share a minimal reproducible example of your model?

Yes of course:

class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1):
        super(SpectralConv1d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))

    def compl_mul1d(self, input, weights):
        # (batch, in_channels, timesteps), (in_channel, out_channel, timesteps) -> (batch, out_channel, timesteps)
        return torch.einsum('bix,iox->box', input, weights)

    def forward(self, x):
        # (batch, in_channels bzw. width, timesteps)
        batchsize = x.shape[0]

        x_ft = torch.fft.rfft(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat)
        # (batch, width, modes) * (width, width, modes) -> (batch, width, modes)
        out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1)

        # Return to physical space
        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return x

class FNO1d(nn.Module):
    def __init__(self, inp_dim_func, inp_dim_scalar, out_dim_func, modes, width,):
        super(FNO1d, self).__init__()

        self.modes1 = modes
        self.width = width
        self.linear_p = nn.Linear(in_features=inp_dim_func, out_features=self.width)
        self.linear_scalar = nn.Linear(in_features=inp_dim_scalar, out_features=self.width)

        self.spect1 = SpectralConv1d(2*self.width, self.width, self.modes1) # produces (batch, width, timesteps)
        self.lin0 = nn.Conv1d(2*self.width, self.width, 1)

        self.linear_q = nn.Linear(self.width, 32)
        self.output_layer = nn.Linear(32, out_dim_func)

        self.activation = torch.nn.Tanh()

    def fourier_layer(self, x, spectral_layer, conv_layer):
        return self.activation(spectral_layer(x) + conv_layer(x))

    def linear_layer(self, x, linear_transformation):
        return self.activation(linear_transformation(x))

    def forward(self, func_inp, scalar_inp):

        # shape (batch, timesteps, inp features)
        x = self.linear_p(func_inp)  # produces (batch, timesteps, width)
        x = x.permute(0, 2, 1) # produces (batch, width, timesteps)
        timesteps = x.shape[2]

        p = self.linear_scalar(scalar_inp)
        # Repeat scalar inputs for each timestep
        p = torch.unsqueeze(p, 1).repeat(1, timesteps, 1)  # Shape (batch_size, timesteps, 1)
        p = p.permute(0, 2, 1)
        # Concatenate processed inputs
        x1 = torch.cat((x, p), dim=1)

        x1 = self.fourier_layer(x1, self.spect1, self.lin0)
        x1 = x1.permute(0, 2, 1)

        x1 = self.linear_layer(x1, self.linear_q)
        x1 = self.output_layer(x1)
        return x1

model = FNO1d(inp_func_dim=20, inp_scalar_dim=14, res_func_dim=15, modes=16, width=64)

It relies heavily on FNO1d model but this example is specified for the data shapes I introduced earlier and it also has less layers.
The class FNO1d is the model. As one can see the weights1 are complex. This is due to the fact that we do the Fast Fourier Transform with torch.rfft. Doing such transform we end up in complex space and thus we need complex weights. After multiplication we transform back and end up in the real space and we can go on with further computations.

We have real inputs and outputs but we have some complex weights. There is also no way around this as the whole model relies on the Fourier Transform. Apparently autograd.functional.jacobian has no problem with that but there I do not know if the calculations are correct as the shape of the Jacobian is not as expected.

So, one problem I can see initially is that you’re using batch size inside your model, but when you vmap the batch size no longer exists (as your model only works on 1 sample at a time). That might be a problem.

Yes but it gets initialized within a forward call to allow for varying batch size.

As far as I understood I think I run into this problem discussed here https://github.com/pytorch/pytorch/issues/94397

So I would not be able to use jacrev but I would still be able to use autograd.functional.jacobian due to the calculations of the Wirtinger derivatives, meaning that my first approach valid. But this leaves me with the question about the correct shapes.

Is autograd.functional.jacobian able to differentiate an output of shape [1, 601, 15] with input [1, 14] ?
I end up with a Jacobian of shape [601, 15, 1, 14] where I would have expected [1, 601, 15, 14].

It is able to differentiate such output and one can get the Jacobian as described in https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/6.