Help with Masked Autoencoder Training - Struggling to Reconstruct Signals

Hello everyone,

I’m currently working on training a masked autoencoder that takes audio signals as input, applies a time-masking transformation, and tries to reconstruct the original signal. However, the model is not learning to reconstruct the signals well, and the loss isn’t improving much over time. I would appreciate any feedback on whether my training loop or general approach might have issues.

Here’s the training loop I’m using:

def tensor_to_image(tensor):
    fig, ax = plt.subplots(figsize=(12, 4))
    ax.imshow(tensor, aspect='auto', origin='lower')
    ax.axis('off')
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close(fig)
    buf.seek(0)
    return np.array(Image.open(buf))

def save_checkpoints(model, model_name, save_to):
   os.makedirs(os.path.join(save_to, 'checkpoints'), exist_ok=True)
   torch.save(model.state_dict(), os.path.join(save_to, 'checkpoints', f'{model_name}.pt'))

mask = torchaudio.transforms.TimeMasking(time_mask_param=150, p=1.0)

def train(device, model, epochs, train_dataloader, val_dataloader, criterion, optim, log, save_epochs, save_path):

    print(f'\nLogging to wandb: {log}')
    if log:
        global_step = 0
        wandb.init(project='VisResAE')
        wandb.watch(model, criterion, log='all', log_freq=1)

    print('\nTraining model...')
    model.to(device)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', factor=0.5, patience=5)

    print(f'\nUsing {scheduler.__class__.__name__} schedulling.')
    print(f'\nTraining on {device}')

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0.0
        train_steps = 0
        for idx_batch, signal in enumerate(train_dataloader):
            original_signal = signal.to(device)
            masked_signal = mask(original_signal.clone())
            out, _ = model(masked_signal)
            loss = criterion(out, original_signal)
            total_train_loss += loss.item()

            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optim.step()
            train_steps += 1

            if (idx_batch + 1) % 1 == 0:
                print(f'Epoch [{epoch+1}/{epochs}] - Train Step [{idx_batch+1}/{len(train_dataloader)}] - Loss: {loss.item():.3f}')

            if log:
                global_step += 1
                wandb.log({
                    'train_steps': global_step,
                    'train_loss': loss.item()
                })

        avg_train_loss = total_train_loss / len(train_dataloader)

        model.eval()
        total_val_loss = 0.0
        for idx_batch, signal in enumerate(val_dataloader):
            original_signal = signal.to(device)
            masked_signal = mask(original_signal.clone())
            out, _ = model(masked_signal)
            loss = criterion(out, original_signal)
            total_val_loss += loss.item()

            if (idx_batch + 1) % 1 == 0:
                print(f'Epoch [{epoch+1}/{epochs}] - Validation Step [{idx_batch+1}/{len(val_dataloader)}] - Loss: {loss.item():.3f}')
            
            original_img = tensor_to_image(original_signal[0].squeeze().cpu().detach().numpy())
            masked_img = tensor_to_image(masked_signal[0].squeeze().cpu().detach().numpy())
            reconstructed_img = tensor_to_image(out[0].squeeze().cpu().detach().numpy())

            if log:
                wandb.log({
                    'original_spectrogram': wandb.Image(original_img, caption='Original Spectrogram'),
                    'masked_spectrogram': wandb.Image(masked_img, caption='Masked Spectrogram'),
                    'reconstructed_spectrogram': wandb.Image(reconstructed_img, caption='Reconstructed Spectrogram'),
                })

        avg_val_loss = total_val_loss / len(val_dataloader)
        scheduler.step(avg_val_loss)

        if log:
            wandb.log({
                'epoch': epoch,
                'avg_train_loss': avg_train_loss,
                'avg_train_loss': avg_val_loss
            })

        if epoch % save_epochs == 0:
            save_checkpoints(model=model, model_name=f'melmae_{epoch+1}', save_to=save_path)

    print('Finished Training.')

    if log:
        wandb.finish()

Some things I’ve noticed:

  1. The model’s reconstruction outputs don’t seem to get significantly better as training progresses.
  2. The loss remains quite high, and there’s little improvement over multiple epochs.

Some questions:

  • Does this training loop seem reasonable for a masked autoencoder setup?
  • Should I be using a different masking method or transformation?
  • Could the scheduler or optimizer setup be affecting the model’s ability to learn?

Thank you in advance for your help!