With this implementation your batches will have to be the same size, since rx.t()
and ry
are square matrices and will end up with different sizes if batch sizes are different (and then zz
also won’t be square).
The best option is just to randomly sample N
elements from each of your batches so they are the same size.