Hello,
I get an error when I run the following simple code involving nested tensors:
import torch
import torch.nn as nn
batch_size = 16
min_length = 21
max_length = 42
sizes = torch.randint(min_length, max_length, (batch_size,))
device = torch.device('cuda')
with device:
tokens1 = torch.nested.nested_tensor([torch.randint(5, (s,)) for s in sizes], layout=torch.jagged)
tokens2 = torch.nested.nested_tensor([torch.randint(7, (s,)) for s in sizes], layout=torch.jagged)
emb1 = nn.Embedding(5, 16)
emb2 = nn.Embedding(7, 32)
h1 = emb1(tokens1)
h2 = emb2(tokens2)
h = torch.cat([h1, h2], dim=-1)
assert h.shape[0] == batch_size
assert h.shape[-1] == 48
loss = h.mean()
print(loss)
loss.backward()
The error is as follows:
tensor(0.0138, grad_fn=<MeanBackwardAutogradNestedTensor0>)
Traceback (most recent call last):
File "/homes/evgeny/code/RP3Net/.scripts/test_nested_cat.py", line 25, in <module>
loss.backward()
File "/hps/software/users/chembl/evgeny/micromamba/envs/rp3/lib/python3.12/site-packages/torch/_tensor.py", line 616, in backward
return handle_torch_function(
^^^^^^^^^^^^^^^^^^^^^^
File "/hps/software/users/chembl/evgeny/micromamba/envs/rp3/lib/python3.12/site-packages/torch/overrides.py", line 1728, in handle_torch_function
result = mode.__torch_function__(public_api, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/hps/software/users/chembl/evgeny/micromamba/envs/rp3/lib/python3.12/site-packages/torch/utils/_device.py", line 103, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/hps/software/users/chembl/evgeny/micromamba/envs/rp3/lib/python3.12/site-packages/torch/_tensor.py", line 625, in backward
torch.autograd.backward(
File "/hps/software/users/chembl/evgeny/micromamba/envs/rp3/lib/python3.12/site-packages/torch/autograd/__init__.py", line 354, in backward
_engine_run_backward(
File "/hps/software/users/chembl/evgeny/micromamba/envs/rp3/lib/python3.12/site-packages/torch/autograd/graph.py", line 841, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Function CatBackward0 returned an invalid gradient at index 1 - got [16, j1, 32] but expected shape compatible with [16, j2, 32]
The error happens both on CPU and on CUDA. Note that this is not an unsupported operation, it just blows up. I found a similar issue on GitHub, but according to the comments it has been deprioritised. Wonder if I should raise this as an issue too? And maybe someone has already come across this and knows a workaround?