import torch
import random
from torchvision import datasets, transforms
class DataManager():
def __init__(self, data: datasets):
self.data = data
self.data_len = len(data)
# (Class, 4D Tensor)
self.tensor_dict = dict()
self.avg_tensor_dict = dict()
for i in range(self.data_len):
tensor, target = data[i]
if target in self.tensor_dict:
old_tensor = self.tensor_dict[target]
self.tensor_dict[target] = torch.cat([old_tensor, tensor.unsqueeze(0)],0)
else:
self.tensor_dict[target] = tensor.unsqueeze(0)
def get_class_tensor(self, key):
return self.tensor_dict[key]
def get_average_class_tensor(self, key):
tensor = self.tensor_dict[key]
tensor = torch.mean(tensor, 0)
self.avg_tensor_dict[key] = tensor
return tensor
def sample_class_tensor(self, key, num: int):
class_tensor = self.tensor_dict[key]
num_tensors = class_tensor.shape[0]
samples = random.sample(range(0, num_tensors), num)
ret_tensor = None
for i in range(len(samples)):
tensor = class_tensor[samples[i]]
if i == 0:
ret_tensor = tensor.unsqueeze(0)
else:
ret_tensor = torch.cat([ret_tensor, tensor.unsqueeze(0)], 0)
return ret_tensor
I have this datamanager that takes in a dataset and sorts cifar10 into a dictionary where each key is associated with a 4D tensor of all the tensors with a label of the same class. The for loop makes it pretty slow as cifar is quite a large dataset. Is there a way to optimize this?