Backpropagation issue when using argmax

I have a network with following architecture

def forward(self, W, z):

    W1 = self.SM(W)
    W2 = torch.argmax(W1, dim=3).float()

    h_5 = W2.view(-1, self.n_max_atom * self.n_max_atom)
    h_6 = self.leaky((self.dec_fc_5(h_5)))
    h_6 = h_6.view(-1, self.n_max_atom, self.n_atom_features)

    return h_6

where self.SM = nn.Softmax(dim=3). When running this, I receive the error:

File “/Users/Blade/model/VAEtrain.py”, line 154, in trainepoch
loss.backward()
File “/Users/Blade/anaconda3/lib/python3.7/site-packages/torch/tensor.py”, line 198, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File “/Users/Blade/anaconda3/lib/python3.7/site-packages/torch/autograd/init.py”, line 100, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Expected isFloatingType(grads[i].scalar_type()) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) (validate_outputs at /Users/distiller/project/conda/conda-bld/pytorch_1587428061935/work/torch/csrc/autograd/engine.cpp:476)
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits, std::__1::allocator > const&) + 135 (0x112602ab7 in libc10.dylib)
frame #1: torch::autograd::validate_outputs(std::__1::vector<torch::autograd::Edge, std::__1::allocatortorch::autograd::Edge > const&, std::__1::vector<at::Tensor, std::__1::allocatorat::Tensor >&, std::__1::function<std::__1::basic_string<char, std::__1::char_traits, std::__1::allocator > (std::__1::basic_string<char, std::__1::char_traits, std::__1::allocator > const&)> const&) + 5884 (0x1193d614c in libtorch_cpu.dylib)
frame #2: torch::autograd::Engine::evaluate_function(std::__1::shared_ptrtorch::autograd::GraphTask&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 1996 (0x1193d18ec in libtorch_cpu.dylib)
frame #3: torch::autograd::Engine::thread_main(std::__1::shared_ptrtorch::autograd::GraphTask const&, bool) + 497 (0x1193d08b1 in libtorch_cpu.dylib)
frame #4: torch::autograd::Engine::thread_init(int) + 152 (0x1193d0648 in libtorch_cpu.dylib)
frame #5: torch::autograd::python::PythonEngine::thread_init(int) + 52 (0x111b43a04 in libtorch_python.dylib)
frame #6: void* std::__1::__thread_proxy<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_deletestd::__1::__thread_struct >, void (torch::autograd::Engine::)(int), torch::autograd::Engine, int> >(void*) + 66 (0x1193dfc32 in libtorch_cpu.dylib)
frame #7: _pthread_body + 126 (0x7fff6e8632eb in libsystem_pthread.dylib)
frame #8: _pthread_start + 66 (0x7fff6e866249 in libsystem_pthread.dylib)
frame #9: thread_start + 13 (0x7fff6e86240d in libsystem_pthread.dylib)

It seems that the torch.argmax function brakes the backpropagation. But I believe that it should work fine. Does the model architecture have a problem?

Hello Blade!

argmax() is not usefully differentiable, and so, indeed, does break
backpropagation.

Yes, the argmax() piece isn’t differentiable.

Make sure you understand why argmax() isn’t differentiable, and then
see if you can reformulate what you are doing in a way that avoids the
discrete jumps inherent in argmax().

Good luck.

K. Frank

1 Like

Thanks! Is there a work around for this?

Hello Blade!

“Work around” isn’t really the right term, as it implies that you are
trying to do something that makes sense, and you need to work
around a bug or limitation to reach the same result by a different
path.

Conceptually, what are the index values in W2 supposed to mean?
How do you want your optimizer to respond when one of the values
in W2 suddenly makes a discrete jump from 2 to 1?

To illustrate my point, here is a “work around”:

Replace W2 = torch.argmax (W1, dim = 3).float()
with W2 = 0.0 * torch.sum (W1, dim = 3).float().

Backpropagation will now work (but all of your gradients will be zero).

softmax() is a smooth (differentiable) approximation to the one-hot
encoding of argmax(). But this comment will only be helpful if you
understand the conceptual role you want the piece-wise constant (not
usefully differentiable) W2 values to play in your network training.

Good luck.

K. Frank

1 Like

I have two networks that are trained together. The networks are in series, i.e. output of one is the input of the other. Upstream network is solving a classification problem for graph edge types [Batch, node, node, class]. What you see is my downstream network: it takes in output of the upstream W, applies Softmax to turn them into probabilities W1, and then I want to turn them into either a one-hot vector or class labels W2 before reshaping and feeding it to the linear layer dec_fc_5.

Hello Blade!

Once you turn the output of your upstream network into a one-hot
vector or a class label, you have done something that is not
differentiable and that breaks backpropagation.

So … You have to do something else.

Why not just pass the output of your upstream directly to your
downstream network (i.e., directly to fc_5)? What breaks?

If the output of your upstream network comes directly from a Linear
without any subsequent activation function, you will want an activation
function in between. How about relu() or sigmoid() (or even
softmax())?

If passing a one-hot vector to your downstream vector makes the
most sense (ignoring the fact that it isn’t differentiable), perhaps you
should consider my observation that softmax() is a differentiable
approximation to the one-hot encoding of argmax().

You can sharpen this result – making softmax() “less soft” – by
scaling its argument, e.g., softmax (scale * W1). As scale is
increased to approach +inf, the result of softmax (scale * W1)
will approach one_hot (argmax (W1)).

Best.

K. Frank

1 Like

That makes sense, thank you very much for your thorough explanation!