'None' gradients when passing output from one network to the input of another

Hi all,

I’m trying to build a custom network that uses the output of one network to transform the input of the second layer. The reason I’m doing this is because I want the network to warp the input in time and then find the best fit, learning both simultaneously.

However, the gradients of the first section are turning into ‘None’ values somewhere.

More precisely, I have inputs X (matrix) and z = 0:X.shape[0].

Specifically, the first layer is RELU(Linear(z)) and outputs a tensor ‘p’.
Then X and p are passed to the second stage, where each row of X, x[i], is shifted by the value p[i] to the right and then reduced to a vector via a column wise mean. Basically, its taking a matrix X and creating a vector x_shift, which is warped in time.

From there, I pass the shifted value ‘x_shift’ through feed forward layers and create the final output y. The loss function is a maximum likelihood function based on both y and p.

The problem is that the gradients for the first network stage are ‘None’. I don’t get any errors, just None values. However, when I edit the network and run each one separately everything is fine. I’m a bit stuck here, any help would be much appreciated!

Here is the relevant code:

# Training Loop
p_mod = PerturbationNetwork(1, 5, t=n)
model = ShiftNetwork(n=1, hidden_size=5, t=n)
optimizer = optim.Adam(list(model.parameters()) + list(p_mod.parameters()), lr=0.01)
for epoch in range(1000):
    # Forward pass 1
    p_out = p_mod(p)
    y_pred = model(X, p_out)
    loss = model.log_likelihood(y_pred, y, p_out)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Here is the forward pass for the first stage:

    def forward(self, x):
        # Create the perturbation profile
        p_i = self.relu(self.fc0(x))
        p_i = self.p_out(p_i)

        # Rescale
        p_scale = p_i*float(p_i.shape[0])
        p_scale = torch.tensor(p_scale, dtype=int)

        return p_scale

And here is the forward pass for the second stage:

    def forward(self, X, p):

        x_shift = self.shift_rows(X, p)
        x_shift  = self.nan_mean(x_shift)

        x_shift = warp.fill_nan_with_last_value(x_shift.squeeze(0))
        x_shift = x_shift[::self.sr]
        x_shift = x_shift[:len(p)]
        x_shift = x_shift.unsqueeze(1)
  
        h1 = self.relu(self.fc1(x_shift))
        h2 = self.relu(self.fc2(h1))
        h3 = self.relu(self.fc3(h2))
        y = self.fc4(h3)
        return y.squeeze(1)

The helper functions have been built using torch functions. I’ve tested them separately. They are:

    def shift_rows(self, X, p):
        # Get the dimensions of the input matrix
        num_rows, num_cols = X.size()

        # Create an empty tensor of the same shape as the input matrix
        shifted_X = torch.torch.full((X.shape[0],X.shape[1]), float('nan'))

        # Iterate over each row of the input matrix
        for i in range(0, (len(p)-1)):
            # Get the shift value for the current row
            shift_amount = int(p[i].item())  # Convert tensor to scalar

            # Perform the row shifting operation
            shifted_X[i*self.sr, :] = torch.roll(X[i*self.sr, :], shifts=shift_amount, dims=0)

        return shifted_X
    
    def nan_mean(self, input):
        # Create a mask for NaN values
        mask = torch.isnan(input)
        
        # Replace NaN values with zeros
        input_zeros = torch.where(mask, torch.zeros_like(input), input)
        
        # Compute the sum and count of non-NaN values along each column
        col_sum = torch.sum(input_zeros, dim=0)
        col_count = torch.sum(~mask, dim=0, dtype=torch.float32)
        
        # Compute the column-wise mean, ignoring NaN values
        output = torch.where(col_count > 0, col_sum / col_count, torch.tensor(float('nan')))
        
        return output

  def fill_nan_with_last_value(x):
      mask = torch.isnan(x)
  
      last = x[0].item()
      for i in range(1, x.shape[0]):
          if mask[i].item():
              x[i] = last
          elif mask[i].item() is False:
              last = x[i].item()
      return x

Everything works and has been tested individually, its just when putting it all together I lose the gradients somewhere.

Thanks again for any support!
Aaron

This line:
p_scale = torch.tensor(p_scale, dtype=int)
will detach p_scale since you are a) recreating a new tensor without a gradient history and b) using an integer dtype which does not support gradients.

Thanks very much! That seems to have been the main problem.