Jacobian matrix of right shape, but null everywhere

Dear Pytorch forum,
For a project, I was tasked with the simulation of an MRI sequence using EPG formalism. The latter was a success. For the optimization of the sequence, it is necessary to differentiate the sequence signal with regards to different parameters at play. For the latter, I initially employed the autograd.grad function, a method which proved to take a lot more time due to the need for the implementation of loops, if I understood it correctly (see this link).
Consequentially, I move to the usage of autograd.functional.jacobian. After some initial tests, I was able to implement it in my code such as to obtain a 500x3 (a.k.a. (time points) x (number of parameters)) array. The problem is though, that the latter is null in all points. Am I missing something conceptually with how the jacobian is to be computed with the pytorch function? Currently, the output() functions spits out the whole time-dependent array of the signal for the whole sequence. Could this be the cause of the issue?
Code is pasted below, and further clarifications will be provided with haste should you find my explanation insufficient. Thanks in advance!

def dm_dT(params, func):
    values_array = []

    for key, value in params.items():
        if isinstance(value, torch.Tensor):
            values_array.append(value.item())
        else:
            values_array.append(value)

    jacobian = torch.autograd.functional.jacobian(
                func=func,
                inputs=torch.tensor(values_array, dtype = torch.float32),
                create_graph=True
                
            )        
    return jacobian

import numpy as np
import matplotlib.pyplot as plt
import torch




def output(params):
    T1, T2, D = params
    TE = 4
    TR = 15
    TI = 15  
    kg = 35
    T2 = torch.tensor(T2, dtype=torch.float32, requires_grad=True)
    T1 = torch.tensor(T1, dtype=torch.float32, requires_grad=True)
    D = torch.tensor(D, dtype=torch.float32, requires_grad=True)
    TI = torch.tensor(TI, dtype=torch.float32)
    TR = torch.tensor(TR, dtype=torch.float32)
    TE = torch.tensor(TE, dtype=torch.float32)
    kg = torch.tensor(kg, dtype=torch.float32)
    
    magnetization_states  = []
    transversal_data_imag = []
    transversal_data_real = []
    transversal_data_abs  = [] 

    FpFmZ = torch.zeros((3, 500), dtype=torch.complex64)
    FpFmZ[2, 0] = 1.

    FpFmZ = epg.epg_rf(FpFmZ, alpha=torch.tensor(np.pi, dtype=torch.float32), phi=torch.tensor(0, dtype=torch.float32))

    FpFmZ = epg.epg_relax(FpFmZ, T1, T2, TI)

    for a, FA in enumerate(FA_array[:500]):
        alpha = torch.tensor(FA, dtype=torch.float32, requires_grad=True)
        phi = torch.tensor(0, dtype=torch.float32, requires_grad=True)
        
        FpFmZ = epg.epg_rf(FpFmZ, alpha=alpha, phi=phi)    
        FpFmZ = epg.epg_grelax(FpFmZ, T1, T2, TE, kg, D, Gon=0)
        
        transversal_data_abs.append(torch.abs(FpFmZ.clone()[0,0]))
        transversal_data_imag.append(torch.real(FpFmZ.clone()[0,0]))
        transversal_data_real.append(torch.imag(FpFmZ.clone()[0,0]))
        
        magnetization_states.append(transversal_data_abs)    
        FpFmZ = epg.epg_grelax(FpFmZ, T1, T2, TR-TE, kg, D, Gon=1)
    
    # Assuming transversal_data_real is a list, convert it to a tensor
    transversal_data_real_tensor = torch.tensor(transversal_data_real, dtype=torch.float32)

    # Then detach and convert to numpy for plotting
    plt.plot(transversal_data_real_tensor.detach().numpy()) 
    
    return torch.tensor(transversal_data_real)

T1 = 300
T2 = 200
D = 0

grad_r = dm_dT(params={"T1": T1,"T2": T2,"D": D}, func=output)

print(grad_r)
      

Hey!

The problem is with the early part of your output() function.
You break the autograd link between the given “params” and the T1/T2 used below by re-wrapping them into new Tensors.
Deleting these 3 lines should fix it:


    T2 = torch.tensor(T2, dtype=torch.float32, requires_grad=True)
    T1 = torch.tensor(T1, dtype=torch.float32, requires_grad=True)
    D = torch.tensor(D, dtype=torch.float32, requires_grad=True)
1 Like

Hello Alban, thanks for the reply! And sorry for my late one…
After applying your changes, which made perfect sense to me, and which logically should have worked, I unfortunately still get the same result, even when considering one variable in isolation. Could it have something to do with the shape of the output making it impossible to get non-null values?
The function now takes 3 input parameters (T1, T2, D), and outputs the simulated signal at every time point for the whole sequence (500 elements).
The output jacobian has the right shape (500 x 3), such as would suggest the definition of a jacobian pasted below, but when scrolling through the documentation of the jacobian function, there were function definitions in the examples which would spit out an array. If this is the error, what kind of changes to the code would you suggest?

Update: I ran a small test to see how torch.autograd.functional.jacobian handles functions with multiple outputs, and it seems to give the expected result. Consequentially, I am now not really sure what could be causing out the nulling of the matrix elements of the computed jacobian…

If the value is still 0, I would expect something else is “breaking” the autograd graph, you can use torchviz · PyPI and pass it the output of func for the given sample inputs to see what is the autograd graph that gets generated. You can also give it the inputs to see if they are indeed properly linked into the graph.
Feel free to share the image of the graph here if you need help interpreting it.