Multiprocessing certain parts of the model

Hi all

I have a model which has an encoder and two decoders. Each decoder has a separate loss function and target, so I think I can run both decoders separately. Is this possible using torch.multiprocessing?

So I want to do something like

code = self.encoder(input0)

decode1 = self.decoder1(code, input1)
decode2 = self.decoder2(code, input2)

loss1 = loss_fn(decode1, output1)
loss2 = loss_fn(decode2, output2)
total_loss = loss1 + loss2

self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()

and concurrently run self.decoder1 and self.decoder2. Is this possible?