Hi. I’m trying to implement CNN-LSTM and also have same problem. I don’t have ‘+=’ operations(I think), but squeezing and unsqueezing. Was twiking some, but still haven’t found the correct solution. I marked where the error occurs in the code with # <<<< ========ERROR .
class A2C(nn.Module):
def __init__(self, input_size, num_actions, args, convolution=True):
super(A2C, self).__init__()
self.num_actions = num_actions
self.convolution = convolution
self.sequence_lenght = args.sequence_lenght
self.batch_size = args.batch_size
self.lstm_hidden_size = args.lstm_hidden_size
if convolution:
convolution_layers = [nn.Conv2d(input_size[0], 512, kernel_size=8, stride=4), nn.ReLU(),
nn.Conv2d(512, 256, kernel_size=4, stride=2), nn.ReLU(),
nn.Conv2d(256, 128, kernel_size=3, stride=2), nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=3, stride=1), nn.ReLU()]
self.conv = nn.Sequential(*convolution_layers)
self.input_size = self.get_conv_out(input_size)
if convolution == False:
self.input_size = input_size[0]
self.lstm = nn.LSTM(self.input_size, self.lstm_hidden_size, num_layers = 1, batch_first=True)
if args.hidden_layers_num == 1:
layers_a = [nn.Linear(self.lstm_hidden_size, args.hidden_1), nn.ReLU(), nn.Linear(args.hidden_1, num_actions)]
layers_c = [nn.Linear(self.lstm_hidden_size, args.hidden_1), nn.ReLU(), nn.Linear(args.hidden_1, 1)]
if args.hidden_layers_num != 1:
layers_a = [nn.Linear(self.lstm_hidden_size, args.hidden_1), nn.ReLU(), nn.Linear(args.hidden_1, args.hidden_2), nn.ReLU(), nn.Linear(args.hidden_2, num_actions)]
layers_c = [nn.Linear(self.lstm_hidden_size, args.hidden_1), nn.ReLU(), nn.Linear(args.hidden_1, args.hidden_2), nn.ReLU(), nn.Linear(args.hidden_2, 1)]
self.Actor = nn.Sequential(*layers_a)
self.Critic = nn.Sequential(*layers_c)
def get_conv_out(self, shape):
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, obs, h_x, c_x):
obs = torch.FloatTensor([obs]).to(device)
if self.convolution:
batch_size = obs.size()[0]
obs = self.conv(obs).view(1, -1).unsqueeze(0)
obs, (h_x, c_x) = self.lstm(obs, (h_x, c_x))
logits = self.Actor(obs).squeeze(0) # **<<<< ========ERROR**
values = self.Critic(obs).squeeze(0)
action_probs = F.softmax(logits, dim=1).cpu().detach().numpy()[0]
action = np.random.choice(self.num_actions, p=action_probs)
return action, logits, values, h_x, c_x
def init_hidden(self):
return torch.zeros(1, 1, self.lstm_hidden_size), torch.zeros(1, 1, self.lstm_hidden_size)
This is my sample generator. Env is continuous and I’m taking n steps as sequence and saving last detached hidden state as start of another loop. (could be a problem if using multiple environments or detach in every env is enough to separate them?)
def play(self, net, device):
if self.first == True:
self.state = self.env.reset()
self.first = False
done = False
if self.h_x == None:
self.h_x, self.c_x = net.init_hidden()
values = []
logits_ = []
actions = []
rewards = []
total_reward = 0.0
_idx = 0
while True:
action, logits, value, self.h_x, self.c_x = net(self.state, self.h_x.to('cuda'), self.c_x.to('cuda')) # **<<<< ========ERROR**
next_state, reward, done, _ = self.env.step(action)
if _idx == 0:
reward-=2*(self.env.trade_fees) * self.env.leverage * 10_000
_idx = _idx+1
values.append(value)
logits_.append(logits)
actions.append(action)
if done and self.if_trading_env == False:
reward = -1 # <---
rewards.append(reward)
total_reward+=reward
self.state = next_state
if len(actions) >= args.sequence_lenght:
self.h_x = self.h_x.detach()
self.c_x = self.c_x.detach()
return values, logits_, actions, discounted_rewards(rewards, self.gamma), total_reward
My training loop.
idx = 0
while True:
batch_counter = 0
batch_values = []
batch_logits = []
batch_actions =[]
batch_vals_ref = []
while True:
for env in enviroments:
values, _logits, actions, vals_ref, total_reward = env.play(net, device) #**<<<< ========ERROR**
batch_values.append(values)
batch_logits.append(_logits)
batch_actions.append(actions)
batch_vals_ref.append(vals_ref)
episodes_rewrds.append(total_reward)
batch_counter += 1
if batch_counter >= args.batch_size:
break
if batch_counter >= args.batch_size:
break
for i in range(len(batch_values)):
torch.cuda.empty_cache()
values_v = torch.stack(batch_values[i]).to(device)
logits_v = torch.stack(batch_logits[i]).squeeze(1).to(device)
actions_t = torch.LongTensor(batch_actions[i]).to(device)
vals_ref_v = torch.FloatTensor(batch_vals_ref[i]).to(device)
net.zero_grad()
value_loss = args.zeta * F.mse_loss(values_v.squeeze(-1).squeeze(-1), vals_ref_v)
advantage = vals_ref_v - values_v.detach()
log_probs = F.log_softmax(logits_v, dim=1)
log_prob_action = advantage * log_probs[range(len(actions_t)), actions_t]
policy_loss = - log_prob_action.mean()
actions_probs = F.softmax(logits_v, dim=1)
entropy_loss = - args.entropy_beta * (actions_probs * log_probs).sum(dim=1).mean()
total_policy_loss = policy_loss + entropy_loss
total_policy_loss.backward(retain_graph=True) # **<<<< ========ERROR**
value_loss.backward()
nn_utils.clip_grad_norm_(net.parameters(), args.clip_grad)
optimizer.step()
idx +=1
print(idx, round(np.mean(episodes_rewrds), 2))
torch.save(net.state_dict(), NET_PARAMS_PATH)
if np.mean(episodes_rewrds) > 1_000_000:
break
And this is my error.
Warning: Error detected in MmBackward. Traceback of forward call that caused the error:
File "E:\Market Data Collection\crypto_gym\A2C_LSTM_multi_1.0.py", line 320, in <module>
values, _logits, actions, vals_ref, total_reward = env.play(net, device)
File "E:\Market Data Collection\crypto_gym\A2C_LSTM_multi_1.0.py", line 257, in play
action, logits, value, self.h_x, self.c_x = net(self.state, self.h_x.to('cuda'), self.c_x.to('cuda'))
File "C:\Python38\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "E:\Market Data Collection\crypto_gym\A2C_LSTM_multi_1.0.py", line 202, in forward
logits = self.Actor(obs).squeeze(0)
File "C:\Python38\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "C:\Python38\lib\site-packages\torch\nn\modules\container.py", line 100, in forward
input = module(input)
File "C:\Python38\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "C:\Python38\lib\site-packages\torch\nn\modules\linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "C:\Python38\lib\site-packages\torch\nn\functional.py", line 1612, in linear
output = input.matmul(weight.t())
(print_stack at ..\torch\csrc\autograd\python_anomaly_mode.cpp:60)
Traceback (most recent call last):
File "E:\Market Data Collection\crypto_gym\A2C_LSTM_multi_1.0.py", line 355, in <module>
total_policy_loss.backward(retain_graph=True)
File "C:\Python38\lib\site-packages\torch\tensor.py", line 198, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "C:\Python38\lib\site-packages\torch\autograd\__init__.py", line 98, in backward
Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1024, 2]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Also have a question. Can I shuffle samples from multiple batches for training when using LSTM in my code. Hidden states are already preserved in samples when I was adding them on environment iteration?