I want to train a network using a modified loss function that has both a typical classification loss (e.g. `nn.CrossEntropyLoss`

) as well as a penalty on the Frobenius norm of the end-to-end Jacobian (i.e. if f(x) is the output of the network, \nabla_x f(x)).

I’ve implemented a model that can successfully learn using `nn.CrossEntropyLoss`

. However, when I try adding the second loss function (by doing two backwards passes), my training loop runs, but the model never learns. Furthermore, if I calculate the end-to-end Jacobian, but don’t include it in the loss function, the model also never learns. At a high level, my code does the following:

- Forward pass to get predicted classes,
`yhat`

, from inputs`x`

- Call
`yhat.backward(torch.ones(appropriate shape), retain_graph=True)`

- Jacobian norm =
`x.grad.data.norm(2)`

- Set loss equal to classification loss + scalar coefficient * jacobian norm
- Run
`loss.backward()`

I suspect that I’m misunderstanding how `backward()`

works when run twice, but I haven’t been able to find any good resources to clarify this.

Too much is required to produce a working example, so I’ve tried to extract the relevant code:

```
def train_model(model, train_dataloader, optimizer, loss_fn, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
train_loss = 0
correct = 0
for batch_idx, (batch_input, batch_target) in enumerate(train_dataloader):
batch_input, batch_target = batch_input.to(device), batch_target.to(device)
optimizer.zero_grad()
batch_input.requires_grad_(True)
model_batch_output = model(batch_input)
loss = loss_fn(model_output=model_batch_output, model_input=batch_input, model=model, target=batch_target)
train_loss += loss.item() # sum up batch loss
loss.backward()
optimizer.step()
```

```
def end_to_end_jacobian_loss(model_output, model_input):
model_output.backward(
torch.ones(*model_output.shape),
retain_graph=True)
jacobian = model_input.grad.data
jacobian_norm = jacobian.norm(2)
return jacobian_norm
```