Loss does not improve in Flow Matching implementation

I am trying to implement the Flow Matching for Generative Modelling by Lipman, et.al. I am using the make_moons dataset in scikit_learn to test it out. As far as I can tell, I have everything implemented correctly, but the loss doesn’t improve over 20k epochs, and I can see no transformation in the input data to the target. I fixed a few bugs that I discovered, but the needle has not really moved a lot. Would be grateful for any feedback on what I am doing wrong.

def make_moons_data(batch_size=256, noise=0.1):
   
    X, _ = make_moons(n_samples=batch_size, noise=noise)
    return torch.tensor(X, dtype=torch.float32)

class NeuralVelocityField(nn.Module):
## A simple MLP to serve as the NN to use in the loss function
    def __init__(self, input_dims: int, output_dims=None, hidden=32, time_=True) -> None:
        """
        input_dims: The input dimensions of the data
        output_dims: The dimensions of the computed output
        hidden: The initial starting size of the number of neurons
        time_: Boolean variable to adjust the input_dims if the time is concatenated
        """
        super().__init__()
      
        
        if input_dims is None or input_dims <=1:
            raise AssertionError("Input dimensions cannot be None and must be greater than 1")

        assert isinstance(input_dims, int)
        assert isinstance(time_, bool)
        
        if output_dims is None:
            output_dims = input_dims
        else:
            assert isinstance(output_dims, int)
        
        self. model = nn.Sequential(
                 nn.Linear(2+(1 if time_ else 0), hidden),
                 nn.ReLU(),
                 nn.Linear(hidden,hidden),
                 nn.ReLU(),
                 nn.Linear(hidden,hidden),
                 nn.ReLU(),
                 nn.Linear(hidden, hidden), 
                 nn.ReLU(),
                 nn.Linear(hidden,output_dims)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: Input to the model of type torch.Tensor
        return: The output of the forward pass computation of the model
        """
        return self.model(x)

class ConditionalFlowMatching(nn.Module):

    def __init__(self, sigma):
        super().__init__()
        
        self.sigma = sigma if sigma is not None else 0.1

    def sample_base_distribution(self, x1: torch.Tensor) -> torch.Tensor:
       """
       x1: Sample from the target distribution
       x_0: A sample from the base distribution which in this case is the unit Normal
       """
       #Source distribution -- sample from  zero mean and unit variance Gaussian
       x_0 = torch.randn_like(x1)
       return x_0

    
    def __compute_mu_t(self, t: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
        """
        t: Time variable from the U[0,1] distribution
        x1: Sample from the target distribution
        mu_t: The time varying mean
        """
        #Equation 20 in Lipman FM paper
        #t = torch.reshape(t, (t.shape[0], 1))
        mu_t = t*x1
        #print("mu_t is {}".format(mu_t))
        return mu_t

    def __compute_sigma_t(self, t: torch.Tensor) -> torch.Tensor:
        """
        t: Time variable from the U[0,1] distribution
        sigma_t: The time varying standard deviation
        """
        
        #t = torch.reshape(t, (t.shape[0], 1))
        #Equation 20 in Lipman FM paper
        sigma_t = 1.-(1.-self.sigma)*t
        #print("sigma_t is {}".format(sigma_t))
        return sigma_t

    def compute_transformed_data(self, t: torch.Tensor, x1: torch.Tensor, x0: torch.Tensor) -> torch.Tensor:
        """
        t: Time variable from the U[0,1] distribution
        x1: Sample from the target distribution
        x0: Sample from the source distribution which in this case is the unit Normal
        x_t: Sample x_0 after the push forward operation at time t
        """
        mu_t = self.__compute_mu_t(t, x1)
        sigma_t = self.__compute_sigma_t(t)
        #Equation 22-- x_t is the source data after the push forward transformation considering linear interpolation
        x_t = x0*sigma_t+mu_t
        return x_t

    def compute_conditional_vel_field(self, t: torch.Tensor, x1: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
        """
        t: Time variable from the U[0,1] distribution
        x1: Sample from the target distribution
        x_t: Sample x_0 after the push forward operation at time t
        u_t: The conditional velocity field in closed form
        """
        numerator = x1-(1.-self.sigma)*x_t
        denominator = self.__compute_sigma_t(t)
        #Equation 21 in Lipman FM paper
        u_t = torch.div(numerator, denominator)
        return u_t

def plot_flow_trajectory(epoch_number: int, t: torch.Tensor, data_dict: dict) -> None:
   
    assert isinstance(epoch_number, int)
    assert isinstance(t, torch.Tensor)
    assert isinstance(data_dict, dict)
    
    with torch.no_grad():
        
        ##Convert data to numpy format
        for key in data_dict.keys():
            
            if data_dict[key] is None:
                print("Value with key {} is of None type in the data dict".format(key))
                raise AssertionError
        
            data = data_dict[key]
            data = data.numpy()
            data_dict[key] = data
        
        plt.figure()
        plt.title("Source, target and transf: ormed data at epoch {}".format(epoch_number)) 
            
        ##Plot the source and target first
        plt.plot(data_dict["source"][:, 0], data_dict["source"][:, 1], "b.", label="source")
        #This should be the moons data
        plt.plot(data_dict["target"][:, 0], data_dict["target"][:, 1], "r*", label="target")

        plt.plot(data_dict["transformed_data"][:, 0], data_dict["transformed_data"][:, 1], "kx", label="transformed")
        plt.legend()

        plt.show()

def training(batch_size=128, num_epochs=None, learning_rate = None):

    #TODO: Plot loss
    loss_list = []
    
    num_epochs = 20000 if num_epochs is None else num_epochs
    lr = 1e-3 if learning_rate is None else learning_rate  

    velocity_model = NeuralVelocityField(input_dims=2, hidden=64)
    #Default lr is 1e-3
    optimizer = torch.optim.Adam(velocity_model.parameters(), lr=lr)

    
    cfm = ConditionalFlowMatching(sigma=0.1)

    #Log this when writing .py 
    print("Learning rate is {}".format(lr))
    print("Batch size is {}".format(batch_size))
    print("Model will be trained for {} epochs".format(num_epochs))
    print("Starting training now")
    
    for epoch in range(num_epochs):
        optimizer.zero_grad()       
        #x1 is the target dataset. 
        #In the Lipman Flow Matching paper, it is used to condition the source, i.e., p_t(x0|x1)
        x1 = make_moons_data(batch_size)
    
        #Source distribution -- sample zero mean and unit variance
        x_0 = cfm.sample_base_distribution(x1)
        
        #Sample time from the uniform distribution
        t = torch.rand([x1.shape[0], 1])
       
        

        x_t = cfm.compute_transformed_data(t, x1, x_0)
        assert(x_t.shape == x1.shape)

        #Compute the velocity field
        u_t = cfm.compute_conditional_vel_field(t, x1, x_t)
        
        #Neural network to compute the velocity field
        #v is a function of time and space. Hence the need to compute v(x_t,t)
        v = velocity_model(torch.cat((x_t, t), dim=-1))

        #Compute the mean squared error between the conditional velocity and the neural network
        loss = torch.mean(torch.pow(v-u_t, 2))

        loss.backward()
        optimizer.step()
        
        #if (epoch + 1) % 100 == 0:
        #    for name, param in velocity_model.named_parameters():
        #        if param.grad is None:
        #           print("None valued gradients at epoch {}".format(epoch+1))
        
        data_dict = {"source":x_0, "target":x1, "transformed_data":x_t}        
        if (epoch+1)%1000 == 0:
            print("Finished epoch number {}".format(epoch+1))
            print("Loss is {}".format(loss.item()))
            plot_flow_trajectory(epoch+1, t, data_dict)
                
training()

I have tried varying the learning rate, the capacity of the network in terms of the number of neurons. I haven’t added more layers. I checked for gradients, they are indeed small, but not None.