In such two structures, is the weight update process the same?

Hello, everybody.

In such two structures, is the weight update process the same? I mean, in the code before else, the data is divided into two parts, which are forwarded to calculate the loss, and in the code after else, all the data is forwarded together to calculate the loss. Are these two methods the same when backward the update parameters?

        ......

        self.tr_flickrloader = iter(self.tr_flickrloader)

        for curr_epoch in range(self.start_epoch, self.end_epoch):
            train_mainloss_record = AvgMeter()
            train_branchloss_record = AvgMeter()

            # for each batch, a batch include a quarter of the A dataset(self.tr_dutsloader), 
            # and three quarters of the B dataset(self.tr_flickrloader)
            
            # My network has a subnetwork, `self.net`,
            # they are independent of each other
            
            for train_batch_id, duts_data in enumerate(self.tr_dutsloader):
                curr_iter = curr_epoch * len(self.tr_dutsloader) + train_batch_id
                self.main_sche.step(curr_iter)
                self.branch_sche.step(curr_iter)

                duts_inputs, duts_labels, duts_names = duts_data
                flickr_inputs, flickr_labels, flickr_names = next(self.tr_flickrloader)

                self.main_opti.zero_grad()
                self.branch_opti.zero_grad()

                # batch_size=12, duts_inputs:flickr_inputs=1:3
                if curr_epoch >= 20:
                    duts_inputs = duts_inputs.to(self.dev)
                    duts_labels = duts_labels.to(self.dev)
                    flickr_inputs = flickr_inputs.to(self.dev)
                    flickr_labels = flickr_labels.to(self.dev)

                    main_o_duts = self.net(duts_inputs)
                    main_train_loss = 1 / 4 * self.crit(main_o_duts, duts_labels)
                    main_o_flickr = self.net(flickr_inputs)
                    main_train_loss += 3 / 4 * self.crit(main_o_flickr, flickr_labels)
                    main_train_loss.backward()
                else:
                    train_inputs = torch.cat((duts_inputs, flickr_inputs), dim=0)
                    train_labels = torch.cat((duts_labels, flickr_labels), dim=0)
                    train_inputs = train_inputs.to(self.dev)
                    train_labels = train_labels.to(self.dev)

                    main_otr = self.net(train_inputs)
                    main_train_loss = self.crit(main_otr, train_labels)
                    main_train_loss.backward()

                self.main_opti.step()
                self.branch_opti.step()

                ......

Thanks!

Whether it is the same depends on the batchsizes of the two dataloaders and the network:

  1. The first enforces the 1/4-3/4 weighting even if the batch sizes are not in that ratio. If the criterion self.crit takes the mean over batch size, this would be identical if the batch size is 3 duts inputs to 9 flick inputs.
  2. If you have batch norm in your network, separate runs will take batch norms for the two runs through the network separately. Depending on whether both datasets generate similar statistics, this might give a difference. - They would also lead to different averages to be used in inference, but that likely doesn’t matter if you train for many epochs using the else branch. It may have an effect on the intermediate validation stats.

So it is likely close, but not the same.

Best regards

Thomas

1 Like

Ok, that’s really the answer I want, thank you.

In fact, there are three samples of A datasets (self.tr_dutsloader) and nine samples of B dataset (self.tr_flickrloader) in a batch.

For both codes, the progress of the weght update is identical, but the forward propagation is different(BatchNorm in self.net(flickr_inputs) and in self.net(duts_inputs) are different). Do you think I get it right?