Hi,
I am trying to create a noisy dataset for ML. Here’s what I did:
mnist_train = MNIST('../data/MNIST', download = True,
transform = transforms.Compose([
transforms.ToTensor(),
]), train = True)
mnist_test = MNIST('../data/MNIST', download = True,
transform = transforms.Compose([
transforms.ToTensor(),
]), train = False)
And this:
def add_noise_sp(image,sd,amount=0.2):
np.random.seed(seed=sd)
low_clip = 0.13
std = 0.31
image = np.asarray(image)
out = image.copy()
p = amount
q = 0.5
flipped = np.random.choice([True, False], size=image.shape,
p=[p, 1 - p])
salted = np.random.choice([True, False], size=image.shape,
p=[q, 1 - q])
peppered = ~salted
out[flipped & salted] = low_clip-2*std
return torch.tensor(out)
s = 1
class SyntheticNoiseDatasetsalttr(Dataset):
def __init__(self, data, mode='train'):
self.mode = mode
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
global s
img = self.data[index][0]
s = s + 1
return add_noise_sp(img,s), img
b = 3425
class SyntheticNoiseDatasetsaltte(Dataset):
def __init__(self, data, mode='test'):
self.mode = mode
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
global b
img = self.data[index][0]
b = b + 1
return add_noise_sp(img,b), img
This how I create the dataset:
noisy_mnist_train = SyntheticNoiseDatasetsalttr(mnist_train, 'train')
noisy_mnist_test = SyntheticNoiseDatasetsaltte(mnist_test, 'test')
train_set, val_set = torch.utils.data.random_split(noisy_mnist_train, [55000, 5000], generator=torch.Generator().manual_seed(42))
I thought this would give me a constant dataset. Does this mean that when I train my neural net, it will see different noisy version of the same input everytime? When I run the following command, I get digit 5 but everytime it has different noise on it. It seems like everytime I call DataLoader, it’s running my datageneration code and giving me different noisy data of the same subject.
torch.manual_seed(123)
bt = DataLoader(val_set,batch_size=1, shuffle= True)
bt =next(iter(bt))
noisy,clean = bt
plt.imshow(noisy.squeeze(), cmap='gray')