Hi All,
I was wondering if it’s possible to jit a network whose output depends on a flag? For example, I have a neural network that has an internal flag self.use_det
. The network is represented as a nn.Module
whose forward
method comprises of a few nn.Linear
layers that eventually produce a batch of matrices in the shape [B,N,N]
where B
is the batch size and N
is the number of input nodes in the input layer. However, the returned value from the network is determined by the state of the self.use_det
flag. If this flag is set to False, the network a tensor of shape [B,N,N]
but if self.use_det = True
then it return a tensor of shape[B,2]
(due to the use of using torch.slogdet
on the tensor.
Now, the question. Is it possible to jit
a network where the output depends on this flag? Because I tried naively applying torch.jit.script(net)
but I get the follwoing error,
Traceback (most recent call last):
File "run_mcmc.py", line 53, in <module>
net = torch.jit.script(net)
File "~/anaconda3/lib/python3.8/site-packages/torch/jit/_script.py", line 942, in script
return torch.jit._recursive.create_script_module(
File "~/anaconda3/lib/python3.8/site-packages/torch/jit/_recursive.py", line 391, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "~/anaconda3/lib/python3.8/site-packages/torch/jit/_recursive.py", line 452, in create_script_module_impl
create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
File "~/anaconda3/lib/python3.8/site-packages/torch/jit/_recursive.py", line 335, in create_methods_and_properties_from_stubs
concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError:
Previous return statement returned a value of type Tensor but this return statement returns a value of type Tuple[Tensor, Tensor]:
File "~/main.py", line 55
else:
sign, logabsdet = self.slogdet(matrices)
return sign, logabsdet
~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
The forward of the class concludes with,
if(self.use_det):
return matrices
else:
sign, logabsdet = self.slogdet(matrices)
return sign, logabsdet
Is this possibe? Any help is apprecitated! Thank you!