How can I jit a network whose output depends on a flag?

Hi All,

I was wondering if it’s possible to jit a network whose output depends on a flag? For example, I have a neural network that has an internal flag self.use_det. The network is represented as a nn.Module whose forward method comprises of a few nn.Linear layers that eventually produce a batch of matrices in the shape [B,N,N] where B is the batch size and N is the number of input nodes in the input layer. However, the returned value from the network is determined by the state of the self.use_det flag. If this flag is set to False, the network a tensor of shape [B,N,N] but if self.use_det = True then it return a tensor of shape[B,2] (due to the use of using torch.slogdet on the tensor.

Now, the question. Is it possible to jit a network where the output depends on this flag? Because I tried naively applying torch.jit.script(net) but I get the follwoing error,

Traceback (most recent call last):
  File "run_mcmc.py", line 53, in <module>
    net = torch.jit.script(net)
  File "~/anaconda3/lib/python3.8/site-packages/torch/jit/_script.py", line 942, in script
    return torch.jit._recursive.create_script_module(
  File "~/anaconda3/lib/python3.8/site-packages/torch/jit/_recursive.py", line 391, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "~/anaconda3/lib/python3.8/site-packages/torch/jit/_recursive.py", line 452, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "~/anaconda3/lib/python3.8/site-packages/torch/jit/_recursive.py", line 335, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError: 
Previous return statement returned a value of type Tensor but this return statement returns a value of type Tuple[Tensor, Tensor]:
  File "~/main.py", line 55
    else:
      sign, logabsdet = self.slogdet(matrices)
      return sign, logabsdet
      ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

The forward of the class concludes with,

    if(self.use_det):
      return matrices
    else:
      sign, logabsdet = self.slogdet(matrices)
      return sign, logabsdet

Is this possibe? Any help is apprecitated! Thank you! :smiley:

Yes, it should work as long as you don’t change the output type. In your case you could return an empty tensor in the if path as seen here:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, 10)
        self.fc2 = nn.Linear(1, 2)
        self.flag = True
        
    def forward(self, x):
        if self.flag:
            out = self.fc1(x)
            return out, torch.empty_like(out)
        else:
            out = self.fc2(x)
            out2 = out + 1
            return out, out2
    
# eager mode
model = MyModel()
x = torch.randn(1, 1)

out = model(x)
print(out[0].shape)
> torch.Size([1, 10])

model.flag = False
out = model(x)
print(out[0].shape)
> torch.Size([1, 2])

# script
model = torch.jit.script(model)
out = model(x)
print(out[0].shape)
> torch.Size([1, 2])

model.flag = True
out = model(x)
print(out[0].shape)
> torch.Size([1, 10])
1 Like

Ah, great! Thank you @ptrblck! :slight_smile:

Hi ptrblck,

I’ve just implemented your solution and I get an error. I was wondering if you could take a look?

Traceback (most recent call last):
  File "main.py", line 152, in <module>
    X, _ = sampler(burn_in)
  File "~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "~/anaconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "~/src/Samplers.py", line 83, in forward
    self.step()
  File "~/anaconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "~/src/Samplers.py", line 67, in step
    log_a = self._log_pdf(xcand).detach_() - self._log_pdf(self.chains).detach_()  #calculate log acceptance probability 
  File "~/anaconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "~/src/Samplers.py", line 47, in _log_pdf
    return torch.slogdet(self.network(x)[0])[1].mul(2).detach_()
  File "~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
RuntimeError: nvrtc: error: failed to open libnvrtc-builtins.so.11.1.
  Make sure that libnvrtc-builtins.so.11.1 is installed correctly.
nvrtc compilation failed: 

#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)


template<typename T>
__device__ T maximum(T a, T b) {
  return isnan(a) ? a : (a > b ? a : b);
}

template<typename T>
__device__ T minimum(T a, T b) {
  return isnan(a) ? a : (a < b ? a : b);
}

extern "C" __global__
void fused_tanh_add(double* t0, double* t1, double* aten_add) {
{
  double v = __ldg(t1 + (64 * (((512 * blockIdx.x + threadIdx.x) / 64) % 24) + ((512 * blockIdx.x + threadIdx.x) / 1536) * 1536) + (512 * blockIdx.x + threadIdx.x) % 64);
  double v_1 = __ldg(t0 + (64 * (((512 * blockIdx.x + threadIdx.x) / 64) % 24) + ((512 * blockIdx.x + threadIdx.x) / 1536) * 1536) + (512 * blockIdx.x + threadIdx.x) % 64);
  aten_add[(64 * (((512 * blockIdx.x + threadIdx.x) / 64) % 24) + ((512 * blockIdx.x + threadIdx.x) / 1536) * 1536) + (512 * blockIdx.x + threadIdx.x) % 64] = (tanh(v)) + v_1;
}
}

I assume this is more likely a non-pytorch issue?

I’m running Ubuntu 20.04, CUDA 11.2, Driver 460.80 and pytorch is 1.8.1+cu111 (So I assume it’s running CUDA 11.1, could that be an issue?)

Thank you! :slight_smile:

This error seems to be related to this issue, which was already fixed, so you might want to update to the latest release (1.9.0) or the nightly.

1 Like

@ptrblck saves the day once again! Thank you! :slight_smile: