im trying to select 3 of 4 classes out of the Image dataset, and also provide class weights to those specific classes. Does weighted random do both of these tasks?
train_dataset = dset.ImageFolder(os.path.join(dataroot), transform=transforms.Compose(
[SmallScale(image_size),
transforms.RandomCrop(image_size),
transforms.Grayscale(),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])
]))
targets = train_dataset.targets
target_values = img_class # img_class = [[1, 2, 3]] out of 4 classes
target_bool = []
for x in targets:
if x in target_values:
target_bool.append(True)
else:
target_bool.append(False)
target_idx = torch.tensor(target_bool).nonzero()
sample_idx = target_idx
sample_tar = torch.Tensor(targets)
sample_range = sample_tar[sample_idx.squeeze()]
ordered_range = sample_tar[sample_idx.squeeze()]
min_val = sample_range.min()
max_val = sample_range.max()
for val, tar in enumerate(ordered_range):
if tar == min_val:
ordered_range[val] = 0
elif tar == max_val:
ordered_range[val] = 2
elif tar != min_val and tar != max_val:
ordered_range[val] = 1
target_idx = target_idx[:, 0].tolist()
class_sample_count = torch.tensor([(sample_range == t).sum() for t in torch.Tensor(img_class)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in ordered_range.long()])
balanced_sampler = torch.utils.data.sampler.WeightedRandomSampler(samples_weight, len(samples_weight))
# temporal_sampler = torch.utils.data.sampler.SubsetRandomSampler(target_idx)
data_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False, # (train_sampler is None),
drop_last=True,
num_workers=int(workers),
sampler=balanced_sampler # but want to use the indices of target_idx instead of the entire dataset
)
# data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
# shuffle=True,
# num_workers=int(workers))
return data_loader```