You could reshape the input such that the batches and sliced are both in dim0, which would thus increase the batch size via x = x.view(-1, 3, 256, 256).
This would treat each slice as an own input in the same way as your previous approach.
Alternatively, you might want to treat the slice dimension as the depth dimension.
In that case, you would need to change the model architecture, since you would need 3D layers such as nn.Conv3d.