Hi,
I’m trying to use torch.nn.utils.parametrizations.orthogonal
to achieve unitary weights in my self-defined module on PyTorch version 1.10.2. But it always put out an error that:
Module 'DimTransModel(
(increase_mlp): Linear(in_features=8, out_features=16, bias=False)
(decrease_mlp): Linear(in_features=8, out_features=16, bias=False)
)' does not have a parameter, a buffer, or a parametrized element with name 'weight'
I’m not sure whether it is a bug of PyTorch itself.
This is my implementation of the module:
class DimTransModel(nn.Module):
weight = torch.Tensor
def __init__(self, num_antenna, num_output):
super(DimTransModel, self).__init__()
self.increase_mlp = nn.Linear(num_antenna, num_output,
bias=False).to(torch.complex64)
self.decrease_mlp = nn.Linear(num_antenna, num_output,
bias=False).to(torch.complex64)
self.weight = nn.Parameter(torch.Tensor(num_antenna, num_output)).to(
torch.complex64)
def increase_dim(self, data):
self.increase_mlp.weight = self.weight
return self.increase_mlp(data)
def decrease_dim(self, data):
self.decrease_mlp.weight = self.weight.conj().transpose()
return self.decrease_mlp(data)
And this is the function call:
self.dim_trans_model = nn.utils.parametrizations.orthogonal(
DimTransModel(num_antenna, num_antenna * 2), "weight")