Does Autograd work when slicing tensors and re-assigning slices to different tensors?

Hi all! I’m fairly new to PyTorch and still understanding how Autograd works.

I am experimenting with fusing different word embeddings to feed to a Neural Network and am not sure if the operations I am doing will be tracked correctly and differentiated as intended.

Let me explain better:

I am building an embedding fusion layer that is intended to behave exactly like a TimeDistributedDense layer in Keras. Namely, this layer takes multiple embeddings for each token and fuses them by applying the same Linear transformation to all of them. However, I do not want to perform fusion for padding tokens, which are always all zeros and so I am skipping it by doing this:

def _apply_to_nonzero(self, x):
        # find all non-zero vectors in the batch - these are the non-padding elements
        non_zeros = torch.tensor([torch.max(v).item() > 0 for v in x])
        # build zero matrix for the batch using the output size of the Linear layer (self.output_size)
        all_zeros = torch.zeros((*x.size()[:-1], self.output_size))
        # now fill the non-zeros indexes of the zero matrix with the fused embeddings
        all_zeros[non_zeros] = self._fuse(x[non_zeros]).
        # at this point, all_zeros contains the fused embeddings for tokens that are not padding and
        # zero vectors of length self.output_size, in the indexes that contained padding elements.
        x = all_zeros
        return x

The code above runs as intended, but are my operations going to be correctly differentiated? (what concerns me is the fact that I’m slicing and moving slices across different tensors…)

1 Like

Hey I created this simple example, slicing seems to propagate the gradients:

a = torch.tensor([2., 3, 3], requires_grad = True)
b = torch.tensor([6., 4, 2], requires_grad = True)
Q = 3*a**3 - b**2
# slice and assign
x = Q[1:3]
# call backward
x.sum().backward()
print(a.grad)
print(b.grad)

Output is:

tensor([ 0., 81., 81.])
tensor([-0., -8., -4.])

which shows that gradients have been computed