I am trying to implement the Flow Matching for Generative Modelling by Lipman, et.al. I am using the make_moons dataset in scikit_learn to test it out. As far as I can tell, I have everything implemented correctly, but the loss doesn’t improve over 20k epochs, and I can see no transformation in the input data to the target. I fixed a few bugs that I discovered, but the needle has not really moved a lot. Would be grateful for any feedback on what I am doing wrong.
def make_moons_data(batch_size=256, noise=0.1):
X, _ = make_moons(n_samples=batch_size, noise=noise)
return torch.tensor(X, dtype=torch.float32)
class NeuralVelocityField(nn.Module):
## A simple MLP to serve as the NN to use in the loss function
def __init__(self, input_dims: int, output_dims=None, hidden=32, time_=True) -> None:
input_dims: The input dimensions of the data
output_dims: The dimensions of the computed output
hidden: The initial starting size of the number of neurons
time_: Boolean variable to adjust the input_dims if the time is concatenated
if input_dims is None or input_dims <=1:
raise AssertionError("Input dimensions cannot be None and must be greater than 1")
assert isinstance(input_dims, int)
assert isinstance(time_, bool)
if output_dims is None:
output_dims = input_dims
assert isinstance(output_dims, int)
self. model = nn.Sequential(
nn.Linear(2+(1 if time_ else 0), hidden),
nn.Linear(hidden, hidden),
def forward(self, x: torch.Tensor) -> torch.Tensor:
x: Input to the model of type torch.Tensor
return: The output of the forward pass computation of the model
return self.model(x)
class ConditionalFlowMatching(nn.Module):
def __init__(self, sigma):
self.sigma = sigma if sigma is not None else 0.1
def sample_base_distribution(self, x1: torch.Tensor) -> torch.Tensor:
x1: Sample from the target distribution
x_0: A sample from the base distribution which in this case is the unit Normal
#Source distribution -- sample from zero mean and unit variance Gaussian
x_0 = torch.randn_like(x1)
return x_0
def __compute_mu_t(self, t: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
t: Time variable from the U[0,1] distribution
x1: Sample from the target distribution
mu_t: The time varying mean
#Equation 20 in Lipman FM paper
#t = torch.reshape(t, (t.shape[0], 1))
mu_t = t*x1
#print("mu_t is {}".format(mu_t))
return mu_t
def __compute_sigma_t(self, t: torch.Tensor) -> torch.Tensor:
t: Time variable from the U[0,1] distribution
sigma_t: The time varying standard deviation
#t = torch.reshape(t, (t.shape[0], 1))
#Equation 20 in Lipman FM paper
sigma_t = 1.-(1.-self.sigma)*t
#print("sigma_t is {}".format(sigma_t))
return sigma_t
def compute_transformed_data(self, t: torch.Tensor, x1: torch.Tensor, x0: torch.Tensor) -> torch.Tensor:
t: Time variable from the U[0,1] distribution
x1: Sample from the target distribution
x0: Sample from the source distribution which in this case is the unit Normal
x_t: Sample x_0 after the push forward operation at time t
mu_t = self.__compute_mu_t(t, x1)
sigma_t = self.__compute_sigma_t(t)
#Equation 22-- x_t is the source data after the push forward transformation considering linear interpolation
x_t = x0*sigma_t+mu_t
return x_t
def compute_conditional_vel_field(self, t: torch.Tensor, x1: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
t: Time variable from the U[0,1] distribution
x1: Sample from the target distribution
x_t: Sample x_0 after the push forward operation at time t
u_t: The conditional velocity field in closed form
numerator = x1-(1.-self.sigma)*x_t
denominator = self.__compute_sigma_t(t)
#Equation 21 in Lipman FM paper
u_t = torch.div(numerator, denominator)
return u_t
def plot_flow_trajectory(epoch_number: int, t: torch.Tensor, data_dict: dict) -> None:
assert isinstance(epoch_number, int)
assert isinstance(t, torch.Tensor)
assert isinstance(data_dict, dict)
with torch.no_grad():
##Convert data to numpy format
for key in data_dict.keys():
if data_dict[key] is None:
print("Value with key {} is of None type in the data dict".format(key))
raise AssertionError
data = data_dict[key]
data = data.numpy()
data_dict[key] = data
plt.title("Source, target and transf: ormed data at epoch {}".format(epoch_number))
##Plot the source and target first
plt.plot(data_dict["source"][:, 0], data_dict["source"][:, 1], "b.", label="source")
#This should be the moons data
plt.plot(data_dict["target"][:, 0], data_dict["target"][:, 1], "r*", label="target")
plt.plot(data_dict["transformed_data"][:, 0], data_dict["transformed_data"][:, 1], "kx", label="transformed")
def training(batch_size=128, num_epochs=None, learning_rate = None):
#TODO: Plot loss
loss_list = []
num_epochs = 20000 if num_epochs is None else num_epochs
lr = 1e-3 if learning_rate is None else learning_rate
velocity_model = NeuralVelocityField(input_dims=2, hidden=64)
#Default lr is 1e-3
optimizer = torch.optim.Adam(velocity_model.parameters(), lr=lr)
cfm = ConditionalFlowMatching(sigma=0.1)
#Log this when writing .py
print("Learning rate is {}".format(lr))
print("Batch size is {}".format(batch_size))
print("Model will be trained for {} epochs".format(num_epochs))
print("Starting training now")
for epoch in range(num_epochs):
#x1 is the target dataset.
#In the Lipman Flow Matching paper, it is used to condition the source, i.e., p_t(x0|x1)
x1 = make_moons_data(batch_size)
#Source distribution -- sample zero mean and unit variance
x_0 = cfm.sample_base_distribution(x1)
#Sample time from the uniform distribution
t = torch.rand([x1.shape[0], 1])
x_t = cfm.compute_transformed_data(t, x1, x_0)
assert(x_t.shape == x1.shape)
#Compute the velocity field
u_t = cfm.compute_conditional_vel_field(t, x1, x_t)
#Neural network to compute the velocity field
#v is a function of time and space. Hence the need to compute v(x_t,t)
v = velocity_model(torch.cat((x_t, t), dim=-1))
#Compute the mean squared error between the conditional velocity and the neural network
loss = torch.mean(torch.pow(v-u_t, 2))
#if (epoch + 1) % 100 == 0:
# for name, param in velocity_model.named_parameters():
# if param.grad is None:
# print("None valued gradients at epoch {}".format(epoch+1))
data_dict = {"source":x_0, "target":x1, "transformed_data":x_t}
if (epoch+1)%1000 == 0:
print("Finished epoch number {}".format(epoch+1))
print("Loss is {}".format(loss.item()))
plot_flow_trajectory(epoch+1, t, data_dict)
I have tried varying the learning rate, the capacity of the network in terms of the number of neurons. I haven’t added more layers. I checked for gradients, they are indeed small, but not None.