# Remove column from tensor?

What is the proper way to remove a column from a tensor?

What I’m searching is the proper way to do the following (removing the 3):

``````t = torch.tensor([0,1,2,3,4,5])
t = torch.cat( ( t[:3], t[4:] ) )
``````

1 Like

Does boolean indexing work with torch tensors? Haven’t tried it myself yet.

``````a = np.arange(9, -1, -1)     # a = array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
b = a[np.arange(len(a))!=3]  # b = array([9, 8, 7, 5, 4, 3, 2, 1, 0])``````

Hi,

Did anyone figure an efficient way to pop up (delete) an entire column from a torch tensor? Do we have a function in PyTorch to do that? just like we have numpy.delete. Thanks!

HI @dev-team ,

Can somebody help out here ?

1 Like

Okay, so here is a less memory efficient way to delete last n columns. I have only tested it on a 2-D tensor.

``````>>> def delete_last_n_columns(a, n):
...     n_rows = a.size()[0]
...     n_cols = a.size()[1]
...     assert(n<n_cols)
...     first_cols = n_cols - n
...     b = torch.index_select(a,1,mask) # Retain first few columns; delete last_n columns
...     return b
...
>>> a.size()
torch.Size([10, 536])
>>> b = delete_last_n_columns(a,512)
>>> b.size()
torch.Size([10, 24])
``````

While the below solution may be a bit heavy-handed for removing only one column, it works great when you need to remove many columns:

``````import torch

t=torch.arange(6)
print(t)

tensor([0, 1, 2, 3, 4, 5])

c=[0,1,2,4,5]
print(t[c])

tensor([0, 1, 2, 4, 5])
``````

You can also use an array of booleans when deciding what to keep. This is probably the most efficient method when you want to create arguments to determine which indices to act on.

``````a = torch.arange(10)
row_cond=a!=5
print(a[row_cond])

tensor([0, 1, 2, 3, 4, 6, 7, 8, 9])
``````

``````print(a[~row_cond])

tensor([5])
``````

With this method, you can also strictly perform an operation only on certain values:

``````a[row_cond]=a[row_cond]*2
print(a)

tensor([ 0,  2,  4,  6,  8,  5, 12, 14, 16, 18])
``````

Using slices and concatenations :

``````import torch

x = torch.arange(4 * 8).reshape((4, 8))

print(x)

x = torch.cat([x[:, :1], x[:, 4:]], dim=1)

print(x)
``````

prints :

tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31]])
tensor([[ 0, 4, 5, 6, 7],
[ 8, 12, 13, 14, 15],
[16, 20, 21, 22, 23],
[24, 28, 29, 30, 31]])