Are conversions to string differentiable?

I’m building a model that generates sequences. Internally, the sequences are a (N,1, length, W) tensor of width-W one-hots, but part of the loss function involves converting sequences to strings (using a dictionary) and passing them as an argument to another PyTorch neural network that returns a scalar. Furthermore, certain one-hots are (deterministically) ignored during this conversion.

Is this okay to do? It isn’t clear to me how gradients would be backpropagated properly, but I’m not getting any errors.

Hi,

As you mentioned, strings are “discrete” and so gradients for them don’t really make sense. So no you won’t be able to packprop through that op I’m afraid.

I thought so! Instead, I’ll need to train a model that can accept a (N,1, *, W) tensor, correct? In order to (deterministically) ignore some one-hots in the original (N,1, length, W) tensor, how should I filter so that gradients are still sensible?

For instance, suppose I have a (1, 1, 10, 3) tensor, and I want to ignore the one-hot (0,1,0). How would I turn:

tensor([[[[0., 0., 1.],
          [0., 0., 1.],
          [0., 1., 0.],
          [0., 1., 0.],
          [0., 0., 1.],
          [0., 0., 1.],
          [0., 1., 0.],
          [0., 0., 1.],
          [1., 0., 0.],
          [0., 0., 1.]]]])

into

tensor([[[[0., 0., 1.],
          [0., 0., 1.],
          [0., 0., 1.],
          [0., 0., 1.],
          [0., 0., 1.],
          [1., 0., 0.],
          [0., 0., 1.]]]])

in a gradient-friendly way?

Well, removing entries is easy. You can simply index the elements you want to keep. And the gradients will flow properly back to these elements.

Be careful though when you generate these one-hot encodings that you do so in a way that returns non-zero gradients. :slight_smile:

1 Like

Thank you! It isn’t clear to me how to index the tensor in bulk like that. How would I do the indexing (without collecting the indices using a for loop?)

Sure something like that will work:

import torch

inp = torch.tensor([[[[0., 0., 1.],
          [0., 0., 1.],
          [0., 1., 0.],
          [0., 1., 0.],
          [0., 0., 1.],
          [0., 0., 1.],
          [0., 1., 0.],
          [0., 0., 1.],
          [1., 0., 0.],
          [0., 0., 1.]]]])

to_remove_line = torch.tensor([0., 1., 0.])


to_remove = (inp == to_remove_line).all(-1)
res = inp[to_remove.logical_not()]

print(res)

Note that the == can be relaxed if you’re using floating points by checking that the difference is below a given threshold.

1 Like