Read a pytorch project and get confused?

    def forward(self, input):
        #TODO confused
        input = input.view((-1, ) + input.size()[-3:])
        if self._representation in ['mv', 'residual']:
            input = self.data_bn(input)

        base_out = self.base_model(input)
        return base_out

anyone can tell me this line function?

input = input.view((-1, ) + input.size()[-3:])


“.view” reshapes the tensor to a new size. It takes in, the new size to which we want our tensor to be reshaped. Now, in your code, what you are trying to achieve is keep the last 3 dimensions the same (because of [-3:]) and combine the initial dimensions.

For ex:

a = torch.rand((1,2,3,4,5,6))
a  = input.view((-1,) + input.size()[-3:])

Initial shape of the tensor = (1,2,3,4,5,6)
after resizing the new shape = (6,4,5,6) because we kept the last 3 dimensions the same and combined the rest which is 1x2x3.

1 Like

thanks a lot! it’s a clear and concise explanation!

1 Like