Greetings, you can follow this prescription.
For example simply wrap this in a nn.Parameter
:
def normal(shape):
return nn.Parameter(torch.randn(size=shape)*0.01)
for foo in gnn.parameters():
print(foo)
Output:
Parameter containing:
tensor([[-0.0022]], requires_grad=True)
Parameter containing:
tensor([[0.0037]], requires_grad=True)