Columns not contiguous after ConvTranspose2D

I need to build a model that takes the output of a transpose convolutional network and uses it as the mean parameter in a Normal distribution.

The code is as follows:

import torch
import torch.nn as nn
from torch.distributions import Normal

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.convT = nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=1)
        self.conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1)

    def loss(self, x, permute=False, conv=False):
        pi_mu = self.conv(x) if conv else self.convT(x)
        pi_mu = pi_mu.permute(0, 2, 3, 1).unsqueeze(1) if permute else pi_mu
        return torch.sum(Normal(pi_mu, torch.ones_like(pi_mu)).log_prob(torch.ones_like(pi_mu)))

model = Model()
loss = model.loss(torch.randn(8, 3, 64, 64))
loss.backward()
print('ConvTranspose and no permute/unsqueeze done.')

loss = model.loss(torch.randn(8, 3, 64, 64), permute=True, conv=True)
loss.backward()
print('Conv and permute/unsqueeze done.')

loss = model.loss(torch.randn(8, 3, 64, 64), permute=True)
loss.backward()
print('ConvTranspose and permute/unsqueeze done.')

I’ve outlined three variations in the code, the third of which raises an error:

  1. Use a ConvTranspose2D layer but do not permute the output. Works ok.
  2. Use a Conv2D later and permute the output. Works ok.
  3. Use a ConvTranspose2D and permute the output. Raises an error (see below).

The output is as follows (python3.7, torch1.4.0)

ConvTranspose and no permute/unsqueeze done.
Conv and permute/unsqueeze done.

Traceback (most recent call last):
  File "/Users/alex/Developer/project/test.py", line 26, in <module>
    loss.backward()
  File "/Users/alex/miniconda3/envs/pytorch_env/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/Users/alex/miniconda3/envs/pytorch_env/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: columns needs to be contiguous