"Kinda" breaking backprop gradient with discrete values?

I have been trying to implement a “color quantization module” for a project but I am getting stuck.
I have tried to simplify the techniques I am using to play around with what is actually possible in PyTorch.

Consider:

n = 10

for input, target in dataloader:
    output = model(input)
    output = (n*output).round() / n   # <---
    loss = l1_loss(output, target)
    loss.backward()

Here, we limit the values that output can take to only 10 unique ones. This makes the values discrete and does not throw any error that the gradient is broken BUT training seems pretty much impossible

Is there any way to train models this way? What I want to do in the future is to vary n per sample.
Maybe torch.quantization has a use case for this? I am not very familiar with that module and have been under the impression that it is done for performance sake and might use a static resolution (n).

The round operation would kill the gradient (almost) everywhere, so it’s indeed not very useful for training as seen here:

x = torch.randn(10, requires_grad=True)
y = torch.round(x)
y.mean().backward()
print(x.grad)
> tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

You could try to implement this approach posted by @tom, which would be able to use different “paths” during the forward and backward call.

Thanks ptr. Is there a list of functions like this which “almost always” break the gradient or do we just have to experiment?

Using min() and max() most likely have similar behavior as round() but does not throw “gradient broken error” like argmin() and argmax() do.

I don’t know, if there is a list collecting these operations, but you could “draw” the applied method for different values (in mind or with any software library) and check, how the derivative would look.
For these “almost always” breaking methods, the gradient would be zero almost everywhere beside some specific points.

1 Like