Nan in encoder output

I am working on an encoder-decoder architecture to perform regression for a family of sinusoidal functions. For training, my encoder takes in a random subset of input training pairs (total pairs = 40 for each function) and produces a corresponding feature representation (mean averaged over all chosen subset pairs). Then, the decoder takes this feature representation and all the input values to make predict.

This is my initial model. It may have mistakes, but my inital concern is that Nan values are outputted after a certain point from encoder. I have not been able to figure out why. I have checked that my inputs do not have Nan and the learning rate is appropriate.

Nf = 2000 # the number of different functions f that we will generate
Npts = 40 # the number of x values that we will use to generate each fa
x = torch.zeros(Nf, Npts, 1)
for k in range(Nf):
    x[k,:,0] = torch.linspace(-2, 2, Npts)

x += torch.rand_like(x)*0.1
a = -2 + 4*torch.rand(Nf).view(-1,1).repeat(1, Npts).unsqueeze(2)
y = a*torch.sin(x+a)

LR = 0.0001
MAX_EPOCH = 20
BATCH_SIZE = 80

#  use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

X_train, X_val, y_train, y_val = map(torch.tensor, train_test_split(x, y, test_size=0.2))

train_dataloader = DataLoader(TensorDataset(X_train.squeeze(-1), y_train.squeeze(-1)), batch_size=BATCH_SIZE,
                              pin_memory=True, shuffle=True)
val_dataloader = DataLoader(TensorDataset(X_val.squeeze(-1), y_val.squeeze(-1)), batch_size=BATCH_SIZE,
                            pin_memory=True, shuffle=True)

class Encoder(nn.Module):
  def __init__(self, num_inputs, num_hidden_1, num_hidden_2, num_outputs):
    super(Encoder, self).__init__()
    self.encoder_hidden_layer_1 = nn.Linear(in_features=num_inputs, out_features=num_hidden_1)
    self.encoder_hidden_layer_2 = nn.Linear(in_features=num_hidden_1, out_features=num_hidden_2)
    self.encoder_output_layer = nn.Linear(in_features=num_hidden_2, out_features=num_outputs)

  def forward(self, x):
    x.to(device)
    x1 = F.relu(self.encoder_hidden_layer_1(x))
    x2 = F.relu(self.encoder_hidden_layer_2(x1))
    feature_representation = self.encoder_output_layer(x2)
    
    return feature_representation

class Decoder(nn.Module):
  def __init__(self, num_inputs, num_hidden_1, num_hidden_2, num_outputs):
    super(Decoder, self).__init__()

    self.decoder_hidden_layer_1 = nn.Linear(in_features=num_inputs, out_features=num_hidden_1)
    self.decoder_hidden_layer_2 = nn.Linear(in_features=num_hidden_1, out_features=num_hidden_2)
    self.decoder_output_layer = nn.Linear(in_features=num_hidden_2, out_features=num_outputs)

  def forward(self, x):
    x.to(device)
    x1 = F.relu(self.decoder_hidden_layer_1(x))
    x2 = F.relu(self.decoder_hidden_layer_2(x1))
    y_pred = self.decoder_output_layer(x2)

    return y_pred 

class EncoderDecoder(nn.Module):
  def __init__(self, encoder, decoder, enc_outputs):
    
    super(EncoderDecoder, self).__init__()
    self.num_outputs = enc_outputs
    self.encoder = encoder
    self.decoder = decoder

  def forward(self, x_c, y_c, x_t):
    x_c.to(device)
    y_c.to(device)
    x_t.to(device)

    r_c = torch.empty(1, self.num_outputs).to(device)
    for x, y in zip(torch.transpose(x_c,0,1), torch.transpose(y_c,0,1)):
        enc_input = torch.stack([x,y])
        enc_input = enc_input.view(-1,torch.numel(enc_input))
        enc_output = self.encoder(enc_input)

        r_c = torch.add(r_c, enc_output)

    r_c = torch.div(r_c,torch.numel(x_c))
    dec_input = torch.cat((x_t, r_c),-1)
    dec_input = dec_input.view(-1,torch.numel(dec_input))
    output = self.decoder(dec_input)

    return output

# Model instantiation and initialisation 
def init_weights(m):
  if type(m) == nn.Linear:
      torch.nn.init.xavier_uniform_(m.weight)
      m.bias.data.fill_(0.01)

# Create and initialize encoder:
enc_inputs, enc_hidden_1, enc_hidden_2, enc_outputs = 2, 16, 16, 8
encoder = Encoder(enc_inputs, enc_hidden_1, enc_hidden_2, enc_outputs).to(device)
encoder.apply(init_weights)

# Create and initialize decoder:
dec_inputs, dec_hidden_1, dec_hidden_2, dec_outputs = 9, 32, 16, 1
decoder = Decoder(dec_inputs, dec_hidden_1, dec_hidden_2, dec_outputs).to(device)
decoder.apply(init_weights)

# Create and initialize EncoderDecoder:
model = EncoderDecoder(encoder,decoder,enc_outputs).to(device)

# Adam optimizer with learning rate LR
optimizer = optim.Adam(model.parameters(), lr=LR)

# mean-squared error loss
criterion = nn.MSELoss()

def train(model, train_dataloader, optimizer, criterion, MAX_EPOCH):
  for epoch in range(MAX_EPOCH):
    model.train()
    loss = 0
    for batch_X, batch_y in train_dataloader: 
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)       

        for x_t, y_t in zip(batch_X, batch_y):

            # number of context pairs
            Nc = random.randint(1,Npts)

            # create random subset of x_t and y_t
            x_c, y_c = zip(*random.sample(list(zip(x_t,y_t)), Nc))
            x_c = torch.stack(x_c)
            x_c = x_c.view(-1,torch.numel(x_c))
            y_c = torch.stack(y_c)
            y_c = y_c.view(-1,torch.numel(y_c))

            for x, y in zip(x_t, y_t):

                x = x.view(-1,torch.numel(x))
                y = y.view(-1,torch.numel(y))

                # compute outputs of model
                output = model(x_c, y_c, x)

                # compute training loss
                train_loss = criterion(output, y)
                
                # reset the gradients back to zero
                # PyTorch accumulates gradients on subsequent backward passes
                optimizer.zero_grad()

                # compute accumulated gradients
                train_loss.backward()
                
                # perform parameter update based on current gradients
                optimizer.step()
                
                # add the mini-batch training loss to epoch loss
                loss += train_loss.item()

    # compute the epoch training loss
    loss = loss/len(train_dataloader)
    
    # display the epoch training loss
    print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, MAX_EPOCH, loss))