I am new to PyTorch and have written the following code.

```
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import numpy as np
import itertools
import datetime
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(4, 32, bias=False)
self.fc2 = nn.Linear(32, 16, bias=False)
self.fc3 = nn.Linear(16, 7, bias=False)
def forward(self, x):
x = self.fc1(x)
x = torch.tanh(x)
x = self.fc2(x)
x = torch.tanh(x)
x = self.fc3(x)
output = torch.tanh(x)
return output
def channel(codeword, snr_db, device):
snr_value = 10 ** (snr_db / 10)
h_real = torch.normal(mean=0, std=1, size=(codeword.shape[0], 1)) * torch.sqrt(torch.as_tensor(1/2))
h_imag = torch.normal(mean=0, std=1, size=(codeword.shape[0], 1)) * torch.sqrt(torch.as_tensor(1/2))
h_real_t = h_real.repeat(1, codeword.shape[1]).to(device)
h_imag_t = h_imag.repeat(1, codeword.shape[1]).to(device)
noise_real = torch.normal(mean=0, std=1, size=codeword.shape) * torch.sqrt(torch.as_tensor(1/(2*snr_value)))
noise_imag = torch.normal(mean=0, std=1, size=codeword.shape) * torch.sqrt(torch.as_tensor(1/(2*snr_value)))
noise_real = noise_real.to(device)
noise_imag = noise_imag.to(device)
faded_cw_real = torch.mul(h_real_t, codeword) + noise_real
faded_cw_imag = torch.mul(h_imag_t, codeword) + noise_imag
return torch.cat([faded_cw_real[:, :, None], faded_cw_imag[:, :, None]], dim=2), h_real, h_imag
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.bigru = nn.GRU(input_size=4, hidden_size=100, bidirectional=True, batch_first=False)
# self.fc0 = nn.Linear(4, 1)
self.fc1 = nn.Linear(200, 16)
def forward(self, x):
x, states = self.bigru(x)
output = torch.squeeze(self.fc1(x[-1]))
output = torch.tanh(output)
# output = self.fc1(output)
# output = torch.tanh(output)
# output = torch.softmax(output, dim=0)
return output
def train(args, model1, model2, device, optimizer, epoch, snr):
model1.train()
model2.train()
count = 1000
for i in range(count):
data = np.array([list(j) for j in itertools.product([-1, 1], repeat=4)])
p = np.random.permutation(16)
# p = np.random.randint(low=0, high=16, size=(16,))
train_data = data[p]
data_one_hot = np.eye(16)
truth = data_one_hot[p]
# truth = torch.as_tensor(truth).to(device).float() # Uncomment this for BCE loss
train_data = torch.as_tensor(train_data).float()
train_data = train_data.to(device)
# optimizer1.zero_grad()
optimizer.zero_grad()
output = model1(train_data)
output = output.to(device)
ch_out, h_r, h_i = channel(output, snr, device)
h_r = torch.as_tensor(h_r[:, :, None].repeat(1, 7, 1)).to(device)
h_i = torch.as_tensor(h_i[:, :, None].repeat(1, 7, 1)).to(device)
dec_ip = torch.cat([ch_out, h_r, h_i], 2)
dec_ip = torch.transpose(dec_ip, 2, 1)
hat = model2(torch.as_tensor(dec_ip).float())
loss_d = F.mse_loss(hat, train_data)
# loss_d = F.binary_cross_entropy(hat, truth)
loss_d.backward()
optimizer.step()
if i % 10 == 0:
# print(f"Train epoch: {epoch}, Batch: {i}, Encoder Loss: {loss_e.item()}, SNR: {snr}")
print(f"Train epoch: {epoch}, Batch: {i}, Decoder Loss: {loss_d.item()}, SNR: {snr}")
def main():
epochs = 14
learning_rate = 1
learning_rate_step = 0.7
no_cuda = False
log_interval = 10
use_cuda = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
enc_model = Encoder().to(device)
dec_model = Decoder().to(device)
optimizer = optim.Adam(list(dec_model.parameters())+list(enc_model.parameters()), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=1, gamma=learning_rate_step)
for epoch in range(1, epochs+1):
snr = 20 - 20 * epoch / epochs
train(log_interval, enc_model, dec_model, device, optimizer, epoch, snr)
scheduler.step()
if __name__ == "__main__":
main()
```

However when I run this, the output is:

```
Train epoch: x, Batch: y, Decoder Loss: 2.0, SNR: z
```

x, y and z values are dependent on the iteration. The decoder loss is stuck at 2.0. Is there something simply wrong in the code?

Also, if this is not the correct forum to ask this question, please guide me.