Now I have built a model that contains three layers of CNN and one layer of LSTM. The backpropagation of LSTM is implemented manually by myself, but how should I implement the backpropagation of CNN? In other words, how should I use pytorch’s automatic derivation to perform CNN backpropagation?
# CNN
class CNNModel(nn. Module):
def __init__(self,timesteps):
super().__init__()
self.timesteps = timesteps;
# CNN layer 1 - input channel: 128; Output channels: 64; Convolution kernel size: 3; Stride: 1;
self.conv1 = nn. Conv1d(1, 64, kernel_size=3, stride=1)
self.act1 = nn. ReLU()
self.pool1 = nn. MaxPool1d(2, stride=1)
# CNN layer 2 - input channel: 128; Output channels: 64; Convolution kernel size: 3; Stride: 1;
self.conv2 = nn. Conv1d(64, 128, kernel_size=3, stride=1)
self.act2 = nn. ReLU()
self.pool2 = nn. MaxPool1d(2, stride=1) # pooling layer, pooling window is 2
# CNN layer 3 - input channel: 128; Output channels: 64; Convolution kernel size: 3; Stride: 1;
self.conv3 = nn. Conv1d(128, 64, kernel_size=3,stride=1)
self.act3 = nn. ReLU()
self.pool3 = nn. MaxPool1d(2, stride=1) # pooling layer, pooling window is 2
self.flat = nn. Flatten()
def forward(self, x):
print("x.shape:");
print(x.shape);
# CNN layer 1
# x1 = x.permute(0, 2, 1); # (batch, input_dim, seq_len)
x1 = self.act1(self.conv1(x.permute(0,2,1))); # [1,20,1] => [1, 64, 1]
print("x1.shape:");
print(x1.shape);
x1 = self.pool1(x1);
# CNN layer 2
x2 = self.act2(self.conv2(x1));
x2 = self.pool2(x2);
# CNN layer 3
x3 = self.act3(self.conv3(x2));
x3 = self.pool3(x3);
x4 = self.flat(x3);
print(x1.shape);
print(x2.shape);
print(x3.shape);
print(x4.shape);
return x4
class LSTM(object):
def __init__(self,timesteps,batch_size,input_size,hidden_size,output_size):
self.times = 0;
self.timesteps = timesteps;
self.batch_size = batch_size;
self.input_size = input_size;
self.hidden_size = hidden_size;
self.output_size = output_size;
self.Wfh,self.Wfx,self.bf = self.Weight_bias(self.input_size,self.hidden_size);
self.Wih,self.Wix,self.bi = self.Weight_bias(self.input_size,self.hidden_size);
self.Woh,self.Wox,self.bo = self.Weight_bias(self.input_size,self.hidden_size);
self.Wch,self.Wcx,self.bc = self.Weight_bias(self.input_size,self.hidden_size);
self.Wp = torch.randn(self.output_size,self.hidden_size);
self.bp = torch.randn(self.output_size);
self.f = torch.zeros(self.batch_size,self.hidden_size);
self.i = torch.zeros(self.batch_size,self.hidden_size);
self.o = torch.zeros(self.batch_size,self.hidden_size);
self.ct = torch.zeros(self.batch_size,self.hidden_size);
self.h = torch.zeros(self.batch_size,self.hidden_size);
self.c = torch.zeros(self.batch_size,self.hidden_size);
self.fList = [];
self.iList = [];
self.oList = [];
self.ctList = [];
self.hList = [];
self.cList = [];
self.preList=[];
self.fList.append(self.f);
self.iList.append(self.i);
self.oList.append(self.o);
self.ctList.append(self.ct);
self.hList.append(self.h);
self.cList.append(self.c);
def Weight_bias(self,input_size,hidden_size):
return (torch.randn(hidden_size,hidden_size),
torch.randn(input_size,hidden_size),
torch.randn(hidden_size));
def Weight_bias_grad(self,input_size,hidden_size):
return (torch.zeros(hidden_size,hidden_size),
torch.zeros(input_size,hidden_size),
torch.zeros(hidden_size));
def forward(self,x):
for i in range(len(x)):
self.times += 1;
self.f = self.Sigmoid_forward(self.hList[-1] @ self.Wfh + x[i] @ self.Wfx + self.bf);
self.i = self.Sigmoid_forward(self.hList[-1] @ self.Wih + x[i] @ self.Wix + self.bi);
self.o = self.Sigmoid_forward(self.hList[-1] @ self.Woh + x[i] @ self.Wox + self.bo);
self.ct = self.Tanh_forward(self.hList[-1] @ self.Wch + x[i] @ self.Wcx + self.bc);
self.c = self.f * self.cList[-1] + self.i * self.ct;
self.h = self.o * self.Tanh_forward(self.c);
self.fList.append(self.f);
self.iList.append(self.i);
self.oList.append(self.o);
self.ctList.append(self.ct);
self.hList.append(self.h);
self.cList.append(self.c);
return self.prediction();
def prediction(self):
pre = self.hList[-1] @ self.Wp.T + self.bp;
self.preList.append(pre);
return pre;
def backward(self,x,grad,y_grad):
self.delta_Wfh,self.delta_Wfx,self.delta_bf = self.Weight_bias_grad(self.input_size,self.hidden_size);
self.delta_Wih,self.delta_Wix,self.delta_bi = self.Weight_bias_grad(self.input_size,self.hidden_size);
self.delta_Woh,self.delta_Wox,self.delta_bo = self.Weight_bias_grad(self.input_size,self.hidden_size);
self.delta_Wch,self.delta_Wcx,self.delta_bc = self.Weight_bias_grad(self.input_size,self.hidden_size);
self.delta_Wp = torch.zeros(self.output_size,self.hidden_size);
self.delta_bp = torch.zeros(self.output_size);
self.delta_hList = self.init_delta();
self.delta_cList = self.init_delta();
self.delta_fList = self.init_delta();
self.delta_iList = self.init_delta();
self.delta_oList = self.init_delta();
self.delta_ctList = self.init_delta();
self.delta_hList[-1] = grad;
for k in range(self.times,0,-1):
self.compute_gate_backward(x,k);
self.compute_Weight_bias_backward(x,y_grad);
def init_delta(self):
X = [];
for i in range(self.times + 1):
X.append(torch.zeros(self.batch_size,self.hidden_size));
return X;
def compute_gate_backward(self,x,k):
f = self.fList[k];
i = self.iList[k];
o = self.oList[k];
ct = self.ctList[k];
h = self.hList[k];
c = self.cList[k];
c_pre = self.cList[k-1];
delta_hk = self.delta_hList[k];
if(k == self.times):
delta_ck = delta_hk * o * self.Tanh_backward(c);
else:
delta_ck = delta_hk * o * self.Tanh_backward(c) + self.delta_cList[k+1] * self.fList[k+1];
delta_ctk = delta_ck * i;
delta_fk = delta_ck * c_pre;
delta_ik = delta_ck * ct;
delta_ok = delta_hk * self.Tanh_forward(c);
delta_hkpre = delta_fk * self.Sigmoid_backward(h @ self.Wfh + x[k-1] @ self.Wfx + self.bf) @ self.Wfh + delta_ik * self.Sigmoid_backward(h @ self.Wih + x[k-1] @ self.Wix + self.bi) @ self.Wih +delta_ok * self.Sigmoid_backward(h @ self.Woh + x[k-1] @ self.Wox + self.bo) @ self.Woh +delta_ctk * self.Tanh_backward(h @ self.Wch + x[k-1] @ self.Wcx + self.bc) @ self.Wch;
self.delta_hList[k-1] = delta_hkpre;
self.delta_cList[k] = delta_ck;
self.delta_fList[k] = delta_fk;
self.delta_iList[k] = delta_ik;
self.delta_oList[k] = delta_ok;
self.delta_ctList[k] = delta_ctk;
def compute_Weight_bias_backward(self,x,y_grad):
for t in range(self.times,0,-1):
sum_delta_Wfh,sum_delta_Wfx,sum_delta_bf = self.Weight_bias_grad(self.input_size,self.hidden_size);
sum_delta_Wih,sum_delta_Wix,sum_delta_bi = self.Weight_bias_grad(self.input_size,self.hidden_size);
sum_delta_Woh,sum_delta_Wox,sum_delta_bo = self.Weight_bias_grad(self.input_size,self.hidden_size);
sum_delta_Wch,sum_delta_Wcx,sum_delta_bc = self.Weight_bias_grad(self.input_size,self.hidden_size);
sum_delta_Wp = torch.zeros(self.output_size,self.hidden_size);
sum_delta_bp = torch.zeros(self.output_size);
for i in range(self.batch_size):
if t == self.times :
delta_Wp = y_grad[i].unsqueeze(dim=1) * self.hList[-1][i].unsqueeze(dim=0);
delta_bp = y_grad[i];
sum_delta_Wp += delta_Wp;
sum_delta_bp += delta_bp;
delta_Wfh = (self.delta_fList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Wfh + x[t-1][i] @ self.Wfx + self.bf)) * self.hList[t-1][i].unsqueeze(dim=1);
delta_Wfx = (self.delta_fList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Wfh + x[t-1][i] @ self.Wfx + self.bf)).unsqueeze(dim=1) * x[t-1][i].unsqueeze(dim=0);
delta_bf = self.delta_fList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Wfh + x[t-1][i] @ self.Wfx + self.bf);
delta_Wih = (self.delta_iList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Wih + x[t-1][i] @ self.Wix + self.bi)) * self.hList[t-1][i].unsqueeze(dim=1);
delta_Wix = (self.delta_iList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Wih + x[t-1][i] @ self.Wix + self.bi)).unsqueeze(dim=1) * x[t-1][i].unsqueeze(dim=0);
delta_bi = self.delta_iList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Wih + x[t-1][i] @ self.Wix + self.bi);
delta_Wch = (self.delta_ctList[t][i] * self.Tanh_backward(self.hList[t-1][i] @ self.Wch + x[t-1][i] @ self.Wcx + self.bc)) * self.hList[t-1][i].unsqueeze(dim=1);
delta_Wcx = (self.delta_ctList[t][i] * self.Tanh_backward(self.hList[t-1][i] @ self.Wch + x[t-1][i] @ self.Wcx + self.bc)).unsqueeze(dim=1) * x[t-1][i].unsqueeze(dim=0);
delta_bc = self.delta_ctList[t][i] * self.Tanh_backward(self.hList[t-1][i] @ self.Wch + x[t-1][i] @ self.Wcx + self.bc);
delta_Woh = (self.delta_oList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Woh + x[t-1][i] @ self.Wox + self.bo)) * self.hList[t-1][i].unsqueeze(dim=1);
delta_Wox = (self.delta_oList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Woh + x[t-1][i] @ self.Wox + self.bo)).unsqueeze(dim=1) * x[t-1][i].unsqueeze(dim=0);
delta_bo = self.delta_oList[t][i] * self.Sigmoid_backward(self.hList[t-1][i] @ self.Woh + x[t-1][i] @ self.Wox + self.bo);
sum_delta_Wfh += delta_Wfh;
sum_delta_Wfx += delta_Wfx.T;
sum_delta_bf += delta_bf;
sum_delta_Wih += delta_Wih;
sum_delta_Wix += delta_Wix.T;
sum_delta_bi += delta_bi;
sum_delta_Wch += delta_Wch;
sum_delta_Wcx += delta_Wcx.T;
sum_delta_bc += delta_bc;
sum_delta_Woh += delta_Woh;
sum_delta_Wox += delta_Wox.T;
sum_delta_bo += delta_bo;
self.delta_Wfh += sum_delta_Wfh / self.batch_size;
self.delta_Wfx += sum_delta_Wfx / self.batch_size;
self.delta_bf += sum_delta_bf / self.batch_size;
self.delta_Wih += sum_delta_Wih / self.batch_size;
self.delta_Wix += sum_delta_Wix / self.batch_size;
self.delta_bi += sum_delta_bi / self.batch_size;
self.delta_Wch += sum_delta_Wch / self.batch_size;
self.delta_Wcx += sum_delta_Wcx / self.batch_size;
self.delta_bc += sum_delta_bc / self.batch_size;
self.delta_Woh += sum_delta_Woh / self.batch_size;
self.delta_Wox += sum_delta_Wox / self.batch_size;
self.delta_bo += sum_delta_bo / self.batch_size;
self.delta_Wp += sum_delta_Wp / self.batch_size;
self.delta_bp += sum_delta_bp / self.batch_size;
# Adam
def init_unit_adam_states(self):
Wh = torch.zeros(self.hidden_size,self.hidden_size);
Wx = torch.zeros(self.input_size,self.hidden_size);
b = torch.zeros(self.hidden_size);
return Wh,Wx,b;
def init_adam_states(self):
self.back_times = 1; # 优化器,计算的次数
self.beta1,self.beta2,self.eps = 0.9,0.99,1e-6;
self.v_Wfh,self.v_Wfx,self.v_bf = self.init_unit_adam_states();
self.v_Wih,self.v_Wix,self.v_bi = self.init_unit_adam_states();
self.v_Woh,self.v_Wox,self.v_bo = self.init_unit_adam_states();
self.v_Wch,self.v_Wcx,self.v_bc = self.init_unit_adam_states();
self.s_Wfh,self.s_Wfx,self.s_bf = self.init_unit_adam_states();
self.s_Wih,self.s_Wix,self.s_bi = self.init_unit_adam_states();
self.s_Woh,self.s_Wox,self.s_bo = self.init_unit_adam_states();
self.s_Wch,self.s_Wcx,self.s_bc = self.init_unit_adam_states();
self.v_Wp = torch.zeros(self.output_size,self.hidden_size);
self.v_bp = torch.zeros(self.output_size);
self.s_Wp = torch.zeros(self.output_size,self.hidden_size);
self.s_bp = torch.zeros(self.output_size);
def update(self,lr):
self.v_Wfh = self.beta1 * self.v_Wfh + (1 - self.beta1) * self.delta_Wfh;
self.v_Wfx = self.beta1 * self.v_Wfx + (1 - self.beta1) * self.delta_Wfx;
self.v_bf = self.beta1 * self.v_bf + (1 - self.beta1) * self.delta_bf;
self.v_Wih = self.beta1 * self.v_Wih + (1 - self.beta1) * self.delta_Wih;
self.v_Wix = self.beta1 * self.v_Wix + (1 - self.beta1) * self.delta_Wix;
self.v_bi = self.beta1 * self.v_bi + (1 - self.beta1) * self.delta_bi;
self.v_Woh = self.beta1 * self.v_Woh + (1 - self.beta1) * self.delta_Woh;
self.v_Wox = self.beta1 * self.v_Wox + (1 - self.beta1) * self.delta_Wox;
self.v_bo = self.beta1 * self.v_bo + (1 - self.beta1) * self.delta_bo;
self.v_Wch = self.beta1 * self.v_Wch + (1 - self.beta1) * self.delta_Wch;
self.v_Wcx = self.beta1 * self.v_Wcx + (1 - self.beta1) * self.delta_Wcx;
self.v_bc = self.beta1 * self.v_bc + (1 - self.beta1) * self.delta_bc;
self.v_Wp = self.beta1 * self.v_Wp + (1 - self.beta1) * self.delta_Wp;
self.v_bp = self.beta1 * self.v_bp + (1 - self.beta1) * self.delta_bp;
# 标准化动量
v_Wfh_bias_corr = self.v_Wfh / (1 - self.beta1 ** self.back_times);
v_Wfx_bias_corr = self.v_Wfx / (1 - self.beta1 ** self.back_times);
v_bf_bias_corr = self.v_bf / (1 - self.beta1 ** self.back_times);
v_Wih_bias_corr = self.v_Wih / (1 - self.beta1 ** self.back_times);
v_Wix_bias_corr = self.v_Wix / (1 - self.beta1 ** self.back_times);
v_bi_bias_corr = self.v_bi / (1 - self.beta1 ** self.back_times);
v_Woh_bias_corr = self.v_Woh / (1 - self.beta1 ** self.back_times);
v_Wox_bias_corr = self.v_Wox / (1 - self.beta1 ** self.back_times);
v_bo_bias_corr = self.v_bo / (1 - self.beta1 ** self.back_times);
v_Wch_bias_corr = self.v_Wch / (1 - self.beta1 ** self.back_times);
v_Wcx_bias_corr = self.v_Wcx / (1 - self.beta1 ** self.back_times);
v_bc_bias_corr = self.v_bc / (1 - self.beta1 ** self.back_times);
v_Wp_bias_corr = self.v_Wp / (1 - self.beta1 ** self.back_times);
v_bp_bias_corr = self.v_bp / (1 - self.beta1 ** self.back_times);
self.s_Wfh = self.beta2 * self.s_Wfh + (1 - self.beta2) * torch.square(self.delta_Wfh);
self.s_Wfx = self.beta2 * self.s_Wfx + (1 - self.beta2) * torch.square(self.delta_Wfx);
self.s_bf = self.beta2 * self.s_bf + (1 - self.beta2) * torch.square(self.delta_bf);
self.s_Wih = self.beta2 * self.s_Wih + (1 - self.beta2) * torch.square(self.delta_Wih);
self.s_Wix = self.beta2 * self.s_Wix + (1 - self.beta2) * torch.square(self.delta_Wix);
self.s_bi = self.beta2 * self.s_bi + (1 - self.beta2) * torch.square(self.delta_bi);
self.s_Woh = self.beta2 * self.s_Woh + (1 - self.beta2) * torch.square(self.delta_Woh);
self.s_Wox = self.beta2 * self.s_Wox + (1 - self.beta2) * torch.square(self.delta_Wox);
self.s_bo = self.beta2 * self.s_bo + (1 - self.beta2) * torch.square(self.delta_bo);
self.s_Wch = self.beta2 * self.s_Wch + (1 - self.beta2) * torch.square(self.delta_Wch);
self.s_Wcx = self.beta2 * self.s_Wcx + (1 - self.beta2) * torch.square(self.delta_Wcx);
self.s_bc = self.beta2 * self.s_bc + (1 - self.beta2) * torch.square(self.delta_bc);
self.s_Wp = self.beta2 * self.s_Wp + (1 - self.beta2) * torch.square(self.delta_Wp);
self.s_bp = self.beta2 * self.s_bp + (1 - self.beta2) * torch.square(self.delta_bp);
s_Wfh_bias_corr = self.s_Wfh / (1 - self.beta2 ** self.back_times);
s_Wfx_bias_corr = self.s_Wfx / (1 - self.beta2 ** self.back_times);
s_bf_bias_corr = self.s_bf / (1 - self.beta2 ** self.back_times);
s_Wih_bias_corr = self.s_Wih / (1 - self.beta2 ** self.back_times);
s_Wix_bias_corr = self.s_Wix / (1 - self.beta2 ** self.back_times);
s_bi_bias_corr = self.s_bi / (1 - self.beta2 ** self.back_times);
s_Woh_bias_corr = self.s_Woh / (1 - self.beta2 ** self.back_times);
s_Wox_bias_corr = self.s_Wox / (1 - self.beta2 ** self.back_times);
s_bo_bias_corr = self.s_bo / (1 - self.beta2 ** self.back_times);
s_Wch_bias_corr = self.s_Wch / (1 - self.beta2 ** self.back_times);
s_Wcx_bias_corr = self.s_Wcx / (1 - self.beta2 ** self.back_times);
s_bc_bias_corr = self.s_bc / (1 - self.beta2 ** self.back_times);
s_Wp_bias_corr = self.s_Wp / (1 - self.beta2 ** self.back_times);
s_bp_bias_corr = self.s_bp / (1 - self.beta2 ** self.back_times);
self.Wfh -= lr * v_Wfh_bias_corr / (torch.sqrt(s_Wfh_bias_corr) + self.eps);
self.Wfx -= lr * v_Wfx_bias_corr / (torch.sqrt(s_Wfx_bias_corr) + self.eps);
self.bf -= lr * v_bf_bias_corr / (torch.sqrt(s_bf_bias_corr) + self.eps);
self.Wih -= lr * v_Wih_bias_corr / (torch.sqrt(s_Wih_bias_corr) + self.eps);
self.Wix -= lr * v_Wix_bias_corr / (torch.sqrt(s_Wix_bias_corr) + self.eps);
self.bi -= lr * v_bi_bias_corr / (torch.sqrt(s_bi_bias_corr) + self.eps);
self.Woh -= lr * v_Woh_bias_corr / (torch.sqrt(s_Woh_bias_corr) + self.eps);
self.Wox -= lr * v_Wox_bias_corr / (torch.sqrt(s_Wox_bias_corr) + self.eps);
self.bo -= lr * v_bo_bias_corr / (torch.sqrt(s_bo_bias_corr) + self.eps);
self.Wch -= lr * v_Wch_bias_corr / (torch.sqrt(s_Wch_bias_corr) + self.eps);
self.Wcx -= lr * v_Wcx_bias_corr / (torch.sqrt(s_Wcx_bias_corr) + self.eps);
self.bc -= lr * v_bc_bias_corr / (torch.sqrt(s_bc_bias_corr) + self.eps);
self.Wp -= lr * v_Wp_bias_corr / (torch.sqrt(s_Wp_bias_corr) + self.eps);
self.bp -= lr * v_bp_bias_corr / (torch.sqrt(s_bp_bias_corr) + self.eps);
self.back_times += 1;
def reset(self):
self.times = 0;
self.fList = [];
self.iList = [];
self.oList = [];
self.cList = [];
self.preList=[];
self.f = torch.zeros(self.batch_size,self.hidden_size);
self.i = torch.zeros(self.batch_size,self.hidden_size);
self.o = torch.zeros(self.batch_size,self.hidden_size);
self.ct = torch.zeros(self.batch_size,self.hidden_size);
self.fList.append(self.f);
self.iList.append(self.i);
self.oList.append(self.o);
self.ctList.append(self.ct);
self.hList = [torch.zeros(self.batch_size,self.hidden_size)];
self.cList = [torch.zeros(self.batch_size,self.hidden_size)];
def Sigmoid_forward(self,x):
return torch.sigmoid(x);
def Sigmoid_backward(self,x):
return self.Sigmoid_forward(x) * (1 - self.Sigmoid_forward(x));
def Tanh_forward(self,x):
return torch.tanh(x);
def Tanh_backward(self,x):
return 1 - (self.Tanh_forward(x) * self.Tanh_forward(x));
CNN = CNNModel(train_timesteps);
criterion = nn.MSELoss();
optimizer = torch.optim.Adam(CNN.parameters(),lr = 0.0005);
l = LSTM(train_timesteps,batch_size,704,3,1);
lr = 0.0009;
lossList = [];
l.init_adam_states();
for epoch in range(train_Epoch):
for i in range(len(train_X)):
# train_X[i]:(batchsize,timesteps,inputsize)
# x = train_X[i].permute(1, 0, 2); # (timesteps,batchsize,inputsize)
x = train_X[i]; # (1,20,1)
# forward
pre_CNN = CNN(x);
with torch.no_grad():
l.forward(torch.unsqueeze(pre_CNN,0)); # (1,1,704)
pre = l.prediction();
# loss
loss = (pre - train_Y[i]) * (pre - train_Y[i]);
loss = loss.mean();
y_grad = 2.0 * (pre - train_Y[i]);
h_grad = y_grad @ l.Wp;
# backward
l.backward(torch.unsqueeze(pre_CNN,0),h_grad,y_grad);
# update
l.update(lr);
l.reset();
# CNN
pre_CNN
loss_CNN = criterion(pre,train_Y[i]);
optimizer.zero_grad();
loss_CNN.backward();
lossList.append(loss);
if epoch % 10 == 0:
print(f"Epoch {epoch},loss:{loss}");
plt.plot(lossList);
plt.title("Loss");
plt.xlabel("Epoch");
plt.ylabel("loss");
plt.show();