So do you think this is possible (pseudo-code):
...
def _split_classes(batch, targets):
class_a, class_b = [], []
for i, sample in enumerate(batch):
if targets[i] == 1:
class_a.append(sample)
else:
class_b.append(sample)
return torch.Tensor(class_a), torch.Tensor(class_b)
def forward(x, targets):
x = encoder(x)
x_a, x_b = self._split_classes(x, targets)
x_a = decoderA(x_a)
x_b = decoderB(x_b)
x = concat((x_a, x_b))
return x