Stack python lists to tensor for following computations

I’m trying to remove an ugly and inefficient nested for-loop in my forward pass with matrix operations, but the result I came (stacking my lists to one tensor) to, makes my network not learn anymore. (Simplified versions of my forward pass are below.)

For the output I need to combine 4 states for each pixel of an 2D image. In the final for loop I iterate over all points and concat the 4 corresponding states to one tensor (batch_size x 4hidden_size) to pass them through a linear layer with dimensions (4 hidden_size, num_classes).

To be more efficient I tried to convert the lists to a tensor that holds all states and reshape it, so the last dimension contains all 4 hidden states (i.e. height x width x batch_size x 4* hidden_size). Then I pass the whole tensor through the same linear layer.

My linear and softmax layer are created exactly the same way in both versions, but from my understanding this should work if the last dimension of input to them is the same.
All the resulting shapes match up, however just computing the error on the newly initialized network is more than doubled from version 1 to 2 (same seeds, multiple tries) and version 2 is not able to learn anymore.

Version 1:

result = torch.zeros(height, width, batch_size, num_classes)
states = []

for d in directions:  # always 4 in my case    
    h_states = []
    for h in range(height):
        w_states = []
        for w in range(width):
            state = compute_hidden_state(input[h, w], last_h_state, last_w_state)
            w_states.append(state)
        h_states.append(w_states)
    states.append(h_states)

# states is now a list of dimensions 4 x height x width x batch_size x hidden_size

for h in range(height):
    for w in range(width):
        current_states = torch.cat((states[0][h][w], states[1][h][w], states[2][h][w], states[3][h][w]), 1)
        output = linear_layer(current_states)
        result[h, w] = softmax_layer(output)

Version 2:

states = []

for d in directions:  # always 4 in my case    
    h_states = []
    for h in range(height):
        w_states = []
        for w in range(width):
            state = compute_hidden_state(input[h, w], last_h_state, last_w_state)
            w_states.append(state)
        h_states.append(torch.stack(w_states))
    states.append(torch.stack(h_states))

states_tensor = torch.stack(states)

# states is now a tensor of dimensions 4 x height x width x batch_size x hidden_size

states_tensor = states_tensor.permute(1, 2, 3, 0, 4).contiguous().view(height, width, batch_size, -1)
output = linear_layer(states_tensor)
result = softmax_layer(output)

(Not sure why my code blocks format differently)

As always, my mistake was somewhere completely different, sorry for wasting your time if you read this.