Hi,

I’m working on a torch.nn forward calculation that only uses a subset of a trained torch.nn.Parameter tensor (which subset varies based on the input value), but I have so far been unable to work out a procedure for only selecting and using a subset of the parameter values without breaking the computational graph. As far as I understand, I should be able to use torch.gather() to accomplish this, but the output is consistently detached from the comp graph in my code. Interestingly, the following code, a stand-alone script which effectively mirrors the relevant portion of my implementation, returns a tensor that maintains the computational graph. I’m wondering if anyone can shed light on what is resulting in different output in each case.

Working Code:

```
import torch
# Original parameter tensor
params = torch.nn.Parameter(torch.randn((2, 3, 4)))
# Create tensor with values corresponding to index positions. Will be used/explained below.
index_temp = torch.arange(0, 4, 1).unsqueeze(0).unsqueeze(0).expand((5, 3, 4))
# Parameter selection mask. Dim=0 represents unique inputs (e.g., here 5) and dim=1
# (i.e., size 3) and dim=2 (i.e., size 4) correspond to dims 1 and 2 in params.
# Identifies which subset of params are applicable to each input for given forward pass.
mask = torch.Tensor([
[[True, True, True, False],
[False, True, True, True],
[True, True, True, False]],
[[False, True, True, True],
[False, True, True, True],
[True, True, True, False]],
[[True, True, True, False],
[False, True, True, True],
[False, True, True, True]],
[[True, True, True, False],
[False, True, True, True],
[True, True, True, False]],
[[True, True, True, False],
[False, True, True, True],
[True, True, True, False]]
]).bool()
# Transform mask into tensor containing list of indices in last dim identifying
# applicable parameter values
mask = index_temp[mask].view(mask.size(0), mask.size(1), -1)
# Expand mask so dims 1-3 match params dims
mask = mask.unsqueeze(1).expand(-1, 2, -1, -1)
# Original parameter tensor expanded at dim=0 for number of inputs (i.e., 5)
selected_params = params.unsqueeze(0).expand(mask.size(0), -1, -1, -1)
# Select parameters applicable to each input with torch.gather()
selected_params = torch.gather(selected_params, -1, mask)
print(selected_params.requires_grad) # Returns True
```

Here is a snippet of the relevant code from my implementation that returns a tensor detached from the computational graph. Note that I have replaced some class data members (e.g., self.in_features, self.out_features) with static values below for easy comparison between the code above and below, but thought I’d note that in case it might be relevant to the issue:

```
# Original parameter tensor is a torch.nn.Parameter
# Create tensor with values corresponding to index positions.
index_temp = torch.arange(0, 4, 1).unsqueeze(0).unsqueeze(0).expand((5, 3 ,4))
# Parameter selection mask. Results in bool tensor similar to that
# explicitly defined in working example above.
mask = (index_temp >= index_floor) & (index_temp <= index_ceiling)
# Transform mask into tensor containing list of indices in last dim identifying
# applicable parameter values. Note that "4" below equals size of the last dim in
# params.
mask = index_temp[mask].view(mask.size(0), mask.size(1), -1)
# Expand mask so dims 1-3 match params dims
mask = mask.unsqueeze(1).expand(-1, 2, -1, -1)
# Original parameter tensor expanded at dim=0 for number of inputs (i.e., 5)
selected_params = params.unsqueeze(0).expand(mask.size(0), -1, -1, -1)
# Select parameters applicable to each input with torch.gather()
selected_params = torch.gather(selected_params, -1, mask)
```

Prior to the torch.gather() operation, selected_params in my implementation still show requires_grad=True. However, after executing torch.gather(), that value changes to False.