How to save params with modified out channels

Here is one module inside my model:

class ConstantInput(nn.Module):
    def __init__(self, channel, size=4):
        super().__init__()

        self.input = nn.Parameter(torch.randn(1, channel, size, size))

    def forward(self, input):
        batch = input.shape[0]
        out = self.input.repeat(batch, 1, 1, 1)

        if hasattr(self, 'first_k_oup') and self.first_k_oup is not None:  # support dynamic channel
            assert self.first_k_oup <= out.shape[1]
            return out[:, :self.first_k_oup]
        else:
            return out

I will modify the first_k_oup outside, but need to save the modified param instead of the original(before modification), how could I do it?

.............
        self.model.generator = Generator(......
        pretrained_dict_generator = torch.load('./weights/ckpt-best.pt', map_location=torch.device('cpu'))['g_ema']
        self.model.generator.load_state_dict(pretrained_dict_generator)

        set_uniform_channel_ratio(self.model.generator, 0.25) #this line will modify the output channels, like first_k_oup in above class definition
..............

after I modify the output channels with set_uniform_channel_ratio, I want to save the model like:
torch.save(self.model.generator.state_dict(), 'gen0.25.pth')

but unfortunately, the model size is same with the one before modification.

@ptrblck

Not sure whether I understand correctly


class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.param = nn.Parameter(torch.zeros(10, 100))

    def forward(self, x):
        return x

model = MyModel()
print("#params before: {}".format(sum(x.numel() for x in model.parameters())))
# modify parameters
model.param.data = model.param.data[:, :50]
print("#params after: {}".format(sum(x.numel() for x in model.parameters())))

# output:
#params before: 1000
#params after: 500

thank you! let me try with it