How to use TorchScript with Dropout for training?

My model is:

import torch.nn as nn

class Net(nn.Module):
    def __init__(self, num_input, num_hidden, num_classes, dropout,
        super(Net, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(num_input, num_hidden)
        self.fc2 = nn.Linear(num_hidden, num_classes)

        if activation == 'tanh':
            self.activation_f = torch.tanh
        elif activation == 'relu':
            self.activation_f = torch.relu

    def forward(self, x):
        x = self.activation_f(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

I call my model for instance as:

model = Net(14,512,2,0.2).to(device)

However once I use TorchScript as:

traced_model = torch.jit.trace(model, torch.zeros([1, 14], dtype=torch.float))

I receive the following error:

IndexError: The shape of the mask [2] at index 0 does not match the shape of the indexed tensor [1, 2] at index 0

I know that if I use model.eval() I don’t receive any error BUT I want to use my model for training and not evaluation. Does anybody know any solution or workaround for such problem?

PS: I am using PyTorch version 1.4.

I am unable to reproduce this error with Python 3.8.3 and pytorch-1.4.0 from conda.

Based on your error message, it seems like your input shape might be the problem. Try using an input with shape [14] as the input during tracing:

traced_model = torch.jit.trace(model, torch.zeros([14], dtype=torch.float))

Stepping back a bit, I would advise you to use scripting (torch.jit.script) rather than tracing for this use case. The reason is that control flow is not visible to tracing and from what I remember, the backward pass for dropout predicates whether or not the incoming gradient should be backpropagated based on whether the corresponding input was passed through during the forward pass.

Hi, but it is not possible to take input of size [14], where is the batch dimension ([1,14])? I don’t think the dimension is the problem since once I fix the dropout rate to 0.0 then there is no error message and it works just fine.

I still cannot reproduce your error. Can you provide more details on your setup? There is an environment details collection script in the PyTorch repository.