Hi, I would like to have a question on torch.randperm. I used randperm in my model to reshuffle layers in the torch tensor because ChannelShuffle was not implemented in cuda.
My questions are:
Does this function persist in torch.no_grad()? Does torch.randperm generate a new permutation list in validation and testing?
Does the torch.randperm generates a new permutation list for each batch in training
This is a block in the model that I am implementing
class Mix_block(nn.Module):
def init(self):
super(Mix_block,self).init()
def mix_func(self, tensor1, tensor2, tensor3, tensor4):
input = torch.cat((tensor1, tensor2, tensor3, tensor4), 1)
permute = torch.randperm(256, requires_grad=True)
out = input[:, permute]
# channel_shuffle = nn.ChannelShuffle(4)
# out = channel_shuffle(input)
return out
def forward(self, tensor1, tensor2, tensor3, tensor4):
out = self.mix_func(tensor1, tensor2, tensor3, tensor4)
return out
torch.randperm returns a random permutation of integers in [0, n-1] and your code snippet won’t work:
permute = torch.randperm(256, requires_grad=True)
# RuntimeError: Only Tensors of floating point and complex dtype can require gradients
The function can be used in a no_grad context, but I don’t know what “persist” means here. torch.randperm is not a module and doesn’t use the self.training argument. It’s thus not influenced by model.train() or model.eval() calls.
torch.randperm will shuffle the indices whenever called. If you are calling it for each batch, then a new output will be generated in each call.
Thank ptrblck, I had the the same error when testing my model. Removing require_grad=True solved the error.
Let me explain my problem a bit:
I used torch.randperm to do channel shuffle randomly. So, I want the input channels to be shuffled during the training phase, and then the model will remember the new position of channels from the permutation like a parameter. During the evaluation and the testing phases, the shuffle would stop, and the model would use the random permutation generated during the training phase.
I guess from your answer that my implementation would not work.