Efficient implementation for input covariance

I need to compute covariance of input feature maps of a Conv2d layer.
I’m using a forward_pre_hook to achieve this, as follows:

def _forward_hook(self, module, input):
        if torch.is_grad_enabled():
            x = input[0].data
            if module not in self.m_aa:
                self.m_aa[module] = torch.zeros((x.size(1), x.size(1))).to(device)
            if isinstance(module, nn.Conv2d):
                # [N, C, H, W]
                for h in range(x.size(2)): # Parallelizable?
                    for w in range(x.size(3)):
                        self.m_aa[module] += (x[:,:,h,w].T @ x[:,:,h,w])

This is extremely slow. How can I better utilize the compute by parallelizing over H and W of the Kernel?

I think you can express this computation as either a batched matrix multiply or an einsum that implements a batched matrix multiply:

import time
import torch

x = torch.randn(64, 3, 64, 64, dtype=torch.double)
test = torch.randn(32, 48)
def orig(x):
    out = torch.zeros(x.size(1), x.size(1), dtype=torch.double)
    for h in range(x.size(2)):
        for w in range(x.size(3)):
            out += x[:,:,h,w].T @ x[:,:,h,w]
    return out

def new(x): # using batchmatmul
    n = x.size(0)
    c = x.size(1)
    permuted = x.permute(2, 3, 0, 1) # move matmul dimensions to end
    batched = permuted.reshape(-1, n, c) # combine h,w dimensions to be batched
    temp = torch.matmul(batched.permute(0, 2, 1), batched) # permute is like transpose here
    return torch.sum(temp, dim=0) # sum along batch (h,w) dimension

def new2(x): # using einsum
    temp = torch.einsum('kchw,kdhw->cdhw', x, x)
    return torch.sum(temp, dim=[2,3])

# warmup runs
for i in range(3):
    out = orig(x)
    out2 = new(x)
    out3 = new2(x)

print(torch.allclose(out, out2))
print(torch.allclose(out, out3))

t1 = time.time()
out = orig(x)
t2 = time.time()
out2 = new(x)
t3 = time.time()
out3 = new2(x)
t4 = time.time()

print("original", t2-t1)
print("batched matmul", t3-t2)
print("einsum", t4-t3)

However, you might want to time these implementations on your own system as I think the speedup is shape dependent (needs large enough h, w).

1 Like