Hi Mathias and @albanD:
Thank you for the fast reply and very clear explanation. I fix this bug as following:
Define STE-function
class STEFunction(autograd.Function):
def forward(ctx, input_x, inter_state):
ttemp = torch.squeeze(input_x[:, -1, :]).requires_grad_(True).detach().cpu().numpy()
ttemp = ttemp.transpose()
if inter_state is None:
_, tout, inter_state = control.forced_response(env.AMP_signal_acc, env.mat_step4['t'][0][
1:11], # batch_size=10
updated_state = inter_state[:, -1]
inter_state = inter_state.requires_grad_(True).detach().cpu().numpy()
_, tout, inter_state = control.forced_response(env.AMP_signal_acc, env.mat_step4['t'][0][
1:11], # batch_size=10
inter_state, return_x=True)
updated_state = inter_state[:, -1]
out_u = Variable(torch.tensor(tout.transpose()), requires_grad=True).type(torch.FloatTensor).cuda()
updated_state = Variable(torch.tensor(updated_state), requires_grad=True).type(torch.FloatTensor).cuda()
back_out_u = torch.zeros(10, 1350, 6).cuda()
for i in range(1350):
back_out_u[:, i, :] = out_u
return back_out_u, updated_state # return out_u
def backward(ctx, grad_output, grad_output1):
Here, I add None for the backward of updated_state
return F.hardtanh(grad_output), None
Define class with STE-funciton:
class StraightThroughEstimator(nn.Module):
def __init__(self, inter_state, x): # inter_state, x
self.state = inter_state
super(StraightThroughEstimator, self).__init__()
def forward(self, inter_state, x):
x, updated_state = STEFunction.apply(x, inter_state)
return x, updated_state
However, I find another error as introduced below.
It tells that the intermedia buffers have been freed when comes to the second time backward calling.
In my code, the control.forced_response applied in the STEFunction class definition, is used to simulate the reponse with respect to driving signals. It needs numpy datatype and the inter_state of last step. The problem could be the
ttemp = torch.squeeze(input_x[:, -1, :]).requires_grad_(True).detach().cpu().numpy()
the inter_state. I save it in the Rnn class as a class member variable.
In order to find the problem I upload the following snippets for reference.
Define RNN
class RNN(nn.Module):
def __init__(self, input_size):
super(RNN, self).__init__()
self.state = None
self.out = None
hidden_size = 32 # 64
self.encoder_embedding = nn.Linear(input_size, hidden_size)
self.encoder_rnn = nn.LSTM(
self.encoder_out = nn.Linear(2*hidden_size, 6) # output 6 signals to drive actuators
self.ste = StraightThroughEstimator(self.state, self.out) # As suggested, thank you
self.decoder_embedding = nn.Linear(6, hidden_size) # from 6 driving signals to LSTM
self.decoder_rnn = nn.LSTM(
hidden_size=hidden_size, # 64
self.decoder_out = nn.Linear(2 * hidden_size, input_size)
def forward(self, x, encode=True, decode=False):
if encode:
x = self.encoder_embedding(x)
r_out, (h_n, h_c) = self.encoder_rnn(x, None) # None for initialization by 0 for hidden layers
out = self.encoder_out(r_out)
**out_u_1, self.state = self.ste(self.state, out)**
out_u = out_u_1[:, -1, :]
return out_u # 20, 899, 1; 10, 6
elif decode:
# out_u = self.decoder(x) # response to sensor signal backwards
x = self.decoder_embedding(x)
r_out, (h_n, h_c) = self.decoder_rnn(x, None)
out_u = self.decoder_out(r_out)
# encoding = self.encoder(x)
# out_u = self.decoder(encoding)
x = self.encoder_embedding(x)
t_out, (h_n, h_c) = self.encoder_rnn(x, None)
out = self.encoder_out(t_out)
p_out = self.decoder_embedding(out)
r_out, (h_n, h_c) = self.decoder_rnn(p_out, None)
out_u = self.decoder_out(r_out)
return out_u
RNN parameter configuration:
n = 1 # Only 1 channel of signal for training
rnn = RNN(n).to(device)
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.0002)
loss_func = nn.MSELoss()
Training process:
print(“Start training …”)
for step in range(EPOCH):
rnn.state = None
for tx, ty in trainloader:
tx = tx.to(device)
ty = ty.to(device)
output = rnn(torch.unsqueeze(tx, dim=2)) # reshape input to be 3D [samples, timesteps, features]
loss = loss_func(output[:, 2].reshape(-1, 1).to(device), ty.reshape(-1, 1))
# clear out the gradients from the last step loss.backward()
# backward propagation: calculate gradients. Success in the first train, fail for the second.
# update the weights
print(“%d / %d, loss = %E” % (step + 1, EPOCH, loss))
I delete all state variables, and it can train successfully. However, It is necessary to pass this variable for response simulation.
For real experiments, this response simulation will be replaced by sensor signals. It is supposed to jump out this problem.
Best regards and Merry Christmas