Calling backward after multiple forwards

for some reason (related to triplet loss training in parallel to cross-entropy loss) i want to create a forward pass like this:

        def forward(self,triplet):
            total_outputs = torch.empty((0,self.num_categories))
            output_embedding = {}
            for name, x in triplet.items():
                output = self.main_model(x)
                output_embedding[name] = self.get_current_embedding()
                total_outputs =[total_outputs,output],dim=0)
            return total_outputs, output_embedding

what i did is create a triplet model that inherits from nn.Module and accepts a main_model which is a normal CNN. the triplet model takes a dict as input, passes each value (a triplet component) )through the main_model and collects the triplet embedding, but also aggregates the regular binary output of the three triplet components and pass them out as total_outputs

my question:
is there anything problematic i’m not seeing about the gradients being accumulated on main_model weights through 3 consecutive forward passes? excluding memory issues that are under control, is there anything inherently wrong with the assumption that passing the anchor, positive and negative one after the other and then concatenating the binary outputs as one and calculating the loss and gradients on the total is identical to if i didn’t do it in three passes?