I tried to make a recursive network. It is composed of convolution part and fully connected part, and would use previous output. The shape looks like this:
RecursiveModel(
(conv_part): Sequential(
(conv_0): Conv1d(1, 4, kernel_size=(128,), stride=(1,), padding=(32,))
(conv_1): Conv1d(4, 4, kernel_size=(128,), stride=(1,), padding=(32,))
)
(full_part): Sequential(
(full_conn_0): Linear(in_features=648, out_features=377, bias=True)
(acti_lrelu_0): LeakyReLU(negative_slope=0.01)
(full_conn_1): Linear(in_features=377, out_features=219, bias=True)
(acti_lrelu_1): LeakyReLU(negative_slope=0.01)
(full_conn_2): Linear(in_features=219, out_features=128, bias=True)
(acti_lrelu_2): LeakyReLU(negative_slope=0.01)
)
)
The forward part is:
def forward(self, x, prev_output):
input_shape = x.size()
assert(len(input_shape) == 3)
assert(input_shape[0] == 1) # batch == 1
assert(input_shape[1] == 1) # channel == 1
assert(input_shape[2] == self.frame_size)
conv_output = self.conv_part.forward(x)
conv_output_plain = conv_output.view(1, self.conv_output_shape[0]*self.conv_output_shape[1])
full_input = torch.cat( (conv_output_plain, prev_output), 1 )
result = self.full_part.forward(full_input)
return result
Where previous output will be concatenated with convolution output, and sent into fully connected part.
However, it would claim an error:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
I also tried to detach previous output:
full_input = torch.cat( (conv_output_plain, prev_output.detach()), 1 )
This would run, but the error never converge.
At last, this is the training part of code, I cut input sequence into segments and each time run one segment:
def do_training():
i_iter = 0
while True:
try:
i_dataset = random.randrange(0, num_dset)
dset_input = all_input[i_dataset]
dset_refout = all_refout[i_dataset]
num_frame = len(dset_input)
assert(num_frame == len(dset_refout))
prev_output = zero_output
for i_frame_start in range(0, num_frame, options.batch_sz):
i_iter += 1
optimizer.zero_grad()
loss = 0
for j in range(options.batch_sz):
i_frame = i_frame_start + j
if i_frame >= num_frame:
break
pseudo_batch_input[0][0] = dset_input[i_frame] / 60.0 + 1.0 # normalize to -1 ~ 1
pseudo_batch_refout[0] = dset_refout[i_frame]
curr_output = model.forward(pseudo_batch_input.to(device), prev_output)
prev_output = curr_output
curr_loss = cri( curr_output, pseudo_batch_refout.to(device) )
if math.isnan(float(curr_loss)):
pdb.set_trace()
continue
loss += curr_loss
curr_loss.backward()
# run backward for each batch step
loss_stat_re = loss_stat.record(float(loss))
if loss_stat_re is not None:
line = str(i_iter) + "\t" + "\t".join(map(str, loss_stat_re))
logger.info(line)
#loss.backward()
optimizer.step()
if i_iter > options.n_iter:
return
except KeyboardInterrupt:
print("early out by int at iter %d" % i_iter)
return