I have a model that produces logits that I sample from. When I sample, I get a sequence of one-hots.

Part of my loss function involves passing this sequence into an RNN which will output a scalar. This RNN takes a batch of lists of indices, which get packed (they’re of variable length) and sent through the network, which has an embedding layer.

If the output of my RNN is part of my loss, does my batch of lists of indices (the input to the RNN) need to be differentiable? If so, I think it needs to be of type float to store gradients, but if I do this I get the error RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.cuda.FloatTensor instead (while checking arguments for embedding).

Am I correct in thinking that the need to have integer embedding indices poses a differentiability problem?

Yes. The “indexing” operation is not differentiable wrt to the indices. Mainly because they are not contiguous. And if you added a “floor” op to make it contiguous, you would get a gradient of 0 almost everywhere.

Here’s what I mean: my Model is supposed to take any PyTorch model as an argument and optimize against its outputs. If the model passed in is (say) a simple feedforward neural network (e.g. MLP), then the Model’s own weights can change in response to loss incurred by the MLP’s outputs by backpropagating the error through the MLP’s own layers/weights (and then through the Model’s own weights).

In order for the Model to actually learn, the function/model passed in needs to be differentiable so the Model’s weights can receive the gradients. As long as this criterion is satisfied, it seems to me that a model is “modular” in the sense that it can just be passed in and everything will work fine.

But it also seems to me that if the passed-in model makes use of embedding layers, it will impossible to determine loss with respect to the input, which means the Model won’t be able to change its weights in a way that will make the passed-in model reduce its incurred loss. Does that make sense?

Yes that’s right.
Basically, the Embedding layer cannot be plugged in in a middle of any system as it is not differentiable wrt to its inputs. It has to be the first layer.