Why does nn.Embedding layers expect LongTensor type input Tensors?

I’m a bit confused about the types that different layers are expecting.

I discovered that a nn.Embedding layer expects its input to be of type LongTensor aka torch.int64. Then it outputs a tensor of type torch.float64.

But then if I want to use a nn.LSTM layer next, I need to change the type since nn.LSTM wants torch.float32 as input, and I use for this

emb_vect = torch.tensor(emb_vect dtype=torch.float32)

and it works fine then. But I have two questions:

  1. Why have the constraints on type?
  2. Does changing the type of a variable in a forward function have an effect on the backpropagation?

Thanks a lot

1 Like

Hi,

The embedding layer takes as input the index of the element in the embedding you want to select and return the corresponding embedding. The input is expected to be a LongTensor because it is an index and so must be an integer. The output is a float type, you can call .float() or .double() on the nn.Module to change between float32 and float64.
The lstm takes float type for both input and output and can be switches between single and double precision with .float() and .double() as well.

Change of types are considered as any other operations by the autograd engine and so all the gradients will be computed properly. Note that there might be a loss of precision of going from double to float as you would expect.

2 Likes

Thanks a lot for your answer albanD.

In order not to loose precision, I decided to use what you suggested (I didn’t know it existed) with

self.lstm = nn.LSTM(input_size=self.emb_dim,
                        hidden_size=n_lstm_units,
                        dropout=1-keep_prob).double()

instead of converting the output of the embedding layer. And it works fine. Thank you very much for your time and explanation!

But why doesn’t it accept an IntTensor?

What is not accepting IntTensors?

1 Like

The Embedding layer. I want to save some memory by using IntTensors over LongTensors but the layer doesn’t accept IntTensors as input.

1 Like

Hi,

Yes to avoid hard-to-debug issues on the user side, we only support Long Tensors that can always represent the indices that you will need.
Is your index Tensor too large? You can try to convert it to long just before the embedding forward so that you only ever have a single batch in full precision at a time. That should fit in memory just fine no?

1 Like

How would supporting IntTensors cause bugs?

Hi,

It won’t cause errors in the code itself.
But you can “easily” overflow an IntTensor. And this can be very hard to find out for a user.

@albanD I understand the concern of users accidentally overflowing an int32 tensor, however the requirement of int64 is causing issues for many people integrating with 3rd party inference engines like TensorRT and OpenVino. You can see the long history on this Github issue which has been open for almost 5 years!

I think providing full compatibility is a bigger concern than potential user errors, especially if int32 behavior was designed so that it’s only enabled on an explicit call like torch.tensor(dtype=torch.int32). Then basic users will not find themselves in the overflow situation. Another option would be to put this under a different class name like nn.SmallEmbedding which again forces users to make an explicit choice where they are aware of the potential downsides.

Could the team please re-consider the priority of addressing this or provide a road map for community members to contribute this feature in a PR?

Hi,

Thanks for the feedback and details.
I think we already have support on cpu and cuda for both long and int dtypes in embedding in the latest version of pytorch. Does that solve your issue?

1 Like

Excellent! I was on 1.7.1, but you are correct that 1.8.0 accepts int32 indices now. I will mention this in the Github issue thread so others are aware and it can be closed. FYI the release notes for 1.8.0 only mention this support for nn.EmbeddingBag

1 Like

Ho my bad, thanks for pointing it out. Just updated the release notes! :slight_smile: