# Efficient way to set the maximum value among all channels to 1 through the whole spatial positions

(Lambda Will) #1

Hi, I am transfering Style Swap(based on torch) into pytorch. I wonder if this could be improved.
Input is 4 dimensional, 1*C*H*W.

``````  _, argmax = torch.max(input,1)
for i=1,self.output:size(3) do
for j=1,self.output:size(4) do
ind = argmax[{1,1,i,j}]
self.output[{1,ind,i,j}] = 1
end
end
return self.output
end
``````
1. Can this be improved in the lua version? So I can pull a request to the original repo.
2. How to improved it in pytorch version.
3. BTW, If I want to record the `ind`. How to do it efficiently?
``````--addtional
local spSize = self.output:size(3)*self.output:size(4)
self.ind_vec = torch.Tensor(spSize):zero()

_, argmax = torch.max(input,1)
for i=1,self.output:size(3) do
for j=1,self.output:size(4) do
ind = argmax[{1,1,i,j}]
self.output[{1,ind,i,j}] = 1
-- additional
tmp = (i-1)*self.output:size(3)+j
self.ind_vec[tmp] = ind
end
end
return self.output
end
``````

(Lambda Will) #2

Have solved by myself…

#3

Hi, can I ask how do you do that efficiently? THanks

(Lambda Will) #5

Sorry, I’ve just posted a solution to another problem and it has been withdrawn.
To accelerate it, I use advance indexing.(I haven’t finished complete perf comparison. But it does run much faster for my case.) I only use batchsize=1 in my training. So I haven’t work out a plan to handle multi batchsize.

``````import torch
import datetime
import time
import numpy

bs = 1  # has to be 1 for now.
c = 256
h = 32
w = 32

input  = torch.randn(1,c,h,w).mul_(10).floor()

sp_y = torch.arange(0,w).long()
sp_y = torch.cat([sp_y]*h)

lst = []
for i in range(h):
lst.extend([i]*w)
sp_x = lst
sp_x = torch.from_numpy(numpy.array(sp_x))

print('Method 1')
start1 = time.time()
_,c_max = torch.max(input, 1)

c_max_flatten = c_max.view(-1)

input_zero = torch.zeros_like(input)

input_zero[:,c_max_flatten,sp_x,sp_y]=1
indlst1 = c_max_flatten
end1 = time.time()

print(type(sp_x))
print(type(sp_y))

input_zero=input_zero.cuda()

print('Method 2')

input_zero2 = torch.zeros_like(input).cuda()
indlst2 = torch.zeros(h*w).long()
start2 = time.time()
_, arg_max = torch.max(input,1,keepdim=True)
for i in range(h):
for j in range(w):
ind = arg_max[:,0,i,j]
input_zero2[:,ind[0],i,j] = 1
tmp = i*w+j
indlst2[tmp] = ind[0]

end2 = time.time()

print('Speedup ratio:')
print((end2-start2)/(end1-start1)) # about 1-15 times faster
print('Before:')
print(end2-start2)
print('After:')
print(end1-start1)

print('error:')
print(torch.sum(indlst1-indlst2))
print(torch.sum(input_zero-input_zero2))

``````

(Shaun) #6

Hi, I do have a question regarding the style swap. Why do they use one hot of the max index, and then deconvolve with un-normalized style patch to represent the reconstructed image (feature map)?
Should they use the original K value (of course 0s except for the largest one) to deconvolve the style patch?

Thanks