Optimizing a Differentiable Loss Function for Uniform Distribution Convergence

My objective is to define a loss function that would measure, so to speak, the convergence of the model towards a uniform data distribution.

I generate K random fictitious observations and compare how many of them are smaller than the true data in the training set. In other words, the model generates K simulations, and we determine the number of simulations where the generated data is smaller than the actual data. If the model is well trained, this distribution should converge to a uniform distribution.

However, since the resulting histogram is not differentiable, I needed to approximate this idea using compositions of continuous differentiable functions.

def _sigmoid(y_hat, y):
    """Calculate the sigmoid function centered at y using PyTorch."""
    # return torch.relu((y - y_hat))
    return torch.sigmoid((y - y_hat) * 10.0)
    # return torch.nn.functional.softplus((y - y_hat))


def psi_m(y, m):
    """Calculate the bump function centered at m using PyTorch."""
    stddev = 0.1
    return torch.exp(-0.5 * ((y - m) / stddev) ** 2)


def phi(y_k, y_n):
    """Calculate the sum of sigmoid functions using PyTorch."""
    return torch.sum(_sigmoid(y_k, y_n))


def gamma(y_k, y_n, m):
    """Calculate the contribution to the m-th bin of the histogram using PyTorch."""
    # e_m = torch.tensor([1.0 if j == m else 0.0 for j in range(len(y_k) + 1)])
    e_m = torch.zeros(len(y_k) + 1, device=y_k.device)
    e_m[m] = 1.0
    return e_m * psi_m(phi(y_k, y_n), m)


def gamma_batch(y_k, y_n, m):
    """Calculate the contribution to the m-th bin of the histogram using PyTorch."""
    # e_m = torch.tensor([1.0 if j == m else 0.0 for j in range(len(y_k) + 1)])
    e_m = torch.eye(len(y_k) + 1, device=y_k.device, requires_grad=True)
    psi_m_partial = partial(psi_m, phi(y_k, y_n))
    return e_m * map(psi_m_partial, m)


@torch.compile
def generate_a_k(y_hat, y):
    """Calculate the values of the real observation y in each of the components of the approximate histogram using PyTorch."""
    K = len(y_hat)
    return torch.sum(torch.stack([gamma(y_hat, y, k) for k in range(K+1)]), dim=0)

@torch.compile
def scalar_diff(q):
    """Scalar difference between the vector representing our surrogate histogram and the uniform distribution vector using PyTorch."""
    # return torch.sum((q - 1/len(q)) ** 2)
    uniform_value = 1 / \
        q.size(0)  # Compute the uniform distribution value once
    # return torch.sum((q - uniform_value) ** 2)
    return torch.norm(q - uniform_value, p=2)

def invariant_statistical_loss(model, data_loader, hparams):
    optimizer = optim.Adam(model.parameters(), lr=hparams['eta'])
    losses = []
    for data in tqdm(data_loader):
        a_k = torch.zeros(int(hparams['K']) + 1)
        for y in data:
            optimizer.zero_grad()
            x_k = torch.normal(0.0, 1.0, size=(hparams['K'], 1))
            y_k = model(x_k)
            a_k += generate_a_k(y_k, y)
        loss = scalar_diff(a_k / torch.sum(a_k))
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return losses

So, generate_a_k(y_k, y) would return a vector of size K+1, where the element at index k is close to 1 if y is greater than k elements of the vector y_k.

Finally, scalar_diff calculates the squared difference between the calculated aₖ and the uniform distribution.

I would like to know if there is a way to make this code faster and/or vectorise it. I am not been able. Thank you very much in advance!