Hello,
I am very new to this topic but I am trying to prune the model I am working with. For reference, I am using this page. The model is quite big, containing different encoders, ResNet modules, and decoders. So, I’m guessing that I have to prune each network individually (I couldn’t find a reference where the whole model is being pruned together, but please attach some links where it’s being done). The list of different modules are like:
module.model_enc1.1.weight
module.model_enc1.1.bias
module.model_enc1.2.weight
module.model_enc1.2.bias
module.model_enc1.4.weight
module.model_enc1.4.bias
module.model_enc1.5.weight
.
.
.
So I’m only taking the module.model_enc1.1.weight
using the following code:
test = netM.module.model_enc1
where netM
contains the model weights (<class 'torch.nn.parallel.data_parallel.DataParallel'>
).
So test
contains the following model:
Sequential(
(0): ReflectionPad2d((3, 3, 3, 3))
(1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
)
And when I run the pruning method by pytorch
prune.random_unstructured(test, name='1.weight', amount=0.3)
, I get the following error:
AttributeError Traceback (most recent call last)
in
----> 1 prune.random_unstructured(test, name=‘1.weight’, amount=0.3)/usr/local/lib/python3.6/dist-packages/torch/nn/utils/prune.py in random_unstructured(module, name, amount)
851
852 “”"
→ 853 RandomUnstructured.apply(module, name, amount)
854 return module
855/usr/local/lib/python3.6/dist-packages/torch/nn/utils/prune.py in apply(cls, module, name, amount)
473 “”"
474 return super(RandomUnstructured, cls).apply(
→ 475 module, name, amount=amount
476 )
477/usr/local/lib/python3.6/dist-packages/torch/nn/utils/prune.py in apply(cls, module, name, *args, **kwargs)
155 # starting from the state it is found in prior to this iteration of
156 # pruning
→ 157 orig = getattr(module, name)
158
159 # If this is the first time pruning is applied, take care of moving/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in getattr(self, name)
592 return modules[name]
593 raise AttributeError(“‘{}’ object has no attribute ‘{}’”.format(
→ 594 type(self).name, name))
595
596 def setattr(self, name, value):AttributeError: ‘Sequential’ object has no attribute ‘1.weight’
How do I fix this? Is there any better way to prune these networks?