Hi, I am currently working on a project that involves designing a Reinforcement Learning framework aimed at generating high-quality radiofrequency (RF) pulses for Magnetic Resonance Imaging (MRI).

My network architecture consists of a **Gated Recurrent Unit (GRU) cell followed by several dense layers**. The final layer produces three output values: the mean values mu_a and mu_p, which are used to create Normal distributions N(mu_a, σ) and N(mu_p, σ), respectively, from which both amplitude and phase values are sampled. Additionally, the third output value represents the value function.

The main objective achieved by the loop in the code is to **generate 32 (amplitude, phase) pairs sequentially**, which I will subsequently interpolate to create 256 pairs and feed into a Gym environment that simulates an MRI system to calculate rewards.

To accomplish this task, I am using a **batch size of 256**, implying that during each of the 32 iterations in the loop, I attempt to produce a new (amplitude, phase) pair for every one of the 256 RF pulses.

The input vector provided to the GRU has a shape of (256, 1, 2), indicating one pair of (amplitude, phase) values for each of the 256 RF pulses. On the other hand, the state vector input required by the GRU should have a shape of (256, 256), signifying that there exists one state consisting of 256 values per RF pulse. At each iteration, the preceding state_out serves as the new state_in.

Here is my code:

```
#Network Definition
class SharedNetwork(nn.Module):
def __init__(self, hidden_sizes):
super(SharedNetwork, self).__init__()
self.gru = nn.GRU(2, 256, batch_first=True)
dense_layers = []
prev_dim = hidden_sizes[0]
for hidden_size in hidden_sizes[1:]:
dense_layers.append(nn.Linear(prev_dim, hidden_size))
dense_layers.append(nn.ReLU())
prev_dim = hidden_size
dense_layers.pop() # Remove last ReLU
dense_layers.append(nn.Linear(prev_dim, 3))
self.mlp = nn.Sequential(*dense_layers)
def forward(self, x, state_in):
gru_out, _ = self.gru(x, state_in)
dense_out = self.mlp(gru_out[:, -1, :])
return dense_out, gru_out
#RL Agent behaviour
class PPO:
def __init__(self):
self.model = SharedNetwork(hidden_sizes=(256,128,64,32)).to(device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr)
def get_action(self, rnn_in, state_in):
output, state_out = self.model(rnn_in, state_in)
mean=torch.cat((output[:, 0].unsqueeze(-1), output[:, 1].unsqueeze(-1)), dim=1)
value_out=output[:, 2]
log_std = torch.Tensor([[math.log(max_rad * (1 / args.amp) * args.amp_std), math.log(2 * np.pi * (1 / args.ph) * args.ph_std)]]).to(device)
std = torch.exp(log_std)
r = torch.empty_like(mean).normal_(mean=0.0, std=1.0).to(device)
pi = mean + r * std
pi_clip = torch.clamp(pi, min=torch.Tensor([EPS, -np.inf]).to(device), max=torch.Tensor([max_rad * (1 / args.amp), np.inf]).to(device))
middle1 = torch.exp(torch.multiply((pi_clip - mean).pow(2), (-1.0 / (2.0 * std.pow(2))))).prod(dim=-1, keepdim=True)
p_pi_a = middle1 * (1 / (torch.sqrt(torch.tensor(np.pi * 2.0, device=device)) * std)).prod(dim=-1, keepdim=True) + EPS
return pi_clip, p_pi_a.squeeze(dim=-1), value_out, state_out
agent=PPO()
rnn_in_ = torch.ones((256,1, 2), dtype=torch.float32, device=device)
state_in_ = torch.zeros((1, 256, 256), dtype=torch.float32, device=device)
#Inner loop, 1 (Amp,Phase) pair generated per iteration for each of the 256 RF pulse
for ep in range(32):
pi_clip, p_pi, value_out, state_out=agent.get_action(rnn_in_, state_in_)
rnn_in_= torch.clone(pi_clip).unsqueeze(0)
state_in_=state_out.permute(1,0,2)
```

I currently have this error happening inside the loop, related to the dimensionality of the hidden state of my GRU unit:

```
File "../envs/new_torch_generation.py", line 228, in <module>
pi_clip, p_pi, value_out, state_out=agent.get_action(rnn_in_, state_in_)
File "../envs/new_torch_generation.py", line 136, in get_action
output, state_out = self.model(rnn_in, state_in)
...
File "/home/rohitkumar/.conda/envs/october/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 226, in check_hidden_size
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
RuntimeError: Expected hidden size (1, 1, 256), got [1, 256, 256]
```

This error appears solely during the second iteration (ep=1), whereas the initial iteration runs smoothly with `state_in_`

having a shape of `torch.Size([1, 256, 256])`

, and `state_out`

with a shape of `torch.Size([256, 1, 256])`

. To resolve this discrepancy, I perform tensor reshaping via dimension permutation prior to supplying the updated state back into the GRU for the next iteration.

**Could someone help me spot the error?**

Is the expected hidden size changing between 2 iterations?

Is there a problem with my GRU Cell definition/Tensor shapes?

Am I handling the improper output from the GRU cell among its two possible returns?