Hi, I am transfering Style Swap(based on torch) into pytorch. I wonder if this could be improved.
Input is 4 dimensional, 1*C*H*W.
_, argmax = torch.max(input,1)
for i=1,self.output:size(3) do
for j=1,self.output:size(4) do
ind = argmax[{1,1,i,j}]
self.output[{1,ind,i,j}] = 1
end
end
return self.output
end
Can this be improved in the lua version? So I can pull a request to the original repo.
How to improved it in pytorch version.
BTW, If I want to record the ind. How to do it efficiently?
--addtional
local spSize = self.output:size(3)*self.output:size(4)
self.ind_vec = torch.Tensor(spSize):zero()
_, argmax = torch.max(input,1)
for i=1,self.output:size(3) do
for j=1,self.output:size(4) do
ind = argmax[{1,1,i,j}]
self.output[{1,ind,i,j}] = 1
-- additional
tmp = (i-1)*self.output:size(3)+j
self.ind_vec[tmp] = ind
end
end
return self.output
end
Sorry, I’ve just posted a solution to another problem and it has been withdrawn.
To accelerate it, I use advance indexing.(I haven’t finished complete perf comparison. But it does run much faster for my case.) I only use batchsize=1 in my training. So I haven’t work out a plan to handle multi batchsize.
import torch
import datetime
import time
import numpy
bs = 1 # has to be 1 for now.
c = 256
h = 32
w = 32
input = torch.randn(1,c,h,w).mul_(10).floor()
sp_y = torch.arange(0,w).long()
sp_y = torch.cat([sp_y]*h)
lst = []
for i in range(h):
lst.extend([i]*w)
sp_x = lst
sp_x = torch.from_numpy(numpy.array(sp_x))
print('Method 1')
start1 = time.time()
_,c_max = torch.max(input, 1)
c_max_flatten = c_max.view(-1)
input_zero = torch.zeros_like(input)
input_zero[:,c_max_flatten,sp_x,sp_y]=1
indlst1 = c_max_flatten
end1 = time.time()
print(type(sp_x))
print(type(sp_y))
input_zero=input_zero.cuda()
print('Method 2')
input_zero2 = torch.zeros_like(input).cuda()
indlst2 = torch.zeros(h*w).long()
start2 = time.time()
_, arg_max = torch.max(input,1,keepdim=True)
for i in range(h):
for j in range(w):
ind = arg_max[:,0,i,j]
input_zero2[:,ind[0],i,j] = 1
tmp = i*w+j
indlst2[tmp] = ind[0]
end2 = time.time()
print('Speedup ratio:')
print((end2-start2)/(end1-start1)) # about 1-15 times faster
print('Before:')
print(end2-start2)
print('After:')
print(end1-start1)
print('error:')
print(torch.sum(indlst1-indlst2))
print(torch.sum(input_zero-input_zero2))
Hi, I do have a question regarding the style swap. Why do they use one hot of the max index, and then deconvolve with un-normalized style patch to represent the reconstructed image (feature map)?
Should they use the original K value (of course 0s except for the largest one) to deconvolve the style patch?