Custom loss and its gradient

Hi,

I am trying to implement a simple neural network using a custom loss. My loss is defined as -

sum_i=1^{n} ( f_i(x_i, y_i, \theta) * \mu + \psi(x_i, y_i, \theta, \mu)) s.t. \psi(…) = \nu where g(\nu, x_i, y_i, \theta, \mu) = 1

Here, \theta are the network parameters until the second last layer (feature mapping) and \mu are the parameters of the last layer (that combine with the features to obtain the logits). As you can see, the loss is complicated (with constraints) but I can obtain the gradient for the last layer (corresponding to \mu) using implicit differentiation.

So, basically my question is

  1. How to pass the gradients? From what I read, it seems doing model.mu.grad = ... does the work.
  2. After I pass the gradient for \mu (weights of the last layer), can I call the .backward() to compute the gradients for the rest of the network weights?

Do the above two steps correctly train the network? I ran a small test and seems like it is not working. My loss is not decreasing along the epochs. Here is the code that I am using -

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import torch.nn.functional as F
from Optimization.utils import *

class CustomNN(nn.Module):
    def __init__(self, input_size, n_features, n_classes):
        super().__init__()
        self.hidden = nn.Linear(input_size, n_features)
        self.relu = nn.ReLU()
        self.output = nn.Linear(n_features, n_classes)  # we'll override this layer's gradient

    def forward(self, x):
        x = self.relu(self.hidden(x))
        return self.output(x)

# Compute the gradient with respect to mu in the last layer (this corresponds to MRCs)
def mrcs_grad(phi, mu, phi_mu, labels, n_classes, beta):
    """
    phi: Feature map from the neural network of size (batch_size, num_classes)
    mu: MRCs parameters given by the last layer of the neural network (num_classes, 1)
    phi_mu: The output of the last layer of the neural network which is combination of phi.T @ mu (n_classes,1)
    labels: List of labels corresponding with batch of samples (batch_size, 1)
    """
    # Add the bias to the feature mapping phi
    batch_size = phi.shape[0]
    bias = torch.tensor([[1.]] * batch_size)
    phi_with_intercept = torch.cat((bias, phi), dim=1)
    tau_ = []
    for y_i in range(n_classes):
        tau_.append((1 / batch_size) * phi_with_intercept[labels == y_i, :].sum(dim=0))

    tau_ = torch.stack(tau_)

    psi_grad = torch.zeros((phi_with_intercept.shape[1], n_classes), dtype=torch.float64)
    for i, phi_mu_i in enumerate(phi_mu):
        # Use bisection to find the nu values
        nu_i, iters = bisection((phi_mu_i.detach().numpy(), beta))

        # Compute teh psi function
        psi_beta = ((phi_mu_i + nu_i) / beta) + 1

        psi_grad_i = torch.zeros((phi_with_intercept.shape[1], n_classes), dtype=torch.float64)
        mask = psi_beta > 0
        sum_psi_beta = torch.sum(torch.pow(psi_beta[mask], beta - 1))
        # sum_psi_beta = np.sum(np.power(psi_beta[np.clip(psi_beta, 0., None) != 0], beta - 1))

        for j in range(n_classes):
            if psi_beta[j] > 0:
                psi_grad_i[:, j] -= (torch.pow(psi_beta[j], beta - 1) * phi_with_intercept[i, :]) / sum_psi_beta

        psi_grad = psi_grad + psi_grad_i

    psi_grad = ((1 / batch_size) * psi_grad)

    return -tau_ + psi_grad.T

if __name__=="__main__":

    alpha = 1.05
    beta = alpha / (alpha - 1)

    # Transform to convert PIL images to tensors
    transform = transforms.ToTensor()

    # Load full MNIST training set
    full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    n_classes = len(full_dataset.classes)
    sample, label = full_dataset[0]
    a, b, c = sample.shape
    input_dim = a * b * c

    # Take only the first 100 samples
    small_dataset = Subset(full_dataset, range(320))

    # DataLoader for mini-dataset
    train_loader = DataLoader(small_dataset, batch_size=32, shuffle=True)

    # Define the model
    # Note that the outputs are logits and the output size will be the number of classes
    model = CustomNN(input_size = input_dim, n_features = 128, n_classes = n_classes)

    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    # Train the model with custom loss
    num_epochs = 5

    # Training with MRCs loss using custom gradients.
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0
        for inputs, labels in train_loader:
            # Reset gradients
            optimizer.zero_grad()

            X = []
            for input in inputs:
                X.append(input.flatten())

            X = torch.stack(X, dim=0)

            # Forward pass
            outputs = model(X) # this is the output phi.T@mu

            # Backprop only to get intermediate gradients (not final layer, later we override that)
            outputs.backward(torch.zeros_like(outputs), retain_graph=True)

            # Forward pass until second last layer to obtain the feature mapping phi
            phi = model.relu(model.hidden(X))
            mu = model.output.weight

            # Compute the gradient of the last layer that corresponds to mu
            grad_mu = mrcs_grad(phi, mu, outputs, labels, n_classes, beta)

            # Assign manually computed gradients to last layer
            # weights
            model.output.weight.grad = grad_mu[:, 1:].clone().detach().to(dtype=model.output.weight.dtype, device=model.output.weight.device)
            # biases
            model.output.bias.grad = grad_mu[:, 0].clone().detach().to(dtype=model.output.weight.dtype, device=model.output.weight.device)

            # Call optimizer step
            optimizer.step()

            # Check the loss
            loss = F.cross_entropy(outputs, labels)
            running_loss = running_loss + loss.item()
        print(f"Epoch {epoch + 1}, Log Loss: {running_loss / len(train_loader):.4f}")
        print(torch.norm(grad_mu))
        print(model.output.weight.requires_grad)```

but I can obtain the gradient for the last layer (corresponding to \mu) using implicit differentiation.

If you have custom gradients you can use a custom autograd.Function Extending PyTorch — PyTorch 2.7 documentation
to have it interoperate with the rest of autograd.