How to backpropagate a black box generated cost value

Hi Mathias and @albanD:
Thank you for the fast reply and very clear explanation. I fix this bug as following:

Define STE-function

class STEFunction(autograd.Function):

@staticmethod
def forward(ctx, input_x, inter_state):
    ttemp = torch.squeeze(input_x[:, -1, :]).requires_grad_(True).detach().cpu().numpy()
    ttemp = ttemp.transpose()
    if inter_state is None:
        _, tout, inter_state = control.forced_response(env.AMP_signal_acc, env.mat_step4['t'][0][
                                                                           1:11],  # batch_size=10
                                                       ttemp,
                                                       return_x=True)
        updated_state = inter_state[:, -1]
    else:
        inter_state = inter_state.requires_grad_(True).detach().cpu().numpy()  
        _, tout, inter_state = control.forced_response(env.AMP_signal_acc, env.mat_step4['t'][0][
                                                                           1:11],  # batch_size=10
                                                       ttemp,
                                                       inter_state, return_x=True)
        updated_state = inter_state[:, -1]

    out_u = Variable(torch.tensor(tout.transpose()), requires_grad=True).type(torch.FloatTensor).cuda()
    updated_state = Variable(torch.tensor(updated_state), requires_grad=True).type(torch.FloatTensor).cuda()

    back_out_u = torch.zeros(10, 1350, 6).cuda()
    for i in range(1350):
        back_out_u[:, i, :] = out_u
    return back_out_u, updated_state  # return out_u

@staticmethod
def backward(ctx, grad_output, grad_output1):

Here, I add None for the backward of updated_state

    return F.hardtanh(grad_output), None

Define class with STE-funciton:

class StraightThroughEstimator(nn.Module):

def __init__(self, inter_state, x):  # inter_state, x
    self.state = inter_state
    super(StraightThroughEstimator, self).__init__() 

def forward(self, inter_state, x): 
    x, updated_state = STEFunction.apply(x, inter_state)
    return x, updated_state 

However, I find another error as introduced below.
It tells that the intermedia buffers have been freed when comes to the second time backward calling.

In my code, the control.forced_response applied in the STEFunction class definition, is used to simulate the reponse with respect to driving signals. It needs numpy datatype and the inter_state of last step. The problem could be the
ttemp = torch.squeeze(input_x[:, -1, :]).requires_grad_(True).detach().cpu().numpy()
and
the inter_state. I save it in the Rnn class as a class member variable.

In order to find the problem I upload the following snippets for reference.

Define RNN

class RNN(nn.Module):

def __init__(self, input_size):
    super(RNN, self).__init__()
    self.state = None
    self.out = None
    hidden_size = 32  # 64
    self.encoder_embedding = nn.Linear(input_size, hidden_size)
    self.encoder_rnn = nn.LSTM(
        input_size=hidden_size,
        hidden_size=hidden_size, 
        num_layers=1,
        batch_first=True,
        dropout=0.1,
        bidirectional=True,
    )
    self.encoder_out = nn.Linear(2*hidden_size, 6)  # output 6 signals to drive actuators
    self.ste = StraightThroughEstimator(self.state, self.out) # As suggested, thank you

    self.decoder_embedding = nn.Linear(6, hidden_size)  # from 6 driving signals to LSTM
    self.decoder_rnn = nn.LSTM(
        input_size=hidden_size,
        hidden_size=hidden_size,  # 64
        num_layers=1,
        batch_first=True,
        dropout=0.1,
        bidirectional=True,
      )
    self.decoder_out = nn.Linear(2 * hidden_size, input_size)

def forward(self, x, encode=True, decode=False):
    if encode:
        x = self.encoder_embedding(x)
        r_out, (h_n, h_c) = self.encoder_rnn(x, None)  # None for initialization by 0 for hidden layers
        out = self.encoder_out(r_out)

        **out_u_1, self.state = self.ste(self.state, out)**
        out_u = out_u_1[:, -1, :]
        return out_u  # 20, 899, 1;  10, 6
    elif decode:
        # out_u = self.decoder(x)  # response to sensor signal backwards
        x = self.decoder_embedding(x)
        r_out, (h_n, h_c) = self.decoder_rnn(x, None)  
        out_u = self.decoder_out(r_out)

    else:
        # encoding = self.encoder(x)
        # out_u = self.decoder(encoding)
        x = self.encoder_embedding(x)
        t_out, (h_n, h_c) = self.encoder_rnn(x, None)  
        out = self.encoder_out(t_out)
        p_out = self.decoder_embedding(out)
        r_out, (h_n, h_c) = self.decoder_rnn(p_out, None)  
        out_u = self.decoder_out(r_out)
    return out_u

RNN parameter configuration:

n = 1 # Only 1 channel of signal for training
rnn = RNN(n).to(device)
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.0002)
loss_func = nn.MSELoss()
EPOCH = 5

Training process:

print(“Start training …”)
for step in range(EPOCH):
rnn.state = None
for tx, ty in trainloader:
tx = tx.to(device)
ty = ty.to(device)
output = rnn(torch.unsqueeze(tx, dim=2)) # reshape input to be 3D [samples, timesteps, features]
loss = loss_func(output[:, 2].reshape(-1, 1).to(device), ty.reshape(-1, 1))
# clear out the gradients from the last step loss.backward()
optimizer.zero_grad()
# backward propagation: calculate gradients. Success in the first train, fail for the second.
loss.backward()
# update the weights
optimizer.step()
print(“%d / %d, loss = %E” % (step + 1, EPOCH, loss))

I delete all state variables, and it can train successfully. However, It is necessary to pass this variable for response simulation.
For real experiments, this response simulation will be replaced by sensor signals. It is supposed to jump out this problem.

Best regards and Merry Christmas

Chen