Hello, everybody.
In such two structures, is the weight update process the same? I mean, in the code before else
, the data is divided into two parts, which are forwarded to calculate the loss, and in the code after else
, all the data is forwarded together to calculate the loss. Are these two methods the same when backward
the update parameters?
......
self.tr_flickrloader = iter(self.tr_flickrloader)
for curr_epoch in range(self.start_epoch, self.end_epoch):
train_mainloss_record = AvgMeter()
train_branchloss_record = AvgMeter()
# for each batch, a batch include a quarter of the A dataset(self.tr_dutsloader),
# and three quarters of the B dataset(self.tr_flickrloader)
# My network has a subnetwork, `self.net`,
# they are independent of each other
for train_batch_id, duts_data in enumerate(self.tr_dutsloader):
curr_iter = curr_epoch * len(self.tr_dutsloader) + train_batch_id
self.main_sche.step(curr_iter)
self.branch_sche.step(curr_iter)
duts_inputs, duts_labels, duts_names = duts_data
flickr_inputs, flickr_labels, flickr_names = next(self.tr_flickrloader)
self.main_opti.zero_grad()
self.branch_opti.zero_grad()
# batch_size=12, duts_inputs:flickr_inputs=1:3
if curr_epoch >= 20:
duts_inputs = duts_inputs.to(self.dev)
duts_labels = duts_labels.to(self.dev)
flickr_inputs = flickr_inputs.to(self.dev)
flickr_labels = flickr_labels.to(self.dev)
main_o_duts = self.net(duts_inputs)
main_train_loss = 1 / 4 * self.crit(main_o_duts, duts_labels)
main_o_flickr = self.net(flickr_inputs)
main_train_loss += 3 / 4 * self.crit(main_o_flickr, flickr_labels)
main_train_loss.backward()
else:
train_inputs = torch.cat((duts_inputs, flickr_inputs), dim=0)
train_labels = torch.cat((duts_labels, flickr_labels), dim=0)
train_inputs = train_inputs.to(self.dev)
train_labels = train_labels.to(self.dev)
main_otr = self.net(train_inputs)
main_train_loss = self.crit(main_otr, train_labels)
main_train_loss.backward()
self.main_opti.step()
self.branch_opti.step()
......
Thanks!