How to use Scripting with custom batchNorm?

Hi,
I replaced the default implementation of batchNorm by this one and I added to it other functions that store the running_mean and running_var during the training however before I was using the Tracing to convert the module, but with the new implementation, it does not work viewing that Tracing does not record functions which are data-dependent or any control-flow, I tried to use Scripting instead but I got the error below.

The function that saves the running_mean and running variance which I call inside the batchNorm:

@torch.jit.script
def save_statisctics(x : torch.Tensor):
    random_string = ''.join(random.choice(string.ascii_lowercase) for i in range(16))
    current_directory = os.path.dirname(os.path.realpath(__file__))

    final_directory = os.path.join(current_directory, r'batchStatistics')
    if not os.path.exists(final_directory):
        os.makedirs(final_directory)

    save_x = x.numpy()
    path = os.path.join(final_directory, '{:s}'.format(random_string))

    np.save(path, save_x)

Error: torch.jit.frontend.UnsupportedNodeError: GeneratorExp aren’t supported:

Please what are the changes that I should make in both BatchNorm script and the function “save_statisctics” in order to make it work? And is there another way than Scripting?
Thank you in advance!

Do you need to script this save method? It seems it just stored some numpy arrays to a file to you might write this method manually outside of the model without scripting. Alternatively you could also add the jit.ignore decorator to this method so that it won’t be exported.

I am working in a distributed setting, the server serializes the model using torch.jit.trace and sends it to the workers, my issue after using the custom batchNorm is that with Tracing the worker when it receives the model, the batchNorm layer does not work(no access to the if/else statement).

for the save function I also tried jit.ignore, it passes with no error but when I use it inside the if statement it is not executed for the reason I mentioned above.

Could you post the if condition which is not working?
I’m still not fully understanding the use case I believe.
Wouldn’t it be possible to setup the model in the workers without scripting the save and load functions?

Thank you @ptrblck and sorry for the ambiguity. I will try to make it clear/

This is your implimentaion of BatchNorm with some modifications(call to the saving function):

class BatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(BatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    #@torch.jit.script_method
    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:

            save_statisctics(self.running_mean)

            mean = input.mean([0, 2, 3])
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input

the saving function:

@torch.jit.ignore
def save_statisctics(x : torch.Tensor):
    random_string = ''.join(random.choice(string.ascii_lowercase) for i in range(16))
    current_directory = os.path.dirname(os.path.realpath(__file__))

    final_directory = os.path.join(current_directory, r'batchStatistics')
    if not os.path.exists(final_directory):
        os.makedirs(final_directory)

    save_x = x.numpy()
    path = os.path.join(final_directory, '{:s}'.format(random_string))

    np.save(path, save_x)

and this is the model I am using:

class CNNNet_CIFAR_MyBN(nn.Module):

    def __init__(self, conv1_dim=100, conv2_dim=150, conv3_dim=250, conv4_dim=500):
        super(CNNNet_CIFAR_MyBN, self).__init__()
        self.conv4_dim = conv4_dim

        self.conv1 = nn.Conv2d(3, conv1_dim, 5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(conv1_dim, conv2_dim, 3, stride=1, padding=2)
        self.conv3 = nn.Conv2d(conv2_dim, conv3_dim, 3, stride=1, padding=2)
        self.conv4 = nn.Conv2d(conv3_dim, conv4_dim, 3, stride=1, padding=2)

        self.pool = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(conv4_dim * 3 * 3, 270) # 3x3 is precalculated and written, you need to do it if you want to change the # of filters
        self.fc2 = nn.Linear(270, 150)
        self.fc3 = nn.Linear(150, 10)
     
        
        self.normalize1 = syft_nn.BatchNorm2d(conv1_dim)
        self.normalize2 = syft_nn.BatchNorm2d(conv2_dim)
        self.normalize3 = syft_nn.BatchNorm2d(conv3_dim)
        self.normalize4 = syft_nn.BatchNorm2d(conv4_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.normalize1((self.conv1(x))))) # first convolutional then batch normalization then relu then max pool
        x = self.pool(F.relu(self.normalize2((self.conv2(x)))))
        x = self.pool(F.relu(self.normalize3((self.conv3(x)))))
        x = self.pool(F.relu(self.normalize4((self.conv4(x)))))

        x = x.view(-1, self.conv4_dim * 3 * 3) # flattening the features
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

and here is how I call it:

model = CNNNet_CIFAR_MyBN()
traced_model = torch.jit.trace(model, torch.zeros([1, 3, 32, 32], dtype=torch.float))

By using Tracing like this, the custom batchNorm is not executed when the model is received by the workers. I tried to fix it by adding torch.jit.script_method in the BatchNorm( the commented line). However, I get several errors:

torch.jit.frontend.UnsupportedNodeError: with statements aren’t supported

and even when I ignore the “with statement” I get this error:

TypeError: ‘ScriptMethodStub’ object is not callable

I removed the with torch.no_grad() statement and just detached the mean and var tensors.
Also, I’ve added a condition for the save method as used in this example and scripting seems to work.
Here is a reduced code snippet:

@torch.jit.ignore
def save_statisctics(x : torch.Tensor):
    print('save called')


class BatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True, save=False):
        super(BatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)
        self.save = save

    #@torch.jit.script_method
    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:

            if self.save:
                save_statisctics(self.running_mean)

            mean = input.mean([0, 2, 3]).detach()
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False).detach()
            n = input.numel() / input.size(1)
            
            self.running_mean = exponential_average_factor * mean\
                + (1 - exponential_average_factor) * self.running_mean
            # update running_var with unbiased var
            self.running_var = exponential_average_factor * var * n / (n - 1)\
                + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input


class CNNNet_CIFAR_MyBN(nn.Module):

    def __init__(self, conv1_dim=100, save=False):
        super(CNNNet_CIFAR_MyBN, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, 5, stride=1, padding=2)
        self.normalize1 = BatchNorm2d(conv1_dim, save=save)

    def forward(self, x):
        x = F.relu(self.normalize1((self.conv1(x))))
        return x


model = CNNNet_CIFAR_MyBN(save=False)
x = torch.randn(1, 3, 32, 32)
out = model(x)
scripted_model = torch.jit.script(model)
out = scripted_model(x)

Let me know, if this would work for you.

Eveything seems fine until I execute scripted_model = torch.jit.script(model) I got the error below:

RuntimeError: Could not get qualified name for class 'conv2d': __module__ can't be None.

A similar issue is mentioned here but it is still open and no solution is proposed.
Any idea about what might be the reason, please?
PS: I am working with torch 1.4.0 (CPU)