How to replace a torch.tensor in-place with a new padded tensor?

I am trying to figure out how I overwrite a torch.tensor object, located inside a dict, with a new torch.tensor that is a bit longer due to padding.

# pad the tensor
zeros = torch.zeros(55).long()
zeros[zeros == 0] = 100  # change to padding
temp_input = torch.cat([batch['input_ids'][0][0], zeros], dim=-1) # cat
temp_input.shape  # [567]
batch['input_ids'][0][0].shape  # [512]
batch['input_ids'][0][0] = temp_input
# The expanded size of the tensor (512) must match the existing size (567) at non-singleton dimension 0.  Target sizes: [512].  Tensor sizes: [567]

I am struggling to find a way to extend the values of a tensor in-place or to overwrite them if the dimensions change.

The dict is emitted from torch’s DataLoader and looks like this:

{'input_ids': tensor([[[  101,  3720,  2011,  ..., 25786,  2135,   102]],
 
         [[  101,  1017,  2233,  ...,     0,     0,     0]],
 
         [[  101,  1996,  2899,  ..., 14262, 20693,   102]],
 
         [[  101,  2197,  2305,  ...,  2000,  1996,   102]]]),
 'attn_mask': tensor([[[1, 1, 1,  ..., 1, 1, 1]],
 
         [[1, 1, 1,  ..., 0, 0, 0]],
 
         [[1, 1, 1,  ..., 1, 1, 1]],
 
         [[1, 1, 1,  ..., 1, 1, 1]]]),
 'cats': tensor([[-0.6410,  0.1481, -2.1568, -0.6976],
         [-0.4725,  0.1481, -2.1568,  0.7869],
         [-0.6410, -0.9842, -2.1568, -0.6976],
         [-0.6410, -0.9842, -2.1568, -0.6976]], grad_fn=<StackBackward>),
 'target': tensor([[1],
         [0],
         [1],
         [1]]),
 'idx': tensor([1391, 4000,  293,  830])}

Do I need to create new tensors, store them in a list of lists, and then transform them back to tensors?

If I understand your code correctly, you are trying to change the shape of a subtensor in input_ids, while the other tensors would keep their shape?
If so, this won’t be possible, since the tensor won’t be able to be represented as a single tensor anymore.
Here is a code snippet to illustrate this error:

x = torch.randn(5, 10)
x[0] = torch.ones(10) # works since same dim1
print(x)

x[0] = torch.ones(11) # error

You would have to pad all tensors to the same length or split the original tensor into different ones.

1 Like

Thanks for your time and consideration of my problem!

I am trying to pad the tensors all to the same length, dynamically by batch, through a collate_fn via the DataLoader. The best I can come up with right now is putting them all into a list, but this seems not quite right:

def collate_fn_padd(batch):
    batch_inputs = list()
    max_size = 0
    # find the max length -- for every item
    for item in batch:
        # if len of the input ids > max_size:
        if len(item['input_ids'][0]) > max_size:
            # get a new max size
            max_size = len(item['input_ids'][0])
    # for every item
    for item in batch:
        # check current length of tokens
        current_len = len(item['input_ids'][0])
        # if tokens are smaller than the max_size
        if current_len < max_size:
            # find the difference
            len_diff = max_size - current_len
            # generate some zeros
            zeros = torch.zeros(len_diff).reshape(1, len_diff).long()
            # turn them to padding
            zeros[zeros == 0] = 100  # change to padding
            # cat them together
            temp_input = torch.cat([item['input_ids'][0].reshape(1, item['input_ids'][0].shape[0]), zeros], dim=1)
            # add the inputs to a list
            batch_inputs += list(temp_input)
        else:
            batch_inputs += list(item['input_ids'])
    return batch_inputs

Just an FYI maybe https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_sequence.html pad_sequence can simplify your code a bit.

1 Like

Thanks, I was playing with this earlier from forum searching but was obviously using it wrong; I think this will be helpful!