How can I zero out the weights of a Conv2d
module? I am specifically not able to access the weights.
This should work:
conv = nn.Conv2d(3, 16, 3, 1, 1)
with torch.no_grad():
conv.weight.zero_()
print(conv.weight)
I am sorry, I have not specified that I need the solution in C++ (I thought the C++
tag was enough).
Anyway, I managed to solve the problem:
auto conv = torch::nn::Conv2d(
torch::nn::Conv2dOptions(3, 16, 3)
);
conv->requires_grad_(false);
conv->weight.zero_();
Ah, sorry I missed to check the tag. Yes, it should generally be sufficient but well…
1 Like