Help converting CRNN tensorflow model to pytorch

I have been using tensorflow for a while and rocking a modified version of OCR model for reading Captchas with support for a transformer encoder network to boost its accuracy. I need to drop tensorflow so I tried rewriting the model for pytorch but I think it handles dimensions differently.
This is my pytorch model:

def CRNN(shape, num_classes, use_trencod, learning_rate=0.001):
    # Inputs to the model
    x = layers.Input(shape=shape, name='image', dtype='float32')

    # First conv block
    x = layers.Conv2D(
        32,
        (3, 3),
        activation="relu",
        kernel_initializer="he_normal",
        padding="same",
        name="Conv1",
    )(x) 
    x = layers.MaxPooling2D((2, 2), name="pool1")(x)

    # Second conv block
    x = layers.Conv2D(
        64,
        (3, 3),
        activation="relu",
        kernel_initializer="he_normal",
        padding="same",
        name="Conv2",
    )(x)
    x = layers.MaxPooling2D((2, 2), name="pool2")(x)

    # Third conv block
    x = layers.Conv2D(
        128,
        (3, 3),
        activation="relu",
        kernel_initializer="he_normal",
        padding="same",
        name="Conv3",
    )(x)
    x = layers.MaxPooling2D((2, 2), name="pool3")(x)

    # Fourth conv block
    x = layers.Conv2D(
        256,
        (3, 3),
        activation="relu",
        kernel_initializer="he_normal",
        padding="same",
        name="Conv4",
    )(x)
    x = layers.MaxPooling2D((2, 2), name="pool4")(x)

    # We have used two max pool with pool size and strides 2.
    # Hence, downsampled feature maps are 4x smaller. The number of
    # filters in the last layer is 64. Reshape accordingly before
    # passing the output to the RNN part of the model
    new_shape = ((shape[0] // 4), (shape[1] // 4) * 16) # 64
    x = layers.Reshape(target_shape=new_shape, name="reshape")(x)

    x = layers.Dense(256, name="dense1")(x) # 128
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Dropout(0.2)(x) # 0.2

    # RNNs
    x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.25))(x)
    x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.25))(x)

    if use_trencod:
        # Encoder-only Transformer (https://dl.acm.org/doi/abs/10.1145/3558100.3563845)
        x = TransformerEncoder(intermediate_dim=256, num_heads=8, dropout=0.2, name='trencod')(x)

    # Output layer
    outputs = layers.Dense(num_classes + 1, activation="softmax", name="dense2")(x)

    # Define the model
    model = keras.models.Model( # CustomModel
        inputs=input_img, outputs=outputs, name="ocr_model_v2"
    )
    
    # Optimizer
    optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

    # Compile the model and return
    model.compile(optimizer=optimizer, loss=ctc_loss, metrics=[cer, wer], run_eagerly=False)

    return model

I’m ignoring use_trencod for now just so I can get the model working first.
This is the pytorch version:

class CRNN(nn.Module):
	def __init__(self, shape=(256, 64, 1), num_conv=4, num_chars):
		super(CRNN, self).__init__()
		width, height, channels = shape
		self.convs = nn.ModuleList([ nn.Conv2d(channels if i == 0 else 2 ** (i+4), 2 ** (i+5), kernel_size=3, padding=1) for i in range(num_conv) ])
		self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

		self.fc1 = nn.Linear(256 * (width // 16) * (height // 16), 256)

		self.batch_norm = nn.BatchNorm1d(256)
		self.dropout = nn.Dropout(0.2)

		self.lstm1 = nn.LSTM(256, 128, batch_first=True, bidirectional=True) # , dropout=0.25
		self.lstm2 = nn.LSTM(256, 128, batch_first=True, bidirectional=True) # , dropout=0.25

		self.fc2 = nn.Linear(256, num_chars + 1)

	def forward(self, x):
		for conv in self.convs:
			x = self.pool(nn.ReLU()(conv(x)))

		x = x.view(x.size(0), -1)

		x = self.fc1(x)
		x = self.batch_norm(x)
		x = nn.ReLU(inplace=True)(x)

		x = self.dropout(x)

		x, _ = self.lstm1(x)
		x, _ = self.lstm2(x)

		x = self.fc2(x)
		x = nn.Softmax(dim=1)(x)

		return x

Input shape is [batch, channels, width, height]
Label shape is [batch, max_length]
Meanwhile the model output shape (which I assume is wrong) is [batch, num_chars]

I think its obvious I made some mistake when reshaping with ‘view’ or even somewhere else, some help would be very welcome. Thank you!

Edit: the first training step looks ok but then on the second step they become Nan and with ‘torch.autograd.set_detect_anomaly(True)’ it throws ‘RuntimeError: Function ‘CtcLossBackward0’ returned nan values in its 0th output.’