Differentiable argmax

I need to put argmax in the middle of my network and thus I need it to be differentiable using straight-through estimator, thats: during the forward I want to do the usual argmax and during the backward, as argmax is not differentiable, I would like to pass the incoming gradient instead of 0 gradients. This is what I came up with:

class ArgMax(torch.autograd.Function):

	@staticmethod
	def forward(ctx, input):
		ctx.mark_dirty(input)
		idx = torch.argmax(input.clone(), 1)
		return idx


	@staticmethod
	def backward(ctx, grad_output):
		grad_input = grad_output.clone()
		return grad_input

However I get this error:
RuntimeError: RuntimeError: Function ArgMaxBackward returned an invalid gradient at index 0 - expected shape [1, 10] but got [1]

The problem is that the input to the forward function is a tensor of [1, 10] but because of the armax, the its output becomes a tensor of size [1]. I dont know the mechanics of autograd very well. Is there is solution for this problem?

3 Likes

You’d need to return a zero vector with the argmax entry set to a gradient, this can be achieved by zeros followed by assignment, scatter, or a similar function.

Note that what you do makes very limited sense mathematically (because you trade output for indexing) and argmax is an integral tensor which cannot have a gradient, so you’d need to cast it to float before doing much with it, …

Best of luck

Thomas

Thanks,
I know that I can use scatter to create a one hot vector or something like that and in fact I have already done that:

idx = idx.data.new(input.size()).zero_().scatter_(-1, idx.view(-1, 1), 1.0)

but that problem is the output of the argmax will be fed into an embedding layer which doesnt accept one hot vectors, so in all cases I will need a single scalar value as the output.

I could have used the gumbel softmax instead, but it also has the same issue, it outputs a hot vector.

The cool people (and me) don’t use data or new+zero_ anymore, it’s called torch.zeros or torch.zeros_like now. :wink:

So you won’t change the invariant that the input gradient will have the same shape as the input, but you could use the zero scatter_ in the backward.

From what I understand of your problem, you need to merge the sampling and the embedding into a single Function because otherwise you won’t get a grad_out to propagate.
That would be roughly

import torch
class ArgMaxEmbed(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, embedding_matrix):
        idx = torch.argmax(input, 1)
        ctx._input_shape = input.shape
        ctx._input_dtype = input.dtype
        ctx._input_device = input.device
        ctx.save_for_backward(idx)
        return torch.nn.functional.embedding(idx, embedding_matrix)

    @staticmethod
    def backward(ctx, grad_output):
        idx, = ctx.saved_tensors
        grad_input = torch.zeros(ctx._input_shape, device=ctx._input_device, dtype=ctx._input_dtype)
        print (idx.shape, grad_output.sum(1, keepdim=True), grad_input.shape)
        grad_input.scatter_(1, idx[:, None], grad_output.sum(1, keepdim=True))
        return grad_input, None

(but it doesn’t do gradients on the weight).

You’d probably only use something like that in situations where matmul is prohibitively expensive…

Best regards

Thomas

1 Like

Thanks Thomas,
that really helped me a lot. My embedding layer should be outside of the function based on the architecture that I use. I changed the code slightly like this:

class ArgMax(torch.autograd.Function):

	@staticmethod
	def forward(ctx, input):
		idx = torch.argmax(input, 1)
		ctx._input_shape = input.shape
		ctx._input_dtype = input.dtype
		ctx._input_device = input.device
		ctx.save_for_backward(idx)
		return idx.float()

	@staticmethod
	def backward(ctx, grad_output):
		idx, = ctx.saved_tensors
		grad_input = torch.zeros(ctx._input_shape, device=ctx._input_device, dtype=ctx._input_dtype)
		grad_input.scatter_(1, idx[:, None], grad_output.sum())
		return grad_input

there are gradients propagating back but I’m not 100% sure if they are correct. I need to check them. Also I did not understand what do you mean by “it doesn’t do gradients on the weight”.

