Why gpu memory consumption is higher on lower resolution input with conv3d

Hi

I am applying following operation to a 3D cube with dimension (48x48x48) and each value in the cube is 256 dimensional so basically the input cube’s dimension is:(48x48x48x256) just like an rgb image is (h,w,3) where h is height and w is width.

I try to track gpu memory consumption in inference using max memory allocated. I track memory at every line. I see that memory increases a lot after conv3d operation. However when I apply this operation on input of (80x80x80x256) the memory consumption is lower then when I apply this operation on (48x48x48x256) which is counter intuitive. Memory consumption on higher resolution input should be higher.
Max memory consumption on input with 48 resolution is 9 gb while that on 80 resolution is 6 gb. If anybody know any potential reason, you may describe it.

class Op(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels


        voxel_layers = [
            nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
            nn.BatchNorm3d(out_channels, eps=1e-4),
            nn.LeakyReLU(0.1, True),
            nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
            nn.BatchNorm3d(out_channels, eps=1e-4),
            nn.LeakyReLU(0.1, True),
         ]

        self.voxel_layers = nn.Sequential(*voxel_layers)

		
    def forward(self, inputs):
        voxel_features = inputs
        voxel_features = self.voxel_layers(voxel_features)

        print("################################# start")

        print("############### MEMORY CONSUMPTION, ", torch.cuda.max_memory_allocated())
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()

        voxel_features = self.voxel_layers[0](voxel_features)
        print("############### MEMORY CONSUMPTION, ", torch.cuda.max_memory_allocated())
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()

        voxel_features = self.voxel_layers[1](voxel_features)
        print("############### MEMORY CONSUMPTION, ", torch.cuda.max_memory_allocated())
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()

        voxel_features = self.voxel_layers[2](voxel_features)
        print("############### MEMORY CONSUMPTION, ", torch.cuda.max_memory_allocated())
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()

        voxel_features = self.voxel_layers[3](voxel_features)
        print("############### MEMORY CONSUMPTION, ", torch.cuda.max_memory_allocated())
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()

        voxel_features = self.voxel_layers[4](voxel_features)
        print("############### MEMORY CONSUMPTION, ", torch.cuda.max_memory_allocated())
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()

        voxel_features = self.voxel_layers[5](voxel_features)
        print("############### MEMORY CONSUMPTION, ", torch.cuda.max_memory_allocated())
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()

Thanks

I cannot reproduce the issue and see a lower memory usage for the smaller size:

# model = Op(80, 1, 3).cuda()
# x = torch.randn(1, 80, 80, 80, 256).cuda()
# out = model(x)

############### MEMORY CONSUMPTION,  550519808
############### MEMORY CONSUMPTION,  557073408
############### MEMORY CONSUMPTION,  537411584
############### MEMORY CONSUMPTION,  537411584
############### MEMORY CONSUMPTION,  543965184
############### MEMORY CONSUMPTION,  550519808
############### MEMORY CONSUMPTION,  550519808


model = Op(48, 1, 3).cuda()
x = torch.randn(1, 48, 48, 48, 256).cuda()
out = model(x)
############### MEMORY CONSUMPTION,  228869632
############### MEMORY CONSUMPTION,  238308864
############### MEMORY CONSUMPTION,  117978112
############### MEMORY CONSUMPTION,  117978112
############### MEMORY CONSUMPTION,  120337408
############### MEMORY CONSUMPTION,  122697728
############### MEMORY CONSUMPTION,  122697728