I want to slice the array.
The code below is working but I was wondering if there are more efficient ways prepared by PyTorch developer (Basically I want to avoid loops)
By the way, I changed inputs from torch.Tensor to numpy.ndarray.
train <numpy.ndarray>: (1000, 19, 1024, 2048)
val <numpy.ndarray>: (1000, 19, 1024, 2048)
def logit_preprocess(dataset):
split_len = 64
arr = []
h_len = 1024 // split_len
w_len = 2048 // split_len
for d in tqdm(dataset):
for h in range(h_len):
h_start = h * split_len
h_end = h_start + split_len
for w in range(w_len):
w_start = w * split_len
w_end = w_start + split_len
arr.append(d[:, h_start:h_end, w_start:w_end])
return np.array(arr)
train_data = logit_preprocess(train)
val_data = logit_preprocess(val)
class MyDataset(Dataset):
def __init__(self, data, transform=None):
self.data = torch.from_numpy(data).float()
self.transform = transform
def __getitem__(self, index):
x = self.data[index]
if self.transform:
x = self.transform(x)
return x
def __len__(self):
return len(self.data)
train = MyDataset(train_data)
val = MyDataset(val_data)
batch_size = 1
train_loader = DataLoader(train, batch_size=batch_size,shuffle=False,num_workers=2)
validation_loader = DataLoader(val, batch_size=batch_size,shuffle=False,num_workers=2)