Running mean error for BatchNorm1D after tensor.mean

Using Data2Vec on three modalities, I extracted features of shape [1,197,768] (image), [1,768] (text), [1,499,768] (audio) from my dataset. Because the image and audio tensors have an extra dimension, I tried to use torch.mean to reduce the dimensionality. However, I keep getting this following error
RuntimeError: running_mean should contain 197 elements not 1024
I dug around a lot online and most info says this is a BatchNorm1D problem, with the input not being correctly shaped. But after I performed the tensor.mean, I don’t see why there would still be 197 in my data?

This is the model I intend to use for reference:

import torch
import torch.nn as nn

class Speaker_Dependent_Triple_Mode_with_Context(nn.Module):
def init(self, n_speaker=27, input_embedding_A=768, input_embedding_B=768, input_embedding_C=768, shared_embedding=1024, projection_embedding=512, dropout=0.5, num_classes=5):
super(Speaker_Dependent_Triple_Mode_with_Context, self).init()

    self.n_speaker = n_speaker

    self.input_embedding_A = input_embedding_A
    self.input_embedding_B = input_embedding_B
    self.input_embedding_C = input_embedding_C

    self.shared_embedding = shared_embedding
    self.projection_embedding = projection_embedding
    self.num_classes = num_classes
    self.dropout = dropout

    self.A_context_share = nn.Linear(self.input_embedding_A, self.shared_embedding)
    self.A_utterance_share = nn.Linear(self.input_embedding_A, self.shared_embedding)

    self.C_context_share = nn.Linear(self.input_embedding_C, self.shared_embedding)
    self.C_utterance_share = nn.Linear(self.input_embedding_C, self.shared_embedding)

    self.B_context_share = nn.Linear(self.input_embedding_B, self.shared_embedding)
    self.B_utterance_share = nn.Linear(self.input_embedding_B, self.shared_embedding)

    self.norm_A_context = nn.BatchNorm1d(self.shared_embedding)
    self.norm_A_utterance = nn.BatchNorm1d(self.shared_embedding)

    self.norm_C_context = nn.BatchNorm1d(self.shared_embedding)
    self.norm_C_utterance = nn.BatchNorm1d(self.shared_embedding)

    self.norm_B_context = nn.BatchNorm1d(self.shared_embedding)
    self.norm_B_utterance = nn.BatchNorm1d(self.shared_embedding)

    self.collaborative_gate_1 = nn.Linear(2 * self.shared_embedding, self.projection_embedding)
    self.collaborative_gate_2 = nn.Linear(self.projection_embedding, self.shared_embedding)

    self.pred_module = nn.Sequential(
        nn.Linear(self.n_speaker + 3 * self.shared_embedding, 2 * self.shared_embedding),
        nn.BatchNorm1d(2 * self.shared_embedding),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(2 * self.shared_embedding, self.shared_embedding),
        nn.BatchNorm1d(self.shared_embedding),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(self.shared_embedding, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(512, 128),
        nn.BatchNorm1d(128),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(128, self.num_classes)
    )

def attention(self, featureA, featureB):
    """ This method takes two features and calculates the attention """
    input = torch.cat((featureA, featureB), dim=1)
    return nn.functional.softmax(self.collaborative_gate_1(input), dim=1)

def attention_aggregator(self, feA, feB, feC, feD, feE, feF):
    """ This method calculates the attention for feA with respect to others"""
    input = self.attention(feA, feB) + self.attention(feA, feC) + self.attention(feA, feD) + self.attention(feA, feE) + self.attention(feA, feF)
    return nn.functional.softmax(self.collaborative_gate_2(input), dim=1)

def forward(self, uA, cA, uB, cB, uC, cC, speaker_embedding):
    """
    Args:
        uA: Utterance Video
        uB: Utterance Text
        uC: Utterance Audio
        cA: Context Video
        cB: Context Text
        cC: Context Audio

    Returns:
        probability of emotion classes
    """
    # Pooling or averaging the sequences to match the expected input dimensions for linear layers
    uA = torch.mean(uA, dim=1)  # [1, 197, 768] -> [1, 768]
    uC = torch.mean(uC, dim=1)  # [1, 499, 768] -> [1, 768]
    cA = torch.mean(cA, dim=1)  # [1, 197, 768] -> [1, 768]
    cC = torch.mean(cC, dim=1)

    
    shared_A_context = self.norm_A_context(nn.functional.relu(self.A_context_share(cA)))
    shared_A_utterance = self.norm_A_utterance(nn.functional.relu(self.A_utterance_share(uA)))

    shared_C_context = self.norm_C_context(nn.functional.relu(self.C_context_share(cC)))
    shared_C_utterance = self.norm_C_utterance(nn.functional.relu(self.C_utterance_share(uC)))

    shared_B_context = self.norm_B_context(nn.functional.relu(self.B_context_share(cB)))
    shared_B_utterance = self.norm_B_utterance(nn.functional.relu(self.B_utterance_share(uB)))

    updated_shared_A = shared_A_utterance * self.attention_aggregator(
        shared_A_utterance, shared_A_context, shared_C_context, shared_C_utterance, shared_B_context, shared_B_utterance)
    updated_shared_C = shared_C_utterance * self.attention_aggregator(
        shared_C_utterance, shared_C_context, shared_A_context, shared_A_utterance, shared_B_context, shared_B_utterance)
    updated_shared_B = shared_B_utterance * self.attention_aggregator(
        shared_B_utterance, shared_B_context, shared_A_context, shared_A_utterance, shared_C_context, shared_C_utterance)

    temp = torch.cat((updated_shared_A, updated_shared_C), dim=1)
    input = torch.cat((temp, updated_shared_B), dim=1)

    input = torch.cat((input, speaker_embedding), dim=1)

    return self.pred_module(input)

Your code fails with another error for me using these input shapes:

model = Speaker_Dependent_Triple_Mode_with_Context()
uA = torch.randn(1, 197, 768)
cA = torch.randn(1, 197, 768)
uB = torch.randn(1, 768)
cB = torch.randn(1, 768)
uC = torch.randn(1, 499, 768)
cC = torch.randn(1, 499, 768)

out = model(uA, cA, uB, cB, uC, cC, torch.randn(1, 1))
# ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 1024])

If I increase the batch size I see a shape mismatch in a matmul and guess the speaker_embedding shape is wrong:

model = Speaker_Dependent_Triple_Mode_with_Context()
uA = torch.randn(2, 197, 768)
cA = torch.randn(2, 197, 768)
uB = torch.randn(2, 768)
cB = torch.randn(2, 768)
uC = torch.randn(2, 499, 768)
cC = torch.randn(2, 499, 768)

out = model(uA, cA, uB, cB, uC, cC, torch.randn(2, 1))
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3073 and 3099x2048)