Concatenating nested tensors blows up in backward

Hello,

I get an error when I run the following simple code involving nested tensors:

import torch
import torch.nn as nn

batch_size = 16
min_length = 21
max_length = 42
sizes = torch.randint(min_length, max_length, (batch_size,))

device = torch.device('cuda')

with device:
    tokens1 = torch.nested.nested_tensor([torch.randint(5, (s,)) for s in sizes], layout=torch.jagged)
    tokens2 = torch.nested.nested_tensor([torch.randint(7, (s,)) for s in sizes], layout=torch.jagged)
    emb1 = nn.Embedding(5, 16)
    emb2 = nn.Embedding(7, 32)
    h1 = emb1(tokens1)
    h2 = emb2(tokens2)

    h = torch.cat([h1, h2], dim=-1)
    assert h.shape[0] == batch_size
    assert h.shape[-1] == 48

    loss = h.mean()
    print(loss)
    loss.backward()

The error is as follows:

tensor(0.0138, grad_fn=<MeanBackwardAutogradNestedTensor0>)
Traceback (most recent call last):
  File "/homes/evgeny/code/RP3Net/.scripts/test_nested_cat.py", line 25, in <module>
    loss.backward()
  File "/hps/software/users/chembl/evgeny/micromamba/envs/rp3/lib/python3.12/site-packages/torch/_tensor.py", line 616, in backward
    return handle_torch_function(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/rp3/lib/python3.12/site-packages/torch/overrides.py", line 1728, in handle_torch_function
    result = mode.__torch_function__(public_api, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/rp3/lib/python3.12/site-packages/torch/utils/_device.py", line 103, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/rp3/lib/python3.12/site-packages/torch/_tensor.py", line 625, in backward
    torch.autograd.backward(
  File "/hps/software/users/chembl/evgeny/micromamba/envs/rp3/lib/python3.12/site-packages/torch/autograd/__init__.py", line 354, in backward
    _engine_run_backward(
  File "/hps/software/users/chembl/evgeny/micromamba/envs/rp3/lib/python3.12/site-packages/torch/autograd/graph.py", line 841, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Function CatBackward0 returned an invalid gradient at index 1 - got [16, j1, 32] but expected shape compatible with [16, j2, 32]

The error happens both on CPU and on CUDA. Note that this is not an unsupported operation, it just blows up. I found a similar issue on GitHub, but according to the comments it has been deprioritised. Wonder if I should raise this as an issue too? And maybe someone has already come across this and knows a workaround?

Feel free to leave a comment in the created issue as these signals are helpful to discuss priorities and roadmaps.