Inserting zeroed channels in convolutional feature map

Hi,
I’d like to build a layer that given the output of a Conv2d inserts a series of zero channels in the feature map, e.g. if I have a tensor t1 of shape (Bx3xHxW), I want to obtain a new tensor t2 of shape (Bx5xHxW) inserting two blocks of shape (Bx1xHxW) in particular places of tensor t1 defined by a list of indexes.

Here a 1D example of the expected procedure:

t1 = torch.tensor([1, 2, 3, 4, 5])  # Starting tensor of size 5
target_size = 8
t2 = torch.zeros(target_size)
idxs = torch.tensor([0, 3, 4, 6, 7])  # Indices of where to put the elements of t1 in t2
# Insert elements of t1 in t2 obtaining t2 = [1, 0, 0, 2, 3, 0, 4, 5]

I found three different ways of doing so with 4D tensors (output of Conv2d) that work on cuda devices:

x = torch.randn(bs, ch, shape, shape)
idx = torch.tensor(...)

# matmul
eye = torch.eye(x.shape[1]).unsqueeze(0).to(device)
idx = idx[None, :, None].expand(eye.shape)
zeros = torch.zeros(x.shape[1], target_ch).unsqueeze(0).to(device)
target = torch.scatter(zeros.permute(0, 2, 1), 1, idx, eye).permute(0, 2, 1)

starter.record()
x_flat = x.view(x.shape[0], x.shape[1], -1)  # inference
expanded_x = torch.matmul(target.permute(0, 2, 1), x_flat).view(x.shape[0], -1, x.shape[2], x.shape[3])  
ender.record()
torch.cuda.synchronize()

#scatter
zeros = torch.zeros(x.shape[0], target_ch, *x.shape[2:]).to(device)
idx = idx[None, :, None, None].expand(x.shape)

starter.record()
expanded_x = torch.scatter(zeros, 1, idx, x)
ender.record()
torch.cuda.synchronize()

#index_select
idxs = []
current = 0
for i in range(target_ch):
    if i in idx:
        idxs.append(current)
        current += 1
    else:
        idxs.append(x.shape[1] - 1)
idxs = torch.tensor(idxs, device=x.device)

starter.record()
x = pad(x, (0, 0, 0, 0, 0, 1))
expanded_x = torch.index_select(x, 1, idxs)
ender.record()
torch.cuda.synchronize()

My problem is that all of theese methods have a computational time that drastically increase with the batch size.
Is there, by chance, another way to do so that has a lower computational cost?

I attach a plot of batch size vs time for all three procedures.

1 Like