Methods to reduce the size of intermediate variable

Hi, I am a newbie to PyTorch and I got some trouble when I am implementing one model related to 3D point cloud. The problem is that I use a point pair feature representation for each local patch, so the input size is [bs, 2048, 1024, 4], where 2048 is the number of interest points, and 1024 is the number of neighboring points per local patch, and 4 is the dimension of point pair feature. And the network architecture is relatively simple (just an Auto-Encoder), but even if I use a batch size 1, it still can not run because GPU out of memory. I use torchsummary to see the parameter size and intermediate variable size and find that the intermediate variable takes up some much memory. So I wonder is there any methods to reduce the size of intermediate variable?

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv1d-1           [2048, 64, 1024]             320
       BatchNorm1d-2           [2048, 64, 1024]             128
              ReLU-3           [2048, 64, 1024]               0
            Conv1d-4          [2048, 128, 1024]           8,320
       BatchNorm1d-5          [2048, 128, 1024]             256
              ReLU-6          [2048, 128, 1024]               0
            Conv1d-7          [2048, 256, 1024]          33,024
       BatchNorm1d-8          [2048, 256, 1024]             512
              ReLU-9          [2048, 256, 1024]               0
           Linear-10          [2048, 1024, 512]         360,960
           Linear-11          [2048, 1024, 512]         262,656
          Encoder-12             [2048, 1, 512]               0
           Conv1d-13          [2048, 256, 1024]         131,840
             ReLU-14          [2048, 256, 1024]               0
           Conv1d-15          [2048, 128, 1024]          32,896
             ReLU-16          [2048, 128, 1024]               0
           Conv1d-17           [2048, 64, 1024]           8,256
             ReLU-18           [2048, 64, 1024]               0
           Conv1d-19           [2048, 32, 1024]           2,080
             ReLU-20           [2048, 32, 1024]               0
           Conv1d-21            [2048, 4, 1024]             132
             ReLU-22            [2048, 4, 1024]               0
           Conv1d-23          [2048, 256, 1024]         132,352
             ReLU-24          [2048, 256, 1024]               0
           Conv1d-25          [2048, 128, 1024]          32,896
             ReLU-26          [2048, 128, 1024]               0
           Conv1d-27           [2048, 64, 1024]           8,256
             ReLU-28           [2048, 64, 1024]               0
           Conv1d-29           [2048, 32, 1024]           2,080
             ReLU-30           [2048, 32, 1024]               0
           Conv1d-31            [2048, 4, 1024]             132
             ReLU-32            [2048, 4, 1024]               0
          Decoder-33            [2048, 1024, 4]               0
================================================================
Total params: 1,017,096
Trainable params: 1,017,096
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 32.00
Forward/backward pass size (MB): 68936.00
Params size (MB): 3.88
Estimated Total Size (MB): 68971.88
----------------------------------------------------------------

You could use torch.utils.checkpoint to trade compute for memory.