Cholesky decomposition of a tensor - translation from tensorflow code

I am trying to transform a tensorflow code into a pytorch code but I am getting stucked in parts of it - I am learning pytorch, so maybe this is very basic…

The code is:

class Generator(nn.Module):

    def __init__(self, hilbert_size, num_points, noise=None):
        super(Generator, self).__init__()
    
        self.ops = nn.Parameter(torch.empty(1, hilbert_size, hilbert_size, num_points * 2))
        self.inputs = torch.empty((1, num_points), requires_grad=True)
    
        layer = nn.Linear(num_points, 16 * 16 * 2, bias=False)
        init.normal_(layer.weight, mean=0.0, std=0.02)

        self.x = nn.Sequential(
            layer,
            nn.LeakyReLU(),
            nn.Unflatten(1, (2,16,16))
            )

        self.conv_transpose_1 = nn.Sequential(
            nn.ConvTranspose2d(2, 64, kernel_size=4, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(64),
            #nn.InstanceNorm2d(x.shape[1])
            nn.LeakyReLU(),
        )

        self.conv_transpose_2 = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=1, padding=2, bias=False),
            nn.InstanceNorm2d(64),
            #nn.InstanceNorm2d(x.shape[1])
            nn.LeakyReLU(),
        )

        self.conv_transpose_3 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=1, padding=1, bias=False),
        )

        self.conv_transpose_4 = nn.Sequential(
            nn.ConvTranspose2d(32, 2, kernel_size=4, stride=1, padding=2, bias=False),
        )

        self.density_matrix = DensityMatrix()
        self.expectation = Expectation()
        self.noise = nn.GaussianNoise(noise)

    def forward(self, ops, inputs):
        x = self.x(inputs)
        x = self.conv_transpose_1(x)
        x = self.conv_transpose_2(x)
        x = self.conv_transpose_3(x)
        x = self.conv_transpose_4(x)
        x = self.density_matrix(x)
        complex_ops = convert_to_complex_ops(ops)
        prefactor = 1.0
        x = self.expectation(complex_ops, x, prefactor)
        x = self.noise(x)

        return x

Where:

class DensityMatrix(nn.Module):

    def __init__(self):
        super(DensityMatrix, self).__init__()

    def forward(self, inputs):
    """
    The forward function which applies the Cholesky decomposition

    Args:
        inputs (`torch.Tensor`): a 4D real valued tensor (batch_size, hilbert_size, hilbert_size, 2)
                                 representing batch_size random outputs from a neural network.
                                 The last dimension is for separating the real and imaginary part

    Returns:
        dm (`torch.Tensor`): A 3D complex valued tensor (batch_size, hilbert_size, hilbert_size)
                              representing valid density matrices from a Cholesky decomposition of the
                              cleaned input
    """
        T = clean_cholesky(inputs)
        return density_matrix_from_T(T)

And the cholesky function is:

def clean_cholesky(img):
    real = img[:, :, :, 0]
    imag = img[:, :, :, 1]

    diag_all = torch.diag(imag)
    diags = torch.linalg.diag(diag_all)

    imag = imag - diags
    imag = torch.linalg.band_part(imag, -1, 0)
    real = torch.linalg.band_part(real, -1, 0)
    T = torch.complex(real, imag)
    return T

The inputs:

In [381]: inputs
Out[381]: 
tensor([[1.9826e+05, 4.5778e-41, 5.0742e+02,  ..., 4.5779e-41, 2.1865e+05,
         4.5778e-41]], requires_grad=True)

In [382]: inputs.type 
Out[382]: <function Tensor.type>

I have two problems:

  • After x = self.conv_transpose_1(x) I got the following output:

     Out[387]: 
    tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
    
       [[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]],
    
       [[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]],
    
       ...,
    
       [[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]],
    
       [[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]],
    
       [[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]]], grad_fn=<LeakyReluBackward0>)
    

just a bunch of zeros.
And the second problem is:

  • Even if I ignore the zeros and move on with the code, the part x = self.density_matrix(x) gives the error:

    RuntimeError                     Traceback (most recent call last)
    Cell In[390], line 1
    ----> 1 zzz = density_matrix(yyyy)
    
     File ~/.virtualenvs/cgan/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
     1190 # If we don't have any hooks, we want to skip the rest of the logic in
     1191 # this function, and just call forward.
     1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
     1193         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1194     return forward_call(*input, **kwargs)
     1195 # Do not call functions when jit is used
     1196 full_backward_hooks, non_full_backward_hooks = [], []
    
    Cell In[3], line 50, in DensityMatrix.forward(self, inputs)
         36 def forward(self, inputs):
         37     """
         38     The forward function which applies the Cholesky decomposition
         39 
        (...)
         48                               cleaned input
         49     """
    ---> 50     T = clean_cholesky(inputs)
         51     return density_matrix_from_T(T)
    
     Cell In[389], line 9, in clean_cholesky(img)
           5 imag = img[:, 1, :, :]
           7 #diag_all = torch.linalg.diag_part(imag, k=0, padding_value=0)
           8 #diag_all = torch.diag(imag, k=0, padding_value=0)
     ----> 9 diag_all = torch.diag(imag)
           10 diags = torch.linalg.diag(diag_all)
           12 imag = imag - diags
    
     RuntimeError: matrix or a vector expected
    

And I have no idea on how to fix it and move on.
This is really really important for me. Any help to clarify this issue and fix it would be really appreciated.