How do I print the scheduler nicely in PyTorch?

for example look how nice the optimizer is (with all the fields):

optimizer
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 0.001
    lr: 0.001
    weight_decay: 0
)

but look at the scheduler:

scheduler
<torch.optim.lr_scheduler.MultiStepLR at 0x12035d4a8>

why does the scheduler print so badly?

1 Like

To answer your question, that’s most likely because the scheduler does not have as important parameters as the optimizer, and the __str__() method has not been implemented.

You can either inherit from MultiStepLR and create your own subclass, with a __str__() method that prints the elements you want, or create an external function that extracts the elements you want directly (e.g. milestones, gamma, last_epoch…)

3 Likes

The parameters are obviously important! They affect the final generalization of any model!

Anyway, thats a pitty its not implemented. How does one implement it and pull request it?

Have a look at the CONTRIBUTING document for some guide lines.

If you are interested in contributing to PyTorch, your contributions will fall into two categories:

You want to propose a new feature and implement it.
    Post about your intended feature, and we shall discuss the design and implementation. Once we agree that the plan looks good, go ahead and implement it.
You want to implement a feature or bug-fix for an outstanding issue.
    Search for your issue here: https://github.com/pytorch/pytorch/issues
    Pick an issue and comment on the task that you want to work on this feature.
    If you need more context on a particular issue, please ask and we shall provide.

Once you finish implementing a feature or bug-fix, please send a Pull Request to https://github.com/pytorch/pytorch

1 Like

One way is similar to what @alex.veuthey mentioned, you can implement __repr__() method in your scheduler class. The following is __repr__() method in optimizer class.

def __repr__(self):
        format_string = self.__class__.__name__ + ' ('
        for i, group in enumerate(self.param_groups):
            format_string += '\n'
            format_string += 'Parameter Group {0}\n'.format(i)
            for key in sorted(group.keys()):
                if key != 'params':
                    format_string += '    {0}: {1}\n'.format(key, group[key])
        format_string += ')'
        return format_string

It exactly does what you want when you print(optimizer), so you can do something like this.

Another way you can directly call state_dict() method just like print(scheduler.state_dict()) and it will return some parameters you might want to look into.