GRU cell expected hidden size changes

Hi, I am currently working on a project that involves designing a Reinforcement Learning framework aimed at generating high-quality radiofrequency (RF) pulses for Magnetic Resonance Imaging (MRI).

My network architecture consists of a Gated Recurrent Unit (GRU) cell followed by several dense layers. The final layer produces three output values: the mean values mu_a and mu_p, which are used to create Normal distributions N(mu_a, σ) and N(mu_p, σ), respectively, from which both amplitude and phase values are sampled. Additionally, the third output value represents the value function.

The main objective achieved by the loop in the code is to generate 32 (amplitude, phase) pairs sequentially, which I will subsequently interpolate to create 256 pairs and feed into a Gym environment that simulates an MRI system to calculate rewards.

To accomplish this task, I am using a batch size of 256, implying that during each of the 32 iterations in the loop, I attempt to produce a new (amplitude, phase) pair for every one of the 256 RF pulses.

The input vector provided to the GRU has a shape of (256, 1, 2), indicating one pair of (amplitude, phase) values for each of the 256 RF pulses. On the other hand, the state vector input required by the GRU should have a shape of (256, 256), signifying that there exists one state consisting of 256 values per RF pulse. At each iteration, the preceding state_out serves as the new state_in.

Here is my code:

#Network Definition
class SharedNetwork(nn.Module):

    def __init__(self, hidden_sizes):
        super(SharedNetwork, self).__init__()
        self.gru = nn.GRU(2, 256, batch_first=True)
        dense_layers = []
        prev_dim = hidden_sizes[0]
        for hidden_size in hidden_sizes[1:]:
            dense_layers.append(nn.Linear(prev_dim, hidden_size))
            prev_dim = hidden_size
        dense_layers.pop()  # Remove last ReLU
        dense_layers.append(nn.Linear(prev_dim, 3))
        self.mlp = nn.Sequential(*dense_layers)

    def forward(self, x, state_in):
        gru_out, _ = self.gru(x, state_in)
        dense_out = self.mlp(gru_out[:, -1, :])
        return dense_out, gru_out

#RL Agent behaviour
class PPO:

    def __init__(self):
        self.model = SharedNetwork(hidden_sizes=(256,128,64,32)).to(device)
        self.optimizer = torch.optim.Adam(self.model.parameters(),

    def get_action(self, rnn_in, state_in):
        output, state_out = self.model(rnn_in, state_in)[:, 0].unsqueeze(-1), output[:, 1].unsqueeze(-1)), dim=1)
        value_out=output[:, 2]
        log_std = torch.Tensor([[math.log(max_rad * (1 / args.amp) * args.amp_std), math.log(2 * np.pi * (1 / * args.ph_std)]]).to(device)
        std = torch.exp(log_std)

        r = torch.empty_like(mean).normal_(mean=0.0, std=1.0).to(device)
        pi = mean + r * std
        pi_clip = torch.clamp(pi, min=torch.Tensor([EPS, -np.inf]).to(device), max=torch.Tensor([max_rad * (1 / args.amp), np.inf]).to(device))

        middle1 = torch.exp(torch.multiply((pi_clip - mean).pow(2), (-1.0 / (2.0 * std.pow(2))))).prod(dim=-1, keepdim=True)
        p_pi_a = middle1 * (1 / (torch.sqrt(torch.tensor(np.pi * 2.0, device=device)) * std)).prod(dim=-1, keepdim=True) + EPS

        return pi_clip, p_pi_a.squeeze(dim=-1), value_out, state_out

rnn_in_ = torch.ones((256,1, 2), dtype=torch.float32, device=device) 
state_in_ = torch.zeros((1, 256, 256), dtype=torch.float32, device=device)

#Inner loop, 1 (Amp,Phase) pair generated per iteration for each of the 256 RF pulse
for ep in range(32):

    pi_clip, p_pi, value_out, state_out=agent.get_action(rnn_in_, state_in_)

    rnn_in_= torch.clone(pi_clip).unsqueeze(0)


I currently have this error happening inside the loop, related to the dimensionality of the hidden state of my GRU unit:

 File "../envs/", line 228, in <module>
    pi_clip, p_pi, value_out, state_out=agent.get_action(rnn_in_, state_in_)
  File "../envs/", line 136, in get_action
    output, state_out = self.model(rnn_in, state_in)
 File "/home/rohitkumar/.conda/envs/october/lib/python3.6/site-packages/torch/nn/modules/", line 226, in check_hidden_size
    raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
RuntimeError: Expected hidden size (1, 1, 256), got [1, 256, 256]

This error appears solely during the second iteration (ep=1), whereas the initial iteration runs smoothly with state_in_ having a shape of torch.Size([1, 256, 256]), and state_out with a shape of torch.Size([256, 1, 256]). To resolve this discrepancy, I perform tensor reshaping via dimension permutation prior to supplying the updated state back into the GRU for the next iteration.

Could someone help me spot the error?
Is the expected hidden size changing between 2 iterations?
Is there a problem with my GRU Cell definition/Tensor shapes?
Am I handling the improper output from the GRU cell among its two possible returns?

@ptrblck Hey, any chance you could help me on this?

I don’t fully understand your use case and code.
E.g. in SharedNetwork you are using:

    def forward(self, x, state_in):
        gru_out, _ = self.gru(x, state_in)
        dense_out = self.mlp(gru_out[:, -1, :])
        return dense_out, gru_out

which is assigned as:

output, state_out = self.model(rnn_in, state_in)

claiming gru_out (the original output of the GRU) is the state, which is not the case and a bit confusing.
Note that state_out has the shape [batch_size, seq_len, out_features], i.e. [256, 1, 256]. Later you are permuting it via:


which creates a tensor in the shape [seq_len=1, batch_size=256, features=256].
However, rnn_in_ was unsqueezed to [1, 256, 2] and does fit into the expected shape anymore.
I assume you wanted to unsqueeze the seq_len dimension instead of the batch dimension in rnn_in_, so this should work:

rnn_in_= torch.clone(pi_clip).unsqueeze(1)

However, make sure the actual dimensions are treated as expected.

Thank you! It does work.

However, now I am wondering which one of the 2 returns of the GRU I should actually be dealing with.
I thought that passing gru_out[:, -1, :] was actually the same as using h_n but I may have not understood correctly the GRU cell usage.
My goal is indeed is to pass the hidden state iteratively to the GRU during the 32 iterations (1 state of 256 values per RF pulse, knowing there is 256 RF pulses).

Does my actual implementation make sense?