In the CNN-LSTM model, how does CNN perform backpropagation?

Now I have built a model that contains three layers of CNN and one layer of LSTM. The backpropagation of LSTM is implemented manually by myself, but how should I implement the backpropagation of CNN? In other words, how should I use pytorch’s automatic derivation to perform CNN backpropagation?

# CNN
class CNNModel(nn. Module):
    def __init__(self,timesteps):

        super().__init__()

        self.timesteps = timesteps;

        # CNN layer 1 - input channel: 128; Output channels: 64; Convolution kernel size: 3; Stride: 1;
        self.conv1 = nn. Conv1d(1, 64, kernel_size=3, stride=1)
        self.act1 = nn. ReLU()
        self.pool1 = nn. MaxPool1d(2, stride=1)

       # CNN layer 2 - input channel: 128; Output channels: 64; Convolution kernel size: 3; Stride: 1;
        self.conv2 = nn. Conv1d(64, 128, kernel_size=3, stride=1)
        self.act2 = nn. ReLU()
        self.pool2 = nn. MaxPool1d(2, stride=1) # pooling layer, pooling window is 2

        # CNN layer 3 - input channel: 128; Output channels: 64; Convolution kernel size: 3; Stride: 1;
        self.conv3 = nn. Conv1d(128, 64, kernel_size=3,stride=1)
        self.act3 = nn. ReLU()
        self.pool3 = nn. MaxPool1d(2, stride=1) # pooling layer, pooling window is 2

        self.flat = nn. Flatten()

    def forward(self, x):

        print("x.shape:");
        print(x.shape);

        # CNN layer 1
        # x1 = x.permute(0, 2, 1); # (batch, input_dim, seq_len)
        x1 = self.act1(self.conv1(x.permute(0,2,1))); # [1,20,1] => [1, 64, 1]

        print("x1.shape:");
        print(x1.shape);

        x1 = self.pool1(x1);

        # CNN layer 2
        x2 = self.act2(self.conv2(x1));
        x2 = self.pool2(x2);

        # CNN layer 3
        x3 = self.act3(self.conv3(x2));
        x3 = self.pool3(x3);

        x4 = self.flat(x3);

        print(x1.shape);
        print(x2.shape);
        print(x3.shape);
        print(x4.shape);

        return x4

