Weights aren't updating on backward pass

I have a model that takes a tensor representing the difference between two images and outputs coordinates used to make them more alike. I then calculate the loss as MSE of the created image and the original image, but when I run a backward pass no weights seems to update and the loss remains constant (although not none) throughout all epochs.

Is this because the loss isn’t calculated directly on the model output?

When reading other post the requires_grad was sometimes responsible for the problem, I’ve tried with requires_grad = True but I’m still a bit unsure if it’s used correctly.

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10 * 10 * 3, 240)
        self.fc2 = nn.Linear(240, 240)
        self.fc3 = nn.Linear(240, 240)
        self.fc4 = nn.Linear(240, 7)  # [x1, y1, x2, y2, r, g, b]

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x
def array_to_flatten_tensor(array_to_reshape):
    prepared_image = torch.FloatTensor(array_to_reshape)
    prepared_image.requires_grad = True
    prepared_image.retains_grad = True
    return torch.flatten(prepared_image)
    image = cv2.imread("flower.png")
    image = cv2.resize(image, (10, 10))
    image = torch.FloatTensor(image)
    image.requires_grad = True
    image.retains_grad = True

    image = torch.flatten(image)

    for image_pass in range(1000):
        running_loss = 0
        # Create empty image
        drawn_image = array_to_png_test.generate_empty_RGB_array(10, 10)

        for _ in range(10):

            optimizer.zero_grad()

            # Pass (image - net drawn) image to net
            difference_array = image - array_to_flatten_tensor(drawn_image)

            output = net_to_train(difference_array)

            # Converts the output to two points and colour
            point_1 = (float(output[0]), 0, float(output[1]))
            point_2 = (float(output[2]), float(output[3]))
            colour = round(max(float(output[4]), 0)), round(max(float(output[5]), 0)), round(max(float(output[6]), 0))

            point_1, point_2 = sorted([point_1, point_2])

            # Draws a line to the empty image
            drawn_image = array_to_png_test.draw_line(point_1, point_2, colour, drawn_image)

            # Calculate loss as mean square difference of image and drawn image
            drawn_image_tensor = array_to_flatten_tensor(drawn_image)

            current_loss = criterion(drawn_image_tensor, image)

            running_loss += float(current_loss)

            current_loss.backward()

            optimizer.step()

        print("epoch " + str(image_pass) + " : " + str(running_loss / 250))

I’m not sure to fully understand what your doing but there is clearly a problem the way you compute the loss before calling backward.

Firstly, you should not try forcing requires_grad or retains_grad, the problem is already there before calling your function array_to_flatten_tensor. In classical situation, you don’t need to set requires_grad on intermediate tensors. Concerning retains_grad, it used when you need to keep the gradients after a backward to accumulate the gradient through different passes for instance, or to compute gradients of gradients, I don’t think you need it too.

To perform the backward pass correctly, you must ensure that all your flow of operations, from the model’s output to the loss, is backward compatible. It’s not the case here. The big issue is that you convert your model output to the float python built-in type, if you want to run the backward pass through your model, you must keep all your intermediate computations on tensors only !

To give a simple example with a L2 loss:
This will work (assuming model is an nn.Module in train mode):

prediction = model(input_tensor)             # here prediction is a tensor with requires_grad = True
loss = ((prediction - label)**2).sum()       # here loss is also a tensor with requires_grad = True, label may or may not be a tensor
loss.backward()                              # backward pass working

This won’t:

prediction = model(input_tensor)             # here prediction is a tensor with requires_grad = True
prediction = float(prediction)               # here prediction is no more a tensor ! We just lost the computational graph
prediction = torch.tensor(prediction)        # here prediction is again a tensor but detached from the computational graph
                                             # so the prediction is no more related to the network output from the backward point of view.
                                             # furthermore now prediction has requires_grad = False
prediction.requires_grad = True              # Forcing requires_grad to be True here won't change anything.
loss = ((prediction - label)**2)
loss.backward()                              # this backward pass will not reach the model !

I invite you to read this guide to understand how backward works.

Hope it helps, good luck with your pytorch journey :slight_smile:
Thomas

Thanks for the answer.

I tried reworking my code to keep all my calculations in tensors and I’ve manage to get grad for the drawn image tensor although the image tensor and the current_loss tensor still have grad None and therefore the backward doesn’t update anything.

Would you have any ideas on what to look at next?

Image loading

    image = cv2.imread("flower.png")
    image = cv2.resize(image, (10, 10))
    image = image.astype(np.float32)
    image = torch.tensor(image, requires_grad=True)

    image = torch.flatten(image, start_dim=1)

I don’t think the problem come from the data loader, by the way, you don’t need to set requires_grad to the loaded data. In most case you only need requires_grad on model’s parameters, which have by default requieres_grad set to True.

Simple example to illustrate:

dummy_input = torch.randn(1, 1, 5)
print('input requires_grad ?', dummy_input.requires_grad)
simple_model = torch.nn.Conv1d(1, 1, 3)
print('model param requires_grad ?', simple_model.weight.requires_grad) # weight attribute of conv is a parameter
model_output = simple_model(dummy_input)
print('output requires_grad ?', model_output.requires_grad)
random_target = torch.randn(1, 1, 3)
loss = ((model_output - random_target)**2).sum()
print('loss requires_grad ?', loss.requires_grad)

# Return :
# > input requires_grad ? False
# > model param requires_grad ? True
# > output requires_grad ? True
# > loss requires_grad ? True

In fact, when runing an operation on two tensors, usually if at least one input needs gradients, the output will also requires the gradients. So If your model has parameters with requieres_grad = True, your output should automatically get requieres_grad = True (unless you explicitly froze the model or use torch.no_grad context).

I do not know what could have gone wrong if you changed your code. But you should check some points:

  • How do you initialize your optimizer? It should be something like optimizer = optim.Adam(net_to_train.parameters(), lr=0.0001)
  • Set your input data with image = torch.tensor(image) without the requires_grad argument, and then check if your model output has requires_grad = True.
  • Check that at every step from the model’s output to the loss if the attribute requires_grad goes to false. Few operations on tensors are not differentiable and you cannot backward on them (if an operation is not differentiable, it will be stated in the doc), for instance indexing is not differentiable with respect to the input indices.

If you do not find where the problem is, please show me your current training loop and loss computation, so I can get what could go wrong.