Consider the following example:
import torch
import torch.nn as nn
import torch.optim as optim
class RegressionModel(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Linear(5, 20),
nn.CELU(),
nn.Linear(20, 20),
nn.CELU(),
nn.Linear(20, 1)
)
def forward(self, x):
return self.network(x)
n_samples = 10**4
mini_batch = 500
d_in = 5
torch.manual_seed(1234)
x_train = torch.randn(n_samples,d_in) # Evaluate the probability density function at each sample point
y_train = torch.norm(x_train, dim=1, p = 4)
dataset_train = torch.utils.data.TensorDataset(x_train,y_train)
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=mini_batch, shuffle=True)
I want to compute gradients with respect to the input argument during training. I had originally though tot implement it as:
x_train.requires_grad_(True)
n_epochs = 100
n_print =10
lr = 0.01
torch.manual_seed(567)
u = RegressionModel()
mse_loss = nn.MSELoss()
optimizer = optim.SGD(u.parameters(), lr=lr)
u.train() # Set the model to training mode
for epoch in range(n_epochs):
for batch_x, batch_y in dataloader_train:
optimizer.zero_grad()
y_pred = u(batch_x)
loss = mse_loss(y_pred,batch_y)
loss.backward()
optimizer.step()
y_pred = u(x_train)
gradients = torch.autograd.grad(outputs=y_pred, inputs=x_train, grad_outputs=torch.ones_like(y_pred))[0]
with torch.no_grad():
loss_train = mse_loss(y_pred, y_train).item()
if (epoch+1) % n_print == 0:
print(f'epoch {epoch+1}: train = {loss_train:>4g}')
But I found this to be really slow. If, in contrast, I use this code:
n_epochs = 100
n_print =10
lr = 0.01
torch.manual_seed(567)
u = RegressionModel()
mse_loss = nn.MSELoss()
optimizer = optim.SGD(u.parameters(), lr=lr)
u.train() # Set the model to training mode
for epoch in range(n_epochs):
for batch_x, batch_y in dataloader_train:
optimizer.zero_grad()
y_pred = u(batch_x)
loss = mse_loss(y_pred,batch_y)
loss.backward()
optimizer.step()
x_train.requires_grad_(True) # turn on gradient computation
y_pred = u(x_train)
gradients = torch.autograd.grad(outputs=y_pred, inputs=x_train, grad_outputs=torch.ones_like(y_pred))[0]
x_train.requires_grad_(False) # turn off gradient computation
with torch.no_grad():
loss_train = mse_loss(y_pred, y_train).item()
if (epoch+1) % n_print == 0:
print(f'epoch {epoch+1}: train = {loss_train:>4g}')
It is much faster. But I’m not sure if this is “right”, in the sense that I need to constantly switch requires_grad_
back and forth. Any advice would be appreciate.