The training data is floating point. The epoches is set to 10000,then jump to custom function. In the function, the batch_num is 95. I set a batch of data to be 29 rows. Each row of data is 20 dimensions. I try to run on 3080. but it is too slow. How can I solve it?

the main func:

for epo in range(epoch):

print(“In iteration: %d” % (epo + 1))

mse_all = model.SAE_GRU_Network(X, Y, U, time_steps, batch_size, hidden_size)

loss = mse_all.cuda()

opt.zero_grad()

loss.backward()

opt.step()

```
print(epo+1, sqrt(mse_all/(batch_size*time_steps*data_dim)))
```

the SAE_GRU_Network:

def SAE_GRU_Network(self, X, Y, U, time_steps, batch_size, hidden_size):

mse_batch = 0

h0 = torch.zeros(1, hidden_size, device=device)

for i in range(batch_size):

x_input = X[i][0]

mse_steps = 0

h = h0

for j in range(time_steps):

y_target = Y[i][j]

u_input = U[i][j]

h, y_output = self.one_step(h, x_input, u_input)

x_input = y_output

y_true = torch.reshape(y_target, [-1])

y_pred = torch.reshape(y_output, [-1])

mse_steps += torch.sum(torch.square(y_true - y_pred))

mse_batch += mse_steps

return mse_batch