The training data is floating point. The epoches is set to 10000,then jump to custom function. In the function, the batch_num is 95. I set a batch of data to be 29 rows. Each row of data is 20 dimensions. I try to run on 3080. but it is too slow. How can I solve it?
the main func:
for epo in range(epoch):
print(“In iteration: %d” % (epo + 1))
mse_all = model.SAE_GRU_Network(X, Y, U, time_steps, batch_size, hidden_size)
loss = mse_all.cuda()
opt.zero_grad()
loss.backward()
opt.step()
print(epo+1, sqrt(mse_all/(batch_size*time_steps*data_dim)))
the SAE_GRU_Network:
def SAE_GRU_Network(self, X, Y, U, time_steps, batch_size, hidden_size):
mse_batch = 0
h0 = torch.zeros(1, hidden_size, device=device)
for i in range(batch_size):
x_input = X[i][0]
mse_steps = 0
h = h0
for j in range(time_steps):
y_target = Y[i][j]
u_input = U[i][j]
h, y_output = self.one_step(h, x_input, u_input)
x_input = y_output
y_true = torch.reshape(y_target, [-1])
y_pred = torch.reshape(y_output, [-1])
mse_steps += torch.sum(torch.square(y_true - y_pred))
mse_batch += mse_steps
return mse_batch