Hi,
I am trying to implement a simple neural network using a custom loss. My loss is defined as -
sum_i=1^{n} ( f_i(x_i, y_i, \theta) * \mu + \psi(x_i, y_i, \theta, \mu)) s.t. \psi(…) = \nu where g(\nu, x_i, y_i, \theta, \mu) = 1
Here, \theta are the network parameters until the second last layer (feature mapping) and \mu are the parameters of the last layer (that combine with the features to obtain the logits). As you can see, the loss is complicated (with constraints) but I can obtain the gradient for the last layer (corresponding to \mu) using implicit differentiation.
So, basically my question is
- How to pass the gradients? From what I read, it seems doing
model.mu.grad = ...
does the work. - After I pass the gradient for \mu (weights of the last layer), can I call the
.backward()
to compute the gradients for the rest of the network weights?
Do the above two steps correctly train the network? I ran a small test and seems like it is not working. My loss is not decreasing along the epochs. Here is the code that I am using -
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import torch.nn.functional as F
from Optimization.utils import *
class CustomNN(nn.Module):
def __init__(self, input_size, n_features, n_classes):
super().__init__()
self.hidden = nn.Linear(input_size, n_features)
self.relu = nn.ReLU()
self.output = nn.Linear(n_features, n_classes) # we'll override this layer's gradient
def forward(self, x):
x = self.relu(self.hidden(x))
return self.output(x)
# Compute the gradient with respect to mu in the last layer (this corresponds to MRCs)
def mrcs_grad(phi, mu, phi_mu, labels, n_classes, beta):
"""
phi: Feature map from the neural network of size (batch_size, num_classes)
mu: MRCs parameters given by the last layer of the neural network (num_classes, 1)
phi_mu: The output of the last layer of the neural network which is combination of phi.T @ mu (n_classes,1)
labels: List of labels corresponding with batch of samples (batch_size, 1)
"""
# Add the bias to the feature mapping phi
batch_size = phi.shape[0]
bias = torch.tensor([[1.]] * batch_size)
phi_with_intercept = torch.cat((bias, phi), dim=1)
tau_ = []
for y_i in range(n_classes):
tau_.append((1 / batch_size) * phi_with_intercept[labels == y_i, :].sum(dim=0))
tau_ = torch.stack(tau_)
psi_grad = torch.zeros((phi_with_intercept.shape[1], n_classes), dtype=torch.float64)
for i, phi_mu_i in enumerate(phi_mu):
# Use bisection to find the nu values
nu_i, iters = bisection((phi_mu_i.detach().numpy(), beta))
# Compute teh psi function
psi_beta = ((phi_mu_i + nu_i) / beta) + 1
psi_grad_i = torch.zeros((phi_with_intercept.shape[1], n_classes), dtype=torch.float64)
mask = psi_beta > 0
sum_psi_beta = torch.sum(torch.pow(psi_beta[mask], beta - 1))
# sum_psi_beta = np.sum(np.power(psi_beta[np.clip(psi_beta, 0., None) != 0], beta - 1))
for j in range(n_classes):
if psi_beta[j] > 0:
psi_grad_i[:, j] -= (torch.pow(psi_beta[j], beta - 1) * phi_with_intercept[i, :]) / sum_psi_beta
psi_grad = psi_grad + psi_grad_i
psi_grad = ((1 / batch_size) * psi_grad)
return -tau_ + psi_grad.T
if __name__=="__main__":
alpha = 1.05
beta = alpha / (alpha - 1)
# Transform to convert PIL images to tensors
transform = transforms.ToTensor()
# Load full MNIST training set
full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
n_classes = len(full_dataset.classes)
sample, label = full_dataset[0]
a, b, c = sample.shape
input_dim = a * b * c
# Take only the first 100 samples
small_dataset = Subset(full_dataset, range(320))
# DataLoader for mini-dataset
train_loader = DataLoader(small_dataset, batch_size=32, shuffle=True)
# Define the model
# Note that the outputs are logits and the output size will be the number of classes
model = CustomNN(input_size = input_dim, n_features = 128, n_classes = n_classes)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Train the model with custom loss
num_epochs = 5
# Training with MRCs loss using custom gradients.
for epoch in range(num_epochs):
model.train()
running_loss = 0
for inputs, labels in train_loader:
# Reset gradients
optimizer.zero_grad()
X = []
for input in inputs:
X.append(input.flatten())
X = torch.stack(X, dim=0)
# Forward pass
outputs = model(X) # this is the output phi.T@mu
# Backprop only to get intermediate gradients (not final layer, later we override that)
outputs.backward(torch.zeros_like(outputs), retain_graph=True)
# Forward pass until second last layer to obtain the feature mapping phi
phi = model.relu(model.hidden(X))
mu = model.output.weight
# Compute the gradient of the last layer that corresponds to mu
grad_mu = mrcs_grad(phi, mu, outputs, labels, n_classes, beta)
# Assign manually computed gradients to last layer
# weights
model.output.weight.grad = grad_mu[:, 1:].clone().detach().to(dtype=model.output.weight.dtype, device=model.output.weight.device)
# biases
model.output.bias.grad = grad_mu[:, 0].clone().detach().to(dtype=model.output.weight.dtype, device=model.output.weight.device)
# Call optimizer step
optimizer.step()
# Check the loss
loss = F.cross_entropy(outputs, labels)
running_loss = running_loss + loss.item()
print(f"Epoch {epoch + 1}, Log Loss: {running_loss / len(train_loader):.4f}")
print(torch.norm(grad_mu))
print(model.output.weight.requires_grad)```