Masking a forward pass for back propagation


I want to do the following thing: I run my model N times on some input data and then use backpropagation to update the model (using some loss function). Now in this backward pass, I want to consider only specific runs out of my N forward passes. To be clear: I don’t want to mask single parameters, but single forward passes over parameters. I thought that this might be possible with some hook, but I can’t figure it out. Is it possible?

Here is a minimal script:

import torch
import numpy as np

# Define simple model
class model(torch.nn.Module):
    def __init__(self):
        self.layer = torch.nn.Linear(3, 1)

    def forward(self, x):
        return self.layer(x)

def loss(X):
    return torch.mean(torch.sum(X, dim=0))

X = torch.randn(5, 3)
# X[4] = np.nan

M = model()
res = M(X)
L = loss(res)

If you uncomment X[4] = np.nan, you’ll get NaN for all model weights. However, in this case I just want to exclude this from the backward pass, as if I had only run the model on X[:4].

Background: The reason to do this is that I have a model that is supposed to compute the probabilities for a categorical distribution. The way I set this up is that for a given sample in my data, I run the network ‘M’ times on M different input features vectors. The categorical distribution is over these M different inputs. However, I am using the same MLP for each of these M different inputs, giving me an (unnormalized) scalar value for each of the M inputs. Then I normalize them with a Softmax layer to obtain properly normalized probabilities.
Now, in my batch with K samples, some of the sample actually do have M different input feature vectors, some have less, in which case the “superfluous” feature vectors are filled with NaNs. The forward pass will thus also give me NaNs. To ignore them in the Softmax layer, I manually set them to 0, but then I want to ignore these values in the backward pass.

1 Like