If you use indexing / embedding, the index doesn’t have a gradient, so I can probably learn a trick if you share how you use the index when it works.

By “it doesn’t do gradients on the weight” I just meant that my function above won’t return gradients for the embedding_matrix (weight in the embedding layer…), sorry for the confusion!

Best regards

Thomas

OK sorry I should have clarified this earlier.

I’m trying to replicate some experiments as described here: https://arxiv.org/abs/1810.13107

Basically the idea is like this: We have two seq2seq models chained together. The first seq2seq is anautomatic speech recognition (ASR) encoder-decoder model that takes a speech signal and outputs the corresponding text. Then the text is given to a text-to-speech (TTS) encoder-decoder model to generate back the speech signal from the generated text. You can imagine it like a seq2seq autoencoder model. There are two losses one after the ASR which means we should backpropagate the gradients from the middle of the network to the beggining and the second loss is at the end which should backpropagate the gradients from the end all the way to the beggining (including the TTS model).

The problem is here, in order to generate the text we need to use argmax in the middle and since we need to backpropagate we need use straight-through estimator to prevent sending 0 gradients. The argmax is needed since the two seq2seq models will be decoupled later and each one should be usable independently.

Hope this clarifies,

Thanks for your help

Thank you for explaining your use case! I do think that it is likely easiest to integrate argmax and embedding into a single autograd Function.
To use it I’d probably adapt the “later model” to (optionally) take inputs before argmax and do the integrated argmax + embedding (my personal style would be to pass a flag on construction whether that should be the training mode and then test that and self.training during forward).

Best regards

Thomas

Hey Thomas,
Thank you very much for youhelp. After going through the whole thing over and over again I realized there is a big problem with the architecture I was planning to design. So the gradients cannot backpropagate through embedding layer too. It means that even if I solve the argmax problem I will still get stuck in the embeddings. I guess thats the reason you proposed to merge the two. I checked again the paper and I think they dont use embeddings at all they just use one-hot vectors (in their first paper they use embeddings without full backpropagation: https://arxiv.org/abs/1707.04879).

I this case I will just use something like this:

class ArgMax(torch.autograd.Function):

	@staticmethod
	def forward(ctx, input):
        idx = torch.argmax(input, 1)

        output = torch.zeros_like(input)
        output.scatter_(1, idx, 1)
		
        return output
	

	@staticmethod
	def backward(ctx, grad_output):
        return grad_output

Or simply I will use gumbel softmax: https://pytorch.org/docs/stable/nn.html?highlight=gumbel#torch.nn.functional.gumbel_softmax

i am a newbie to the pytorch so can you tell me if this class will update the embedding_matix weights also?

import torch
h = torch.randn(1, 2, 5, requires_grad=True); print(h)
val,idx = h.max(1, keepdim=True)
print(idx)
z = torch.nn.functional.gumbel_softmax(h, tau=2, hard=False, dim=1); print(z)
z = torch.nn.functional.gumbel_softmax(h, tau=2, hard=True, dim=1); print(z)

outputs are:

tensor([[[-0.0259, -0.9393, -0.2825, 0.6466, -1.0658],[-0.6078, 1.2127, 0.1509, -0.9749, -1.4952]]], requires_grad=True)
tensor([[[0, 1, 1, 0, 0]]])
tensor([[[0.7145, 0.1478, 0.8365, 0.5046, 0.4407],[0.2855, 0.8522, 0.1635, 0.4954, 0.5593]]], grad_fn=)
tensor([[[1., 0., 0., 1., 0.],[0., 1., 1., 0., 1.]]], grad_fn=)

the softmax procedure looks weird…

What is the trick you have for this? I believe @ram’s problem is similar to the one I have here: Predict a categorical variable and then embed it (one-hot?)

Hi, could you please explain a bit about implementing it using gumbel softmax? I’m having a similar problem here and a bit unsure about my current implementation.

Thanks!