Hi,
I have a custom module (let’s call it MyModule
) that has conditional logic in its forward
method. This custom module is part of an nn::Sequential
.
I want to be able to do something along the following lines:
auto flag = true;
auto module = MyModule(flag);
auto sequential = nn::Sequential(
module
);
auto x = torch::rand({T, inChannels});
auto y = sequential->forward(x);
sequential->modules()[0]->as<MyModule>->flag = false;
x = torch::rand({T, inChannels});
y = sequential->forward(x);
How do I go about this? The line
sequential->modules()[0]->as<MyModule>->flag = false;
doesn’t really set the flag
(I think it only returns a temporary rvalue and changes it’s flag, but not of the underlying module).
Thanks.