Speed up torch.svd()

I am doing some classification task for CIFAR10 with Pytorch and for each iteration I have to do some preprocessing on each batch before I would be able to feed forward to the model. Below is the code for the preprocessing part on each batch:

 S = torch.zeros((batch_size, C, H, W))
            for i in range(batch_size):
               Img = Batch[i, :, :, :]
               for c in range(C):                
                   U, _, V = torch.svd(Img[c])
                   S[i, c] = U[:, 0].view(-1, 1).matmul(V[:, 0].view(1, -1))

However, this calculation is very slow. Is there any way that I could speed up this code?

Here is a vectorized implementation:

import torch

N, C, H, W = 8, 3, 50, 100

# Vectoried implementation
batch = torch.rand(N, C, H, W)
u, _, v = torch.svd(batch.view(-1, H, W))

s = u[:, :, 0, None].matmul(v[:, :, 0, None].transpose(2, 1)).view(N, C, H, W)

# Your current implementation
S = torch.zeros((N, C, H, W))
for i in range(N):
    Img = batch[i, :, :, :]
    for c in range(C):                
        U, _, V = torch.svd(Img[c])
        S[i, c] = U[:, 0].view(-1, 1).matmul(V[:, 0].view(1, -1))


# Verify the result
print(S.allclose(s)) # Prints True

Thanks for the code.
I tried the vectorized version before and the improvement is marginal. Since I only need the first singular values, and its associated eigenvectors, is there a way to avoid full svd computation?

1 Like