Mobilenet layer shape mismatch after modifying layers

I am trying to use Mobilenet_v3_large to classify audio spectrograms and encounter “RuntimeError: mat1 and mat2 shapes cannot be multiplied (3840x1 and 960x1280)” when forwarding to linear layers of the net.
Input tensors have only 1 channel with shape of [1, 513, 469]. I load them in batches of 4.

def collate_fn(batch):
    # A data tuple has the form:
    # spectrogram, sample_rate, label, speaker_id, utterance_number
    tensors, targets = [], []

    # Gather in lists, and encode labels as indices
    for spectrogram, _, label, *_ in batch:
        tensors += [spectrogram]
        targets += [label_to_index(label)]

    # Group the list of tensors into a batched tensor
    tensors = torch.stack(tensors)
    targets = torch.stack(targets)

    return tensors, targets

batch_size = 4
train_loader = torch.utils.data.DataLoader(
    train_set, 
    shuffle=True, 
    num_workers=0, 
    collate_fn=collate_fn,
    batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(
     test_set, 
    shuffle=True, 
    num_workers=0, 
    collate_fn=collate_fn,
    batch_size=batch_size)
from torchvision import models

model_ = models.mobilenet_v3_large(pretrained=True)
model_

I modify classifier and first convolution to use 1 layer, so model looks like this.

model_.classifier[3] = torch.nn.Linear(1280, len(labels))
# change first convolution to use 1 channel images
modified = list(model_.children())
w = modified[0][0][0].weight
modified[0][0][0] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)
modified[0][0][0].weight = nn.Parameter(torch.mean(w, dim=1, keepdim=True))
model_ = nn.Sequential(*modified)

model_.to(device)

Sequential(
  (0): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
... Several Inverted Residual Blocks ...
(15): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(160, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (16): Conv2dNormActivation(
      (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
  )
  (1): AdaptiveAvgPool2d(output_size=1)
  (2): Sequential(
    (0): Linear(in_features=960, out_features=1280, bias=True)
    (1): Hardswish()
    (2): Dropout(p=0.2, inplace=True)
    (3): Linear(in_features=1280, out_features=3, bias=True)
  )
)

Training is like this.

def train(model, epoch, log_interval):
    model.train()
    for batch_idx, (data, target, *_) in enumerate(train_loader):
        data = data.to(device)
        target = target.to(device)

        output = model(data)

        loss = F.cross_entropy(output.squeeze(), target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print(f"Train Epoch: {epoch}\t[{(batch_idx * len(data)):5.0f}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):3.0f}%)]\tLoss: {loss.item():.4f}")

        pbar.update(pbar_update)
        losses.append(loss.item())

log_interval = 20
n_epoch = 30

pbar_update = 1 / (len(train_loader) + len(test_loader))
losses = []
accuracies = []

optimizer = optim.Adam(model_.parameters(), lr=0.0001, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2) 

with tqdm(total=n_epoch, 
          bar_format = "{desc}: {percentage:.2f}%|{bar}| {n:.2f}/{total_fmt} [{elapsed}<{remaining}]",
          colour = 'GREEN',
          ) as pbar:
    for epoch in range(1, n_epoch + 1):
        train(model_, epoch, log_interval)
        accuracy_step = test(model_, epoch)
        accuracies.append(accuracy_step)
        scheduler.step()

When training I get “RuntimeError: mat1 and mat2 shapes cannot be multiplied (3840x1 and 960x1280)”. 3840 is 4, batch size, multiplied by 960, number of channels in previous layer, but I’m not sure what AdaptiveAvgPooling(output_size=1) does. I haven’t tried yet just making 3 channels with the same data with transforms, but I am curious what is done wrong here. I would appreciate if you could help me with this matter.

It seems you are trying to wrap all modules into an nn.Sequential container:

model_ = nn.Sequential(*modified)

which can easily break the model if functional API calls were used in the forward method or if the modules are not executed in a strictly sequential way.
In your use case this torch.flatten operation would be missing so you should either derive from the base class and reimplement the forward method or add the missing flatten op as an nn.Flatten module.

1 Like