Help needed: confirm bug in graph tracing

Can anybody test this piece of code and tell me if it crashes on your setup too please?

import torch
import torch.nn as nn
import torch.nn.functional as F


class Localizer(nn.Module):
    def __init__(self):
        super(Localizer, self).__init__()

        self.conv1 = nn.Conv2d(1, 8, kernel_size=5, bias=False)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5, bias=False)
        self.fc1 = nn.Linear(8 * 8 * 16, 32)
        self.fc2 = nn.Linear(32, 2 * 3)

        nn.init.normal_(self.fc1.weight, 0, 1e-5)
        self.fc2.bias.data.copy_(torch.tensor([1., 0., 0., 0., 1., 0.]))

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 3)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)
        x = x.view(-1, 8 * 8 * 16)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x.view(-1, 2, 3)


class STN(nn.Module):
    def __init__(self, localizer):
        super(STN, self).__init__()

        self.localizer = localizer

    def forward(self, x):
        theta = self.localizer(x)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x


# OK
net = Localizer()
torch.jit.trace(torch.rand(16, 1, 64, 64))(net)

# OK
net = STN(Localizer())
output = net(torch.rand(16, 1, 64, 64))
print(output.shape)

# KO
torch.jit.trace(torch.rand(16, 1, 64, 64))(net)
# torch.onnx.export(net, torch.rand(16, 1, 64, 64), "mymodel", export_params=True)

Doesn’t work for me in version 0.5.0a0+e62c3a4 throwing the following error:

RuntimeError: torch/csrc/autograd/generated/VariableType.cpp:27992: transpose: Assertion `jit::tracer::ArgumentStash::empty()` failed.

from grid = F.affine_grid(theta, x.size()).

Thanks, I have opened a bug at https://github.com/pytorch/pytorch/issues/8978