Hi everyone,
I’m trying to add noise layer into pretrained Resnet model as below:
def fault_model(model):
1 for child in model.module.layer1:
2 child.conv1 = nn.Sequential(Noise(child.conv1.weight, n_tr=0.01), child.conv1)
3 child.conv2 = nn.Sequential(Noise(child.conv2.weight, n_tr=0.01), child.conv2)
4 for child in model.module.layer2:
5 child.conv1 = nn.Sequential(Noise(child.conv1.weight, n_tr=0.01), child.conv1)
6 child.conv2 = nn.Sequential(Noise(child.conv2.weight, n_tr=0.01), child.conv2)
7 for child in model.module.layer3:
8 child.conv1 = nn.Sequential(Noise(child.conv1.weight, n_tr=0.01), child.conv1)
9 child.conv2 = nn.Sequential(Noise(child.conv2.weight, n_tr=0.01), child.conv2)
10 # for name, layer in model.module.named_modules():
11 # if ('layer' in name) and ( 'conv' in name ):
12 # # name += 'noise'
13 # layer = nn.Sequential(Noise(layer.weight, n_tr=0.01), layer)
14 # for name, layer in model.module.named_modules():
15 # if ('layer' in name) and ( 'conv' in name ):
16 # print(layer)
17 return model
18
19 class Noise(nn.Module):
20 def __init__(self, weights, n_tr):
21 super(Noise, self).__init__()
22 self.weights = weights
23 self.n_tr = n_tr
24
25 def forward(self):
26 return add_noise(self.weights, self.n_tr)
27
28 def add_noise(weights, n_tr):
29 with torch.no_grad():
30 weight_max = torch.max(weights)
31 weights_std = torch.std(weights)
32 sigma_tr = n_tr*weight_max
33 noise = torch.rand_like(weights)*sigma_tr
34 weights.add_(noise)
35 alpha = 2.0
36 weights.clamp_(-alpha*weights_std, alpha*weights_std)
But I got this error, I don’t know what cause this:
Traceback (most recent call last):
File "compress_classifier.py", line 240, in <module>
main()
File "compress_classifier.py", line 78, in main
app.run_training_loop()
File "/home/th.nguyen/distiller/distiller/apputils/image_classifier.py", line 211, in run_training_loop
top1, top5, loss = self.train_validate_with_scheduling(epoch)
File "/home/th.nguyen/distiller/distiller/apputils/image_classifier.py", line 146, in train_validate_with_scheduling
top1, top5, loss = self.train_one_epoch(epoch, verbose)
File "/home/th.nguyen/distiller/distiller/apputils/image_classifier.py", line 129, in train_one_epoch
loggers=[self.tflogger, self.pylogger], args=self.args)
File "/home/th.nguyen/distiller/distiller/apputils/image_classifier.py", line 608, in train
output = model(inputs)
File "/home/th.nguyen/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/th.nguyen/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
return self.module(*inputs[0], **kwargs[0])
File "/home/th.nguyen/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/th.nguyen/distiller/distiller/models/cifar10/resnet_cifar.py", line 142, in forward
x = self.layer1(x)
File "/home/th.nguyen/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/th.nguyen/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/home/th.nguyen/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/th.nguyen/distiller/distiller/models/cifar10/resnet_cifar.py", line 72, in forward
out = self.conv1(x)
File "/home/th.nguyen/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/th.nguyen/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/home/th.nguyen/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
TypeError: forward() takes 1 positional argument but 2 were given
I don’t know where the error is, any though? Thank you