I have the following problem: I have a bunch of parameters that I need to project into a the valid parameter space. I can only efficiently compute the valid parameter-space during the forward-pass (it changes with the parameters).

So I have computations like this:

``````lower_bound, upper_bound = compute_bound(input)
param[param < lower_bound] = lower_bound
param[param > upper_bound] = upper_bound
out = f(input, param)
``````

the thing is…I wonder what that means for the gradients for param? Do they get obstructed by the assignment? Because I just want to replace every value out of bounds with the nearest bound, but still optimise it.

I think your use case should work if you wrap the inplace manipulations of `param` into a `with torch.no_grad()` guard as seen here:

``````import torch
import torch.nn as nn
import torch.nn.functional as F

lower_bound, upper_bound = 0., 1.
param = nn.Parameter(torch.randn(10, 10) * 10.)
f = F.linear

for epoch in range(10):
print(epoch)
input = torch.randn(10, 10)
print(param.min(), param.max())