How to temporarily freeze a layer?

I’d like to minimize

where

image.

To be more precise, at each iteration, U receives the values of W_2. The gradient update at each iteration won’t change U.

To implement this in pytorch, I wrote

import torch
import torch.nn as nn
x = torch.randn([1,10])
layer1 = nn.Linear(10,10)
layer2 = nn.Linear(10,10)

y1 = layer2(layer1(x))
z2 = layer1(x)

with torch.no_grad():
    y2 = layer2(z2)



y = y1 + y2
loss = torch.norm(y)
print(y)

But I’m not sure if this is correct implementation of the above idea.

Here’a code excerpt

model = // define your model here
for param in model.parameters():
     param.requires_grad = False

This freezes all the layers of the model. To freeze only a portion of it, you can do conditioning in the loop.