Training will be stop after a while in GRU layer

I use my custom dataset class to convert audio files to mel- Spectrogram images. the shape will be padded to (128,1024). I have 10 classes. after a while in the first epoch, my network will be crashed due to this error:

Current run is terminating due to exception: Expected hidden size (1, 7, 32), got [1, 16, 32]
Engine run is terminating due to exception: Expected hidden size (1, 7, 32), got [1, 16, 32]
Traceback (most recent call last):
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-b8f3a45f8e35>", line 1, in <module>
    runfile('/home/omid/OMID/projects/python/mldl/NeuralMusicClassification/tools/train_net.py', wdir='/home/omid/OMID/projects/python/mldl/NeuralMusicClassification/tools')
  File "/home/omid/OMID/program/pycharm-professional-2020.2.4/pycharm-2020.2.4/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/home/omid/OMID/program/pycharm-professional-2020.2.4/pycharm-2020.2.4/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/omid/OMID/projects/python/mldl/NeuralMusicClassification/tools/train_net.py", line 60, in <module>
    main()
  File "/home/omid/OMID/projects/python/mldl/NeuralMusicClassification/tools/train_net.py", line 56, in main
    train(cfg)
  File "/home/omid/OMID/projects/python/mldl/NeuralMusicClassification/tools/train_net.py", line 35, in train
    do_train(
  File "/home/omid/OMID/projects/python/mldl/NeuralMusicClassification/engine/trainer.py", line 79, in do_train
    trainer.run(train_loader, max_epochs=epochs)
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/ignite/engine/engine.py", line 702, in run
    return self._internal_run()
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/ignite/engine/engine.py", line 775, in _internal_run
    self._handle_exception(e)
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/ignite/engine/engine.py", line 469, in _handle_exception
    raise e
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/ignite/engine/engine.py", line 745, in _internal_run
    time_taken = self._run_once_on_dataset()
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/ignite/engine/engine.py", line 850, in _run_once_on_dataset
    self._handle_exception(e)
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/ignite/engine/engine.py", line 469, in _handle_exception
    raise e
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/ignite/engine/engine.py", line 833, in _run_once_on_dataset
    self.state.output = self._process_function(self, self.state.batch)
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/ignite/engine/__init__.py", line 103, in _update
    y_pred = model(x)
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/omid/OMID/projects/python/mldl/NeuralMusicClassification/modeling/model.py", line 113, in forward
    x, h1 = self.gru1(x, h0)
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 819, in forward
    self.check_forward_args(input, hx, batch_sizes)
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 229, in check_forward_args
    self.check_hidden_size(hidden, expected_hidden_size)
  File "/home/omid/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 223, in check_hidden_size
    raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
RuntimeError: Expected hidden size (1, 7, 32), got [1, 16, 32]

my network is :

import torch
import torch.nn as nn
import torch.nn.functional as F

print('cuda', torch.cuda.is_available())


class MusicClassification(nn.Module):
    def __init__(self, cfg):
        super(MusicClassification, self).__init__()
        device = cfg.MODEL.DEVICE
        num_class = cfg.MODEL.NUM_CLASSES

        self.np_layers = 4
        self.np_filters = [64, 128, 128, 128]
        self.kernel_size = (3, 3)

        self.pool_size = [(2, 2), (4, 2)]

        self.channel_axis = 1
        self.frequency_axis = 2
        self.time_axis = 3

        # self.h0 = torch.zeros((1, 16, 32)).to(device)

        self.bn0 = nn.BatchNorm2d(num_features=self.channel_axis)
        self.bn1 = nn.BatchNorm2d(num_features=self.np_filters[0])
        self.bn2 = nn.BatchNorm2d(num_features=self.np_filters[1])
        self.bn3 = nn.BatchNorm2d(num_features=self.np_filters[2])
        self.bn4 = nn.BatchNorm2d(num_features=self.np_filters[3])

        self.conv1 = nn.Conv2d(1, self.np_filters[0], kernel_size=self.kernel_size)
        self.conv2 = nn.Conv2d(self.np_filters[0], self.np_filters[1], kernel_size=self.kernel_size)
        self.conv3 = nn.Conv2d(self.np_filters[1], self.np_filters[2], kernel_size=self.kernel_size)
        self.conv4 = nn.Conv2d(self.np_filters[2], self.np_filters[3], kernel_size=self.kernel_size)

        self.max_pool_2_2 = nn.MaxPool2d(self.pool_size[0])
        self.max_pool_4_2 = nn.MaxPool2d(self.pool_size[1])

        self.drop_01 = nn.Dropout(0.1)
        self.drop_03 = nn.Dropout(0.3)

        self.gru1 = nn.GRU(input_size=128, hidden_size=32, batch_first=True)
        self.gru2 = nn.GRU(input_size=32, hidden_size=32, batch_first=True)

        self.activation = nn.ELU()

        self.dense = nn.Linear(32, num_class)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        # x [16, 1, 128,938]
        x = self.bn0(x)
        # x [16, 1, 128,938]
        x = F.pad(x, (0, 0, 2, 1))
        # x [16, 1, 131,938]
        x = self.conv1(x)
        # x [16, 64, 129,936]
        x = self.activation(x)
        # x [16, 64, 129,936]
        x = self.bn1(x)
        # x [16, 64, 129,936]
        x = self.max_pool_2_2(x)
        # x [16, 64, 64,468]
        x = self.drop_01(x)
        # x [16, 64, 64,468]
        x = F.pad(x, (0, 0, 2, 1))
        # x [16, 64, 67,468]
        x = self.conv2(x)
        # x [16, 128, 65,466]
        x = self.activation(x)
        # x [16, 128, 65,466]
        x = self.bn2(x)
        # x [16, 128, 65,455]
        x = self.max_pool_4_2(x)
        # x [16, 128, 16,233]
        x = self.drop_01(x)
        # x [16, 128, 16,233]
        x = F.pad(x, (0, 0, 2, 1))
        # x [16, 128, 19,233]
        x = self.conv3(x)
        # x [16, 128, 17,231]
        x = self.activation(x)
        # x [16, 128, 17,231]
        x = self.bn3(x)
        # x [16, 128, 17,231]
        x = self.max_pool_4_2(x)
        # x [16, 128, 4,115]
        x = self.drop_01(x)
        # x [16, 128, 4,115]
        x = F.pad(x, (0, 0, 2, 1))
        # x [16, 128, 7,115]
        x = self.conv4(x)
        # x [16, 128, 5,113]
        x = self.activation(x)
        # x [16, 128, 5,113]
        x = self.bn4(x)
        # x [16, 128, 5,113]
        x = self.max_pool_4_2(x)
        # x [16, 128, 1,56]
        x = self.drop_01(x)
        # x [16, 128, 1,56]

        x = x.permute(0, 3, 1, 2)
        # x [16, 56, 128,1]
        resize_shape = list(x.shape)[2] * list(x.shape)[3]
        # x [16, 128, 56,1], reshape size is 128
        x = torch.reshape(x, (list(x.shape)[0], list(x.shape)[1], resize_shape))
        # x [16, 56, 128]
        device = torch.device("cuda" if torch.cuda.is_available()
                              else "cpu")
        h0 = torch.zeros((1, 16, 32)).to(device)
        x, h1 = self.gru1(x, h0)
        # x [16, 56, 32]
        x, _ = self.gru2(x, h1)
        # x [16, 56, 32]
        x = x[:, -1, :]
        x = self.dense(x)
        # x [16,10]
        x = self.softmax(x)
        # x [16, 10]
        # x = torch.argmax(x, 1)
        return x



my dataset is :


from __future__ import print_function, division

import os

import librosa
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from torch.utils.data import Dataset
from utils.util import pad_along_axis

print(torch.__version__)
print(torchaudio.__version__)

# Ignore warnings
import warnings

warnings.filterwarnings("ignore")

plt.ion()

import pathlib

print(pathlib.Path().absolute())


class GTZANDataset(Dataset):
    def __init__(self,
                 genre_folder='/home/omid/OMID/projects/python/mldl/NeuralMusicClassification/data/dataset/genres_original',
                 one_hot_encoding=False,
                 sr=16000, n_mels=128,
                 n_fft=2048, hop_length=512,
                 transform=None):

        self.genre_folder = genre_folder
        self.one_hot_encoding = one_hot_encoding
        self.audio_address, self.labels = self.extract_address()
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.transform = transform
        self.le = LabelEncoder()
        self.hop_length = hop_length

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        address = self.audio_address[index]
        y, sr = librosa.load(address, sr=self.sr)
        S = librosa.feature.melspectrogram(y, sr=sr,
                                           n_mels=self.n_mels,
                                           n_fft=self.n_fft,
                                           hop_length=self.hop_length)

        sample = librosa.amplitude_to_db(S, ref=1.0)
        sample = np.expand_dims(sample, axis=0)
        sample = pad_along_axis(sample, 1024, axis=2)
        # print(sample.shape)
        sample = torch.from_numpy(sample)

        label = self.labels[index]
        # label = torch.from_numpy(label)
        print(sample.shape,label)
        if self.transform:
            sample = self.transform(sample)
        return sample, label

    def extract_address(self):
        label_map = {
            'blues': 0,
            'classical': 1,
            'country': 2,
            'disco': 3,
            'hiphop': 4,
            'jazz': 5,
            'metal': 6,
            'pop': 7,
            'reggae': 8,
            'rock': 9
        }
        labels = []
        address = []
        # extract all genres' folders
        genres = [path for path in os.listdir(self.genre_folder)]
        for genre in genres:
            # e.g. ./data/generes_original/country
            genre_path = os.path.join(self.genre_folder, genre)
            # extract all sounds from genre_path
            songs = os.listdir(genre_path)

            for song in songs:
                song_path = os.path.join(genre_path, song)
                genre_id = label_map[genre]
                # one_hot_targets = torch.eye(10)[genre_id]
                labels.append(genre_id)
                address.append(song_path)

        samples = np.array(address)
        labels = np.array(labels)
        # convert labels to one-hot encoding
        # if self.one_hot_encoding:
        #     labels = OneHotEncoder(sparse=False).fit_transform(labels)
        # else:
        #     labels = LabelEncoder().fit_transform(labels)

        return samples, labels


and trainer :


# encoding: utf-8


import logging

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers import ModelCheckpoint, Timer
from ignite.metrics import Accuracy, Loss, RunningAverage


def do_train(
        cfg,
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler,
        loss_fn,
):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS

    model = model.to(device)

    logger = logging.getLogger("template_model.train")
    logger.info("Start training")
    trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
    evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy(),
                                                            'ce_loss': Loss(loss_fn)}, device=device)
    checkpointer = ModelCheckpoint(output_dir, 'mnist', None, n_saved=10, require_empty=False)
    timer = Timer(average=True)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
                                                                     'optimizer': optimizer.state_dict()})
    timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

    RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss')

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_period == 0:
            logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
                        .format(engine.state.epoch, iter, len(train_loader), engine.state.metrics['avg_loss']))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['ce_loss']
        logger.info("Training Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}"
                    .format(engine.state.epoch, avg_accuracy, avg_loss))

    if val_loader is not None:
        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(engine):
            evaluator.run(val_loader)
            metrics = evaluator.state.metrics
            avg_accuracy = metrics['accuracy']
            avg_loss = metrics['ce_loss']
            logger.info("Validation Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}"
                        .format(engine.state.epoch, avg_accuracy, avg_loss)
                        )

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
                    .format(engine.state.epoch, timer.value() * timer.step_count,
                            train_loader.batch_size / timer.value()))
        timer.reset()

    trainer.run(train_loader, max_epochs=epochs)