To facilitate managing trainable hyperparameters, I am looking for a way to create a class Hyperparameter
that acts both as a nn.Parameter
and a nn.Module
. In particular, I would like to use Hyperparameter
objects both as a nn.Parameter
(e.g. for algebraic manipulations) but still have access to the interface provided by nn.Module
to for example store the objects in a nn.ModuleDict
along with other modules or use methods like zero_grad()
, parameters()
.
I tried to accomplish this through multiple inheritance but I think this might be dangerous:
import torch
class Hyperparameter(torch.nn.Parameter, torch.nn.Module):
def __new__(cls, tensor, name):
return torch.nn.Parameter.__new__(cls, data=tensor)
def __init__(self, tensor, name):
torch.nn.Parameter.__init__(self)
torch.nn.Module.__init__(self)
self.register_parameter(name, self)
hp1 = Hyperparameter(torch.ones(5), "test1")
hp2 = Hyperparameter(torch.ones(8), "test2")
hp_dict = torch.nn.ModuleDict({"hp1": hp1, "hp2": hp2})
hp_dict.to(torch.device("cpu"))
# KeyError: "attribute 'data' already exists"
This works for the things I described, but calling to()
throws an error. I think something is no longer as nn.Module
expects it to be, but I am not sure what it is.