Audio Classifier code throwing error due not adapting to variable sizes

I want the model to accept input tensors of shape [B, C, F, T], where B represents the batch size, C represents the number of channels, F represents the FFT frequency, and T represents the dynamic time dimension. The model returns output tensors of shape [B, 10] representing logits.

I feel like I’m 95% there but having trouble-debugging where I’m going wrong in adaptability of my layers to the B,C, F and T params?

Here is my forward method and layer definitions

        self.conv1 = torch.nn.Conv2d(in_channels=2, out_channels=32, kernel_size=(3, 3), stride=1, padding=1)
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=1, padding=1)
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(32 * (n_fft // 4) * (sample_rate // (2 * fft_hop_length)), 128)
        self.fc2 = torch.nn.Linear(128, 10)  # 10 classes for digits 0-9

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        print("*"*50)
        B, C, F, T = input.size()  # Extract dimensions from the input tensor

        print("Input shape",input.shape)
        print("")
        x = self.maxpool(self.relu(self.conv1(input)))
        print("x1 shape",x.shape)
        x = self.maxpool(self.relu(self.conv2(x)))
        print("x2 shape",x.shape)

        # Calculate the flattened size
        flattened_size = 32 * (self.hparams.n_fft // 4) * (self.hparams.sample_rate // (2 * self.hparams.fft_hop_length)) # NOTE: Method 1
        print("F//2:",(F // 2), " T // 2:",(T//2),"new flatted:",C * (F // 2) * (T // 2))

        print("flatten_size",flattened_size," sample rate:",self.hparams.sample_rate)
        print("*"*50)
        x = x.view(B, flattened_size)
        print("after flattening",x.size())
        
        x = self.fc1(x)
        print("After fc1",x.size())
        x = self.relu(x)
        x = self.fc2(x)
        print("Final shape:",x.size())
        return x

Here is an example of what is working:

Input shape torch.Size([1, 2, 513, 102])

x1 shape torch.Size([1, 32, 256, 51])
x2 shape torch.Size([1, 64, 128, 25])
F//2: 256  T // 2: 51 new flatted: 26112
flatten_size 204800  sample rate: 8000
**************************************************
after flattening torch.Size([1, 204800])
After fc1 torch.Size([1, 128])
Final shape: torch.Size([1, 10])

I’m able to load some of the data, but here is an example of where I’m stuck:

Input shape torch.Size([2, 2, 513, 72])

x1 shape torch.Size([2, 32, 256, 36])
x2 shape torch.Size([2, 64, 128, 18])
F//2: 256  T // 2: 36 new flatted: 18432
flatten_size 409600  sample rate: 8000

    x = x.view(B, flattened_size)
RuntimeError: shape '[2, 409600]' is invalid for input of size 294912