Gradient backprop for different ways of slicing a tensor

Say we have two methods of slicing for computing an output of a convolutional layer:

# N  - batch size
# CI - number of in channels
# CO - number of out channels 
# T  - sequence length (think of e.g. a time series 0, ..., t, ..., T)

# model
conv = nn.Conv1d(CI,CO,kernel_size=1)

# input 
x = torch.rand(N,CI,T)

# method 1
out1 = conv( x[:,:,-5:] )

# method 2
out2 = conv( x )[:,:,-5:]

Theoretically, out1 and out2 should result in the same gradient updates for the conv kernel, but out1 requires less computation, as it only convolves the last 5 elements instead of the whole tensor x. In both cases, we cut away all elements but the 5 last ones, so that the gradients are not influenced by any but the last 5 points. Is that right? Or do the earlier points in the T-dimension of x also influence the gradient computation (for method 2)?

Hi Jay!

This is correct.

No,* because you have kernel_size = 1, neighboring points don’t get
mixed with one another, so the earlier points in x don’t affect the later
points in out1 and out2.

See this example:

>>> import torch
>>> torch.__version__
'1.12.0'
>>>
>>> conv1 = torch.nn.Conv1d (3, 3, kernel_size = 1)
>>> conv2 = torch.nn.Conv1d (3, 3, kernel_size = 1)
>>>
>>> with torch.no_grad():   # initialize the two convs identically
...     _ = conv2.weight.copy_ (conv1.weight)
...     _ = conv2.bias.copy_ (conv1.bias)
...
>>> x = torch.rand (1, 3, 10)
>>>
>>> out1 = conv1 (x[:, :, -5:])
>>> out2 = conv2 (x)[:, :, -5:]
>>> torch.equal (out1, out2)
True
>>>
>>> (out1**2).sum().backward()
>>> (out2**2).sum().backward()
>>> torch.equal (conv1.weight.grad, conv2.weight.grad)
False
>>> torch.equal (conv1.bias.grad, conv2.bias.grad)
False
>>> torch.allclose (conv1.weight.grad, conv2.weight.grad)
True
>>> torch.allclose (conv1.bias.grad, conv2.bias.grad)
True

*) To be precise, the gradient computations are mathematically equivalent,
but do differ in (mathematically-equivalent) orders of the operations that
lead to differing floating-point round-off error. That’s why the equal() tests
for the gradients return False, but the allclose() tests return True.

Best.

K. Frank

1 Like

Absolutely great example, thanks much for the perfect explanation!

@KFrank: maybe you have solution for this one, too? :smiley: kind of stuck there…