Can I split my batch in the forward method?

So do you think this is possible (pseudo-code):

...
def _split_classes(batch, targets):
    class_a, class_b = [], []
    for i, sample in enumerate(batch):

        if targets[i] == 1:
            class_a.append(sample)
        else:
            class_b.append(sample)

    return torch.Tensor(class_a), torch.Tensor(class_b)


def forward(x, targets):
    x = encoder(x)
    x_a, x_b = self._split_classes(x, targets)

    x_a = decoderA(x_a)
    x_b = decoderB(x_b)

    x = concat((x_a, x_b))
    return x