A trivial but annoying "bug" in random_split

I’m researching about federated learning, and I’m splitting the dataset to simulate the scenario in which a few nodes form a federated learning network with their respective training data possessed.

So first I simulate 5 nodes with:
trainsets = random_split(whole_trainset, [0.2, 0.2, 0.2, 0.2, 0.2], generator=generator)
And 10 nodes with:
trainsets = random_split(whole_trainset, [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], generator=generator)

They all worked fined, but then I went to 20 nodes:
ValueError: Sum of input lengths does not equal the length of the input dataset!
print(torch.version) gives “2.2.1+cu118” so it’s not like I’m using too-old version.
Then I realized that print(0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05+0.05) gives “1.0000000000000002” by plain Python in the first place.

I believed that it’s easy-to-fix (only this part in random_split, not why Python works like this I guess), so I decided to create an account and report it here.
I also made an issue on Github.

Thanks for reporting the issue! Would you also be interested in fixing it? If so, let the code owners know in the GitHub issue you’ve created.

Sorry, I’m currently busy on the researches, so I can’t just jump into these open source, code review, pull requests, contribute, etc.
But I believe something like this will work:

If the sum is very close to 1 (instead of equal to 1), then assume that it’s the “proportion case” and do some workaround.
For example, calculate dataset_length*each_proportion and see if they sum up to dataset_length.
If not, then as long as the sum of all the proportion numbers is very close to 1, I think there should still be feasible workaround. It’s probably even already implemented.
For example:

import torch
from torch.utils.data import random_split
trainsets = random_split(range(11), [0.5, 0.5])
print(len(trainsets[0]))
print(len(trainsets[1]))

Gives 6 and 5.
I’m not sure how the complete mechanism had actually been, but I think the same mechanism can be applied as long as sum([0.05, … , 0.05]) is very close to 1, or whatever lengths list that sums very closely to 1. (Probably need a few additional checking and workaround when their sum is not equal to 1 though.)

And that GitHub issue haven’t get a bug label or any reply. I guess the devs just won’t notice it that soon.