Hi. I wanna change weight parameters to zeros or ones.
Here is my code.
if args.convpr == 1:
for i in range(0, len(net.module.features)):
fake_input_size = fake_input.size()
if net.module.features[i].__class__.__name__ == 'Conv2d':
weight_size = net.module.features[i].weight.size()
bias_size = net.module.features[i].bias.size()
for j in range(0, weight_size[0]):
filter_sum[idx][j] = torch.sum(torch.abs(net.module.features[i].weight[j])) # i = layer index, j = channel index
sorted_values_conv, sorted_indices_conv = torch.topk(filter_sum[idx][0:weight_size[0]], round((1 - args.pr) * weight_size[0]))
for j in range(0, weight_size[0]):
if j in sorted_indices_conv:
params_mask = list(mask.module.features[i].parameters())
params_mask[0][j] = torch.ones(weight_size[1], weight_size[2], weight_size[3], requires_grad=True)
else:
params_mask = list(mask.module.features[i].parameters())
params_mask[0][j] = torch.zeros(weight_size[1], weight_size[2], weight_size[3], requires_grad=True)
...
else:
pass
if args.fcpr == 1:
for i in range(0, len(net.module.classifier)):
if net.module.classifier[i].__class__.__name__ == 'Linear':
weight_size = net.module.classifier[i].weight.size()
bias_size = net.module.classifier[i].bias.size()
if idx_fc == 0:
pass
else:
for j in range(0, weight_size[1]):
filter_sum[idx][j] = torch.sum(torch.abs(net.module.classifier[i].weight[:,j]))
sorted_values_fc, sorted_indices_fc = torch.topk(filter_sum[idx][0:weight_size[1]], round((1 - args.pr) * weight_size[1]))
for j in range(0, weight_size[1]):
if j in sorted_indices_fc:
params_mask = list(mask.module.classifier[i].parameters())
params_mask[0][:,j] = torch.ones(weight_size[0], requires_grad=True)
else:
params_mask = list(mask.module.classifier[i].parameters())
params_mask[0][:,j] = torch.zeros(weight_size[0], requires_grad=True)
....
else:
pass
In this code, if args.convpr == 1: part work well. However, if args.fcpr == 1: part didn’t work.
When i did print(mask.module.classifier[1].weight.data), there is no change.
When i did print(params_mask[0][:,j]), changes are reflected.
I cannot find difference between two parts.
Where is the problem?