I need to compute covariance of input feature maps of a Conv2d layer.
I’m using a forward_pre_hook to achieve this, as follows:
def _forward_hook(self, module, input):
if torch.is_grad_enabled():
x = input[0].data
if module not in self.m_aa:
self.m_aa[module] = torch.zeros((x.size(1), x.size(1))).to(device)
if isinstance(module, nn.Conv2d):
# [N, C, H, W]
for h in range(x.size(2)): # Parallelizable?
for w in range(x.size(3)):
self.m_aa[module] += (x[:,:,h,w].T @ x[:,:,h,w])
This is extremely slow. How can I better utilize the compute by parallelizing over H and W of the Kernel?
I think you can express this computation as either a batched matrix multiply or an einsum that implements a batched matrix multiply:
import time
import torch
x = torch.randn(64, 3, 64, 64, dtype=torch.double)
test = torch.randn(32, 48)
def orig(x):
out = torch.zeros(x.size(1), x.size(1), dtype=torch.double)
for h in range(x.size(2)):
for w in range(x.size(3)):
out += x[:,:,h,w].T @ x[:,:,h,w]
return out
def new(x): # using batchmatmul
n = x.size(0)
c = x.size(1)
permuted = x.permute(2, 3, 0, 1) # move matmul dimensions to end
batched = permuted.reshape(-1, n, c) # combine h,w dimensions to be batched
temp = torch.matmul(batched.permute(0, 2, 1), batched) # permute is like transpose here
return torch.sum(temp, dim=0) # sum along batch (h,w) dimension
def new2(x): # using einsum
temp = torch.einsum('kchw,kdhw->cdhw', x, x)
return torch.sum(temp, dim=[2,3])
# warmup runs
for i in range(3):
out = orig(x)
out2 = new(x)
out3 = new2(x)
print(torch.allclose(out, out2))
print(torch.allclose(out, out3))
t1 = time.time()
out = orig(x)
t2 = time.time()
out2 = new(x)
t3 = time.time()
out3 = new2(x)
t4 = time.time()
print("original", t2-t1)
print("batched matmul", t3-t2)
print("einsum", t4-t3)
However, you might want to time these implementations on your own system as I think the speedup is shape dependent (needs large enough h, w).