class LSTM(object):
    def __init__(self,timesteps,batch_size,input_size,hidden_size,output_size):
        
        self.times = 0;
        
        self.timesteps = timesteps;
        self.batch_size = batch_size;
        
        self.input_size = input_size;
        self.hidden_size = hidden_size;
        self.output_size = output_size;

        self.Wfh,self.Wfx,self.bf = self.Weight_bias(self.input_size,self.hidden_size);
        self.Wih,self.Wix,self.bi = self.Weight_bias(self.input_size,self.hidden_size);
        self.Woh,self.Wox,self.bo = self.Weight_bias(self.input_size,self.hidden_size);
        self.Wch,self.Wcx,self.bc = self.Weight_bias(self.input_size,self.hidden_size);
        
        self.Wp = torch.randn(self.output_size,self.hidden_size);
        self.bp = torch.randn(self.output_size);

        self.f = torch.zeros(self.batch_size,self.hidden_size);
        self.i = torch.zeros(self.batch_size,self.hidden_size);
        self.o = torch.zeros(self.batch_size,self.hidden_size);
        self.ct = torch.zeros(self.batch_size,self.hidden_size);
        
        self.h = torch.zeros(self.batch_size,self.hidden_size);
        self.c = torch.zeros(self.batch_size,self.hidden_size);
        
        self.fList = [];
        self.iList = [];
        self.oList = [];
        self.ctList = [];
        
        self.hList = [];
        self.cList = [];
        
        self.preList=[];
        
        self.fList.append(self.f);
        self.iList.append(self.i);
        self.oList.append(self.o);
        self.ctList.append(self.ct);
        
        self.hList.append(self.h);
        self.cList.append(self.c);
        
    def Weight_bias(self,input_size,hidden_size):
        return (torch.randn(hidden_size,hidden_size),
                torch.randn(input_size,hidden_size),
                torch.randn(hidden_size));
    
    def Weight_bias_grad(self,input_size,hidden_size):
        return (torch.zeros(hidden_size,hidden_size),
                torch.zeros(input_size,hidden_size),
                torch.zeros(hidden_size));
        
    def forward(self,x):
        
        for i in range(len(x)):
            
            self.times += 1;
            
            self.f = self.Sigmoid_forward(self.hList[-1] @ self.Wfh + x[i] @ self.Wfx + self.bf);
            self.i = self.Sigmoid_forward(self.hList[-1] @ self.Wih + x[i] @ self.Wix + self.bi);
            self.o = self.Sigmoid_forward(self.hList[-1] @ self.Woh + x[i] @ self.Wox + self.bo);
            self.ct = self.Tanh_forward(self.hList[-1] @ self.Wch + x[i] @ self.Wcx + self.bc);
                
            self.c = self.f * self.cList[-1] + self.i * self.ct;
            self.h = self.o * self.Tanh_forward(self.c);
            
            self.fList.append(self.f);
            self.iList.append(self.i);
            self.oList.append(self.o);
            self.ctList.append(self.ct);
            
            self.hList.append(self.h);
            self.cList.append(self.c);
        
        return self.prediction();
        
    def prediction(self):
        
        pre = self.hList[-1] @ self.Wp.T + self.bp;
        self.preList.append(pre);
        
        return pre;
    
    def backward(self,x,grad,y_grad):
        
        self.delta_Wfh,self.delta_Wfx,self.delta_bf = self.Weight_bias_grad(self.input_size,self.hidden_size);
        self.delta_Wih,self.delta_Wix,self.delta_bi = self.Weight_bias_grad(self.input_size,self.hidden_size);
        self.delta_Woh,self.delta_Wox,self.delta_bo = self.Weight_bias_grad(self.input_size,self.hidden_size);
        self.delta_Wch,self.delta_Wcx,self.delta_bc = self.Weight_bias_grad(self.input_size,self.hidden_size);
        
        self.delta_Wp = torch.zeros(self.output_size,self.hidden_size);
        self.delta_bp = torch.zeros(self.output_size);

        self.delta_hList = self.init_delta();
        self.delta_cList = self.init_delta();
        
        self.delta_fList = self.init_delta();
        self.delta_iList = self.init_delta();
        self.delta_oList = self.init_delta();
        self.delta_ctList = self.init_delta();
        
        self.delta_hList[-1] = grad;
        
        for k in range(self.times,0,-1):
            
            self.compute_gate_backward(x,k);
            
        self.compute_Weight_bias_backward(x,y_grad);
            
    def init_delta(self):
        
        X = [];
        
        for i in range(self.times + 1):
            X.append(torch.zeros(self.batch_size,self.hidden_size));
            
        return X;
    
    def compute_gate_backward(self,x,k):
        
        f = self.fList[k];
        i = self.iList[k];
        o = self.oList[k];
        ct = self.ctList[k];
        
        h = self.hList[k];
        c = self.cList[k];
        
        c_pre = self.cList[k-1];
        
        delta_hk = self.delta_hList[k];
        
        if(k == self.times):
            delta_ck = delta_hk * o * self.Tanh_backward(c);
        else:
            delta_ck = delta_hk * o * self.Tanh_backward(c) + self.delta_cList[k+1] * self.fList[k+1];
            
        delta_ctk = delta_ck * i;
        delta_fk = delta_ck * c_pre;
        delta_ik = delta_ck * ct;
        delta_ok = delta_hk * self.Tanh_forward(c);
        
        delta_hkpre = delta_fk * self.Sigmoid_backward(h @ self.Wfh + x[k-1] @ self.Wfx + self.bf) @ self.Wfh + delta_ik * self.Sigmoid_backward(h @ self.Wih + x[k-1] @ self.Wix + self.bi) @ self.Wih +delta_ok * self.Sigmoid_backward(h @ self.Woh + x[k-1] @ self.Wox + self.bo) @ self.Woh +delta_ctk * self.Tanh_backward(h @ self.Wch + x[k-1] @ self.Wcx + self.bc) @ self.Wch;
       
        self.delta_hList[k-1] = delta_hkpre;
        self.delta_cList[k] = delta_ck;
        
        self.delta_fList[k] = delta_fk;
        self.delta_iList[k] = delta_ik;
        self.delta_oList[k] = delta_ok;
        self.delta_ctList[k] = delta_ctk;
        
    def compute_Weight_bias_backward(self,x,y_grad):

        for t in range(self.times,0,-1):
            
            sum_delta_Wfh,sum_delta_Wfx,sum_delta_bf = self.Weight_bias_grad(self.input_size,self.hidden_size);
            sum_delta_Wih,sum_delta_Wix,sum_delta_bi = self.Weight_bias_grad(self.input_size,self.hidden_size);
            sum_delta_Woh,sum_delta_Wox,sum_delta_bo = self.Weight_bias_grad(self.input_size,self.hidden_size);
            sum_delta_Wch,sum_delta_Wcx,sum_delta_bc = self.Weight_bias_grad(self.input_size,self.hidden_size);
            
            sum_delta_Wp = torch.zeros(self.output_size,self.hidden_size);
            sum_delta_bp = torch.zeros(self.output_size);

            for i in range(self.batch_size):
                
                if t == self.times :
                    
                    delta_Wp = y_grad[i].unsqueeze(dim=1) * self.hList[-1][i].unsqueeze(dim=0);
                    delta_bp = y_grad[i];
                    
                    sum_delta_Wp += delta_Wp;
                    sum_delta_bp += delta_bp;       
                
                delta_Wfh = (self.delta_fList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Wfh + x[t-1][i] @ self.Wfx + self.bf)) * self.hList[t-1][i].unsqueeze(dim=1);
                delta_Wfx = (self.delta_fList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Wfh + x[t-1][i] @ self.Wfx + self.bf)).unsqueeze(dim=1) * x[t-1][i].unsqueeze(dim=0);
                delta_bf = self.delta_fList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Wfh + x[t-1][i] @ self.Wfx + self.bf);
                
                
                delta_Wih = (self.delta_iList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Wih + x[t-1][i] @ self.Wix + self.bi)) * self.hList[t-1][i].unsqueeze(dim=1);
                delta_Wix = (self.delta_iList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Wih + x[t-1][i] @ self.Wix + self.bi)).unsqueeze(dim=1) * x[t-1][i].unsqueeze(dim=0);
                delta_bi = self.delta_iList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Wih + x[t-1][i] @ self.Wix + self.bi);

                delta_Wch = (self.delta_ctList[t][i] * self.Tanh_backward(self.hList[t-1][i] @ self.Wch + x[t-1][i] @ self.Wcx + self.bc)) * self.hList[t-1][i].unsqueeze(dim=1);
                delta_Wcx = (self.delta_ctList[t][i] * self.Tanh_backward(self.hList[t-1][i] @ self.Wch + x[t-1][i] @ self.Wcx + self.bc)).unsqueeze(dim=1) * x[t-1][i].unsqueeze(dim=0);
                delta_bc = self.delta_ctList[t][i] * self.Tanh_backward(self.hList[t-1][i] @ self.Wch + x[t-1][i] @ self.Wcx + self.bc);

                delta_Woh = (self.delta_oList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Woh + x[t-1][i] @ self.Wox + self.bo)) * self.hList[t-1][i].unsqueeze(dim=1);
                delta_Wox = (self.delta_oList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Woh + x[t-1][i] @ self.Wox + self.bo)).unsqueeze(dim=1) * x[t-1][i].unsqueeze(dim=0);
                delta_bo = self.delta_oList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Woh + x[t-1][i] @ self.Wox + self.bo);
                
                sum_delta_Wfh += delta_Wfh;
                sum_delta_Wfx += delta_Wfx.T;
                sum_delta_bf += delta_bf;

                sum_delta_Wih += delta_Wih;
                sum_delta_Wix += delta_Wix.T;
                sum_delta_bi += delta_bi;

                sum_delta_Wch += delta_Wch;
                sum_delta_Wcx += delta_Wcx.T;
                sum_delta_bc += delta_bc;

                sum_delta_Woh += delta_Woh;
                sum_delta_Wox += delta_Wox.T;
                sum_delta_bo += delta_bo;
            
            self.delta_Wfh += sum_delta_Wfh / self.batch_size;
            self.delta_Wfx += sum_delta_Wfx / self.batch_size;
            self.delta_bf += sum_delta_bf / self.batch_size;
            
            self.delta_Wih += sum_delta_Wih / self.batch_size;
            self.delta_Wix += sum_delta_Wix / self.batch_size;
            self.delta_bi += sum_delta_bi / self.batch_size;
            
            self.delta_Wch += sum_delta_Wch / self.batch_size;
            self.delta_Wcx += sum_delta_Wcx / self.batch_size;
            self.delta_bc += sum_delta_bc / self.batch_size;
            
            self.delta_Woh += sum_delta_Woh / self.batch_size;
            self.delta_Wox += sum_delta_Wox / self.batch_size;
            self.delta_bo += sum_delta_bo / self.batch_size;
            
            self.delta_Wp += sum_delta_Wp / self.batch_size;
            self.delta_bp += sum_delta_bp / self.batch_size;

    #   Adam
    def init_unit_adam_states(self):
        
        Wh = torch.zeros(self.hidden_size,self.hidden_size);
        Wx = torch.zeros(self.input_size,self.hidden_size);
        b  = torch.zeros(self.hidden_size);
        
        return Wh,Wx,b;
    
    def init_adam_states(self):
        
        self.back_times = 1; # 优化器,计算的次数
        self.beta1,self.beta2,self.eps = 0.9,0.99,1e-6;
        
        self.v_Wfh,self.v_Wfx,self.v_bf = self.init_unit_adam_states();
        self.v_Wih,self.v_Wix,self.v_bi = self.init_unit_adam_states();
        self.v_Woh,self.v_Wox,self.v_bo = self.init_unit_adam_states();
        self.v_Wch,self.v_Wcx,self.v_bc = self.init_unit_adam_states();
        
        self.s_Wfh,self.s_Wfx,self.s_bf = self.init_unit_adam_states();
        self.s_Wih,self.s_Wix,self.s_bi = self.init_unit_adam_states();
        self.s_Woh,self.s_Wox,self.s_bo = self.init_unit_adam_states();
        self.s_Wch,self.s_Wcx,self.s_bc = self.init_unit_adam_states();
        
        self.v_Wp = torch.zeros(self.output_size,self.hidden_size);
        self.v_bp = torch.zeros(self.output_size);
        
        self.s_Wp = torch.zeros(self.output_size,self.hidden_size);
        self.s_bp = torch.zeros(self.output_size);

    def update(self,lr):     
        
        self.v_Wfh = self.beta1 * self.v_Wfh + (1 - self.beta1) * self.delta_Wfh;
        self.v_Wfx = self.beta1 * self.v_Wfx + (1 - self.beta1) * self.delta_Wfx;
        self.v_bf  = self.beta1 * self.v_bf  + (1 - self.beta1) * self.delta_bf;
        
        self.v_Wih = self.beta1 * self.v_Wih + (1 - self.beta1) * self.delta_Wih;
        self.v_Wix = self.beta1 * self.v_Wix + (1 - self.beta1) * self.delta_Wix;
        self.v_bi  = self.beta1 * self.v_bi  + (1 - self.beta1) * self.delta_bi;
        
        self.v_Woh = self.beta1 * self.v_Woh + (1 - self.beta1) * self.delta_Woh;
        self.v_Wox = self.beta1 * self.v_Wox + (1 - self.beta1) * self.delta_Wox;
        self.v_bo  = self.beta1 * self.v_bo  + (1 - self.beta1) * self.delta_bo;
        
        self.v_Wch = self.beta1 * self.v_Wch + (1 - self.beta1) * self.delta_Wch;
        self.v_Wcx = self.beta1 * self.v_Wcx + (1 - self.beta1) * self.delta_Wcx;
        self.v_bc  = self.beta1 * self.v_bc  + (1 - self.beta1) * self.delta_bc;
        
        self.v_Wp  = self.beta1 * self.v_Wp  + (1 - self.beta1) * self.delta_Wp;
        self.v_bp  = self.beta1 * self.v_bp  + (1 - self.beta1) * self.delta_bp;
        
        # 标准化动量
        
        v_Wfh_bias_corr = self.v_Wfh / (1 - self.beta1 ** self.back_times);
        v_Wfx_bias_corr = self.v_Wfx / (1 - self.beta1 ** self.back_times);
        v_bf_bias_corr  = self.v_bf  / (1 - self.beta1 ** self.back_times);
        
        v_Wih_bias_corr = self.v_Wih / (1 - self.beta1 ** self.back_times);
        v_Wix_bias_corr = self.v_Wix / (1 - self.beta1 ** self.back_times);
        v_bi_bias_corr  = self.v_bi  / (1 - self.beta1 ** self.back_times);
        
        v_Woh_bias_corr = self.v_Woh / (1 - self.beta1 ** self.back_times);
        v_Wox_bias_corr = self.v_Wox / (1 - self.beta1 ** self.back_times);
        v_bo_bias_corr  = self.v_bo  / (1 - self.beta1 ** self.back_times);
        
        v_Wch_bias_corr = self.v_Wch / (1 - self.beta1 ** self.back_times);
        v_Wcx_bias_corr = self.v_Wcx / (1 - self.beta1 ** self.back_times);
        v_bc_bias_corr  = self.v_bc  / (1 - self.beta1 ** self.back_times);
        
        v_Wp_bias_corr  = self.v_Wp  / (1 - self.beta1 ** self.back_times);
        v_bp_bias_corr  = self.v_bp  / (1 - self.beta1 ** self.back_times);
        
        self.s_Wfh = self.beta2 * self.s_Wfh + (1 - self.beta2) * torch.square(self.delta_Wfh);
        self.s_Wfx = self.beta2 * self.s_Wfx + (1 - self.beta2) * torch.square(self.delta_Wfx);
        self.s_bf  = self.beta2 * self.s_bf  + (1 - self.beta2) * torch.square(self.delta_bf);
        
        self.s_Wih = self.beta2 * self.s_Wih + (1 - self.beta2) * torch.square(self.delta_Wih);
        self.s_Wix = self.beta2 * self.s_Wix + (1 - self.beta2) * torch.square(self.delta_Wix);
        self.s_bi  = self.beta2 * self.s_bi  + (1 - self.beta2) * torch.square(self.delta_bi);
        
        self.s_Woh = self.beta2 * self.s_Woh + (1 - self.beta2) * torch.square(self.delta_Woh);
        self.s_Wox = self.beta2 * self.s_Wox + (1 - self.beta2) * torch.square(self.delta_Wox);
        self.s_bo  = self.beta2 * self.s_bo  + (1 - self.beta2) * torch.square(self.delta_bo);
        
        self.s_Wch = self.beta2 * self.s_Wch + (1 - self.beta2) * torch.square(self.delta_Wch);
        self.s_Wcx = self.beta2 * self.s_Wcx + (1 - self.beta2) * torch.square(self.delta_Wcx);
        self.s_bc  = self.beta2 * self.s_bc  + (1 - self.beta2) * torch.square(self.delta_bc);
        
        self.s_Wp  = self.beta2 * self.s_Wp  + (1 - self.beta2) * torch.square(self.delta_Wp);
        self.s_bp  = self.beta2 * self.s_bp  + (1 - self.beta2) * torch.square(self.delta_bp);
        
        s_Wfh_bias_corr = self.s_Wfh / (1 - self.beta2 ** self.back_times);
        s_Wfx_bias_corr = self.s_Wfx / (1 - self.beta2 ** self.back_times);
        s_bf_bias_corr  = self.s_bf  / (1 - self.beta2 ** self.back_times);
        
        s_Wih_bias_corr = self.s_Wih / (1 - self.beta2 ** self.back_times);
        s_Wix_bias_corr = self.s_Wix / (1 - self.beta2 ** self.back_times);
        s_bi_bias_corr  = self.s_bi  / (1 - self.beta2 ** self.back_times);
        
        s_Woh_bias_corr = self.s_Woh / (1 - self.beta2 ** self.back_times);
        s_Wox_bias_corr = self.s_Wox / (1 - self.beta2 ** self.back_times);
        s_bo_bias_corr  = self.s_bo  / (1 - self.beta2 ** self.back_times);
        
        s_Wch_bias_corr = self.s_Wch / (1 - self.beta2 ** self.back_times);
        s_Wcx_bias_corr = self.s_Wcx / (1 - self.beta2 ** self.back_times);
        s_bc_bias_corr  = self.s_bc  / (1 - self.beta2 ** self.back_times);
        
        s_Wp_bias_corr  = self.s_Wp  / (1 - self.beta2 ** self.back_times);
        s_bp_bias_corr  = self.s_bp  / (1 - self.beta2 ** self.back_times);
        
        self.Wfh -= lr * v_Wfh_bias_corr / (torch.sqrt(s_Wfh_bias_corr) + self.eps);
        self.Wfx -= lr * v_Wfx_bias_corr / (torch.sqrt(s_Wfx_bias_corr) + self.eps);
        self.bf  -= lr * v_bf_bias_corr  / (torch.sqrt(s_bf_bias_corr)  + self.eps);
        
        self.Wih -= lr * v_Wih_bias_corr / (torch.sqrt(s_Wih_bias_corr) + self.eps);
        self.Wix -= lr * v_Wix_bias_corr / (torch.sqrt(s_Wix_bias_corr) + self.eps);
        self.bi  -= lr * v_bi_bias_corr  / (torch.sqrt(s_bi_bias_corr)  + self.eps);
        
        self.Woh -= lr * v_Woh_bias_corr / (torch.sqrt(s_Woh_bias_corr) + self.eps);
        self.Wox -= lr * v_Wox_bias_corr / (torch.sqrt(s_Wox_bias_corr) + self.eps);
        self.bo  -= lr * v_bo_bias_corr  / (torch.sqrt(s_bo_bias_corr)  + self.eps);
        
        self.Wch -= lr * v_Wch_bias_corr / (torch.sqrt(s_Wch_bias_corr) + self.eps);
        self.Wcx -= lr * v_Wcx_bias_corr / (torch.sqrt(s_Wcx_bias_corr) + self.eps);
        self.bc  -= lr * v_bc_bias_corr  / (torch.sqrt(s_bc_bias_corr)  + self.eps);
        
        self.Wp  -= lr * v_Wp_bias_corr  / (torch.sqrt(s_Wp_bias_corr)  + self.eps);
        self.bp  -= lr * v_bp_bias_corr  / (torch.sqrt(s_bp_bias_corr)  + self.eps);
        
        self.back_times += 1;
        
    def reset(self):
        
        self.times = 0;
        
        self.fList = [];
        self.iList = [];
        self.oList = [];
        self.cList = [];
        
        self.preList=[];
        
        self.f = torch.zeros(self.batch_size,self.hidden_size);
        self.i = torch.zeros(self.batch_size,self.hidden_size);
        self.o = torch.zeros(self.batch_size,self.hidden_size);
        self.ct = torch.zeros(self.batch_size,self.hidden_size);
        
        self.fList.append(self.f);
        self.iList.append(self.i);
        self.oList.append(self.o);
        self.ctList.append(self.ct);
        
        self.hList = [torch.zeros(self.batch_size,self.hidden_size)];
        self.cList = [torch.zeros(self.batch_size,self.hidden_size)];
        
    def Sigmoid_forward(self,x):
        return torch.sigmoid(x);

    def Sigmoid_backward(self,x):
        return self.Sigmoid_forward(x) * (1 - self.Sigmoid_forward(x));

    def Tanh_forward(self,x):
        return torch.tanh(x);

    def Tanh_backward(self,x):
        return 1 - (self.Tanh_forward(x) * self.Tanh_forward(x));
