How to use a scalar weight

Hi!

How can I use a scalar as a weight on pytorch and use state_dict for saving the model?

For example, I just want to multiply the whole tensor by w1 and then save it on my state_dict.
If I do as follows, the state_dict will be empty

import torch
from torch import nn

class Network(nn.Module):
    def __init__(self):
        super().__init__()

        self.w1 = torch.rand([1])
    
        
    def forward(self, x):
        x = self.w1 * x     
        return x
    
model = Network()
print(model.state_dict())

What is the correct way for doing it?

you should register it as a parameter, using either nn.Parameter or register_parameter

1 Like

Thank you so much! :smiley:

1 Like