Efficient way to set the maximum value among all channels to 1 through the whole spatial positions

Hi, I am transfering Style Swap(based on torch) into pytorch. I wonder if this could be improved.
Input is 4 dimensional, 1*C*H*W.

  _, argmax = torch.max(input,1)
      for i=1,self.output:size(3) do
        for j=1,self.output:size(4) do
          ind = argmax[{1,1,i,j}]
          self.output[{1,ind,i,j}] = 1
        end
     end
     return self.output
end
  1. Can this be improved in the lua version? So I can pull a request to the original repo.
  2. How to improved it in pytorch version.
  3. BTW, If I want to record the ind. How to do it efficiently?
--addtional
local spSize = self.output:size(3)*self.output:size(4)
self.ind_vec = torch.Tensor(spSize):zero()

 _, argmax = torch.max(input,1)
 for i=1,self.output:size(3) do
     for j=1,self.output:size(4) do
         ind = argmax[{1,1,i,j}]
         self.output[{1,ind,i,j}] = 1
         -- additional
         tmp = (i-1)*self.output:size(3)+j
         self.ind_vec[tmp] = ind
    end
  end
return self.output
end

Have solved by myself…

Hi, can I ask how do you do that efficiently? THanks

Sorry, I’ve just posted a solution to another problem and it has been withdrawn.
To accelerate it, I use advance indexing.(I haven’t finished complete perf comparison. But it does run much faster for my case.) I only use batchsize=1 in my training. So I haven’t work out a plan to handle multi batchsize.

import torch
import datetime
import time
import numpy

bs = 1  # has to be 1 for now.
c = 256
h = 32
w = 32

input  = torch.randn(1,c,h,w).mul_(10).floor()


sp_y = torch.arange(0,w).long()
sp_y = torch.cat([sp_y]*h)

lst = []
for i in range(h):
    lst.extend([i]*w)
sp_x = lst
sp_x = torch.from_numpy(numpy.array(sp_x))


print('Method 1')
start1 = time.time()
_,c_max = torch.max(input, 1)

c_max_flatten = c_max.view(-1)

input_zero = torch.zeros_like(input)


input_zero[:,c_max_flatten,sp_x,sp_y]=1
indlst1 = c_max_flatten
end1 = time.time()

print(type(sp_x))
print(type(sp_y))



input_zero=input_zero.cuda()




print('Method 2')

input_zero2 = torch.zeros_like(input).cuda()
indlst2 = torch.zeros(h*w).long()
start2 = time.time()
_, arg_max = torch.max(input,1,keepdim=True)
for i in range(h):
    for j in range(w):
        ind = arg_max[:,0,i,j]
        input_zero2[:,ind[0],i,j] = 1
        tmp = i*w+j
        indlst2[tmp] = ind[0]

end2 = time.time()

print('Speedup ratio:')
print((end2-start2)/(end1-start1)) # about 1-15 times faster
print('Before:')
print(end2-start2)
print('After:')
print(end1-start1)


print('error:')
print(torch.sum(indlst1-indlst2))
print(torch.sum(input_zero-input_zero2))

Hi, I do have a question regarding the style swap. Why do they use one hot of the max index, and then deconvolve with un-normalized style patch to represent the reconstructed image (feature map)?
Should they use the original K value (of course 0s except for the largest one) to deconvolve the style patch?

Thanks