Hello. I am messing around with flow-based generative models, and decided to play around with a simple implementation of Invertible Convolutions. I tried adapting one from a time-series model, only to catch this error:
PS C:\Users\AyazA> & C:/Python37/python.exe c:/Users/AyazA/Desktop/RRF/modules.py
Traceback (most recent call last):
File "c:/Users/AyazA/Desktop/RRF/modules.py", line 49, in <module>
z, log = conv(x)
File "C:\Python37\lib\site-packages\torch\nn\modules\module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
File "c:/Users/AyazA/Desktop/RRF/modules.py", line 43, in forward
log_det_W = height * width * torch.logdet(W)
IndexError: dimension specified as -2 but tensor has no dimensions
Here is the code:
import torch
import torch.nn as nn
import torch.nn.functional as F
# From https://github.com/NVIDIA/waveglow/blob/master/glow.py
class Invertible1x1Conv(nn.Module):
"""
The layer outputs both the convolution, and the log determinant
of its weight matrix. If reverse=True it does convolution with
inverse
"""
def __init__(self, c):
super(Invertible1x1Conv, self).__init__()
self.conv = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0,
bias=False)
# Sample a random orthonormal matrix to initialize weights
W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
# Ensure determinant is 1.0 not -1.0
if torch.det(W) < 0:
W[:,0] = -1*W[:,0]
W = W.view(c, c, 1)
self.conv.weight.data = W
def forward(self, z, reverse=False):
# shape
batch_size, channels, height, width = z.size()
W = self.conv.weight.squeeze()
if reverse:
if not hasattr(self, 'W_inverse'):
# Reverse computation
W_inverse = W.float().inverse()
W_inverse = Variable(W_inverse[..., None])
if z.type() == 'torch.cuda.HalfTensor':
W_inverse = W_inverse.half()
self.W_inverse = W_inverse
z = F.conv2d(z, self.W_inverse, bias=None, stride=1, padding=0)
return z
else:
# Forward computation
log_det_W = height * width * torch.logdet(W)
z = self.conv(z)
return z, log_det_W
x = torch.rand([1, 1, 28, 28])
conv = Invertible1x1Conv(1)
z, log = conv(x)
print(z.shape)