I’d like to minimize

where

.

To be more precise, at each iteration, U receives the values of W_2. The gradient update at each iteration won’t change U.

To implement this in pytorch, I wrote

```
import torch
import torch.nn as nn
x = torch.randn([1,10])
layer1 = nn.Linear(10,10)
layer2 = nn.Linear(10,10)
y1 = layer2(layer1(x))
z2 = layer1(x)
with torch.no_grad():
y2 = layer2(z2)
y = y1 + y2
loss = torch.norm(y)
print(y)
```

But I’m not sure if this is correct implementation of the above idea.