index = [1,2,3,....]
train_tensor.shape = torch.size(128,30,100)
train_label.shape=torch.size(128,1,1)
train_set = dataset(train_tensor,train_label)
Now , I want to remove data form Train_set using index?
I had tried this function:
def tensor_remove(tensors,index):
# print("tensors:",tensors.shape[0])
sub_tensors = torch.empty([0,30,100])
if index == []:
return tensors
if tensors.shape[1] == 1:
sub_tensors = torch.empty([0,1,1])
last_idx = -1
for idx in index:
if idx != last_idx+1:
sub_tensors = torch.cat((sub_tensors,tensors[last_idx+1:idx]),0)
if idx == index[-1]:
sub_tensors = torch.cat((sub_tensors,tensors[idx+1:]),0)
last_idx = idx
# print("sub_tensor:",sub_tensors.shape[0])
return sub_tensors
It words!But when I running train.py ,it rasing CUDA out of memerys! so ,I guess maybe the function caused that problem! Please help me! thanks!