transform = Compose([
TenCrop(size), # this is a list of PIL Images
Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
])
#In your test loop you can do the following:
input, target = batch # input is a 5d tensor, target is 2d
bs, ncrops, c, h, w = input.size()
result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
If I want to use result_max
instead of `result_avg, then
result_max = torch.max(result.view(`bs, ncrops, -1), 1)
is ok?