# Speed up torch.svd()

I am doing some classification task for CIFAR10 with Pytorch and for each iteration I have to do some preprocessing on each batch before I would be able to feed forward to the model. Below is the code for the preprocessing part on each batch:

`````` S = torch.zeros((batch_size, C, H, W))
for i in range(batch_size):
Img = Batch[i, :, :, :]
for c in range(C):
U, _, V = torch.svd(Img[c])
S[i, c] = U[:, 0].view(-1, 1).matmul(V[:, 0].view(1, -1))
``````

However, this calculation is very slow. Is there any way that I could speed up this code?

Here is a vectorized implementation:

``````import torch

N, C, H, W = 8, 3, 50, 100

# Vectoried implementation
batch = torch.rand(N, C, H, W)
u, _, v = torch.svd(batch.view(-1, H, W))

s = u[:, :, 0, None].matmul(v[:, :, 0, None].transpose(2, 1)).view(N, C, H, W)

S = torch.zeros((N, C, H, W))
for i in range(N):
Img = batch[i, :, :, :]
for c in range(C):
U, _, V = torch.svd(Img[c])
S[i, c] = U[:, 0].view(-1, 1).matmul(V[:, 0].view(1, -1))

# Verify the result
print(S.allclose(s)) # Prints True
``````

Thanks for the code.
I tried the vectorized version before and the improvement is marginal. Since I only need the first singular values, and its associated eigenvectors, is there a way to avoid full svd computation?

1 Like