CNN = CNNModel(train_timesteps);
criterion = nn.MSELoss();
optimizer = torch.optim.Adam(CNN.parameters(),lr = 0.0005);

l = LSTM(train_timesteps,batch_size,704,3,1);
lr = 0.0009;
lossList = [];
l.init_adam_states();

for epoch in range(train_Epoch):
    
    for i in range(len(train_X)):
        
        # train_X[i]:(batchsize,timesteps,inputsize)
        # x = train_X[i].permute(1, 0, 2); # (timesteps,batchsize,inputsize)
        x = train_X[i]; # (1,20,1)
        
        # forward
        pre_CNN = CNN(x);
        
        with torch.no_grad():
            
            l.forward(torch.unsqueeze(pre_CNN,0));  # (1,1,704)
            pre = l.prediction();
            
            # loss
            loss = (pre - train_Y[i]) * (pre - train_Y[i]);
            loss = loss.mean();
        
            y_grad = 2.0 * (pre - train_Y[i]);

            h_grad = y_grad @ l.Wp;

            # backward
            l.backward(torch.unsqueeze(pre_CNN,0),h_grad,y_grad);

            # update
            l.update(lr);

            l.reset();
            
        # CNN
        pre_CNN
        loss_CNN = criterion(pre,train_Y[i]);
        
        optimizer.zero_grad();
        loss_CNN.backward();

    lossList.append(loss);
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch},loss:{loss}");
    
plt.plot(lossList);

plt.title("Loss");
plt.xlabel("Epoch");
plt.ylabel("loss");

plt.show();