Error in dist.scatter() when mpi backend is used

Hi, I have a problem with starting a distributed program while using “mpi” as the backend. The program is as follows:

def run(rank, size, hostname):
    print("I am {} of {} in {}".format(rank, size, hostname))
    tensor = torch.zeros(1)
    group=dist.new_group([0,1,2])
    if rank == 0:
        scatter_list=[torch.zeros(1) for _ in range(3)]
        dist.scatter(tensor= tensor, src=0, scatter_list=scatter_list, group=group)
        print("Master has completed Scatter")
    else:
        tensor += 1
        dist.scatter(tensor= tensor, src=0, group=group)
        print("worker has completed scatter")
    print('Rank', rank, 'has data', tensor[0])
    

def init_process(rank, size, hostname, fn, backend='tcp'):
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size,hostname)
    
if __name__ == "__main__":
    world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
    world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
    hostname = socket.gethostname()
    p = Process(target = init_process, 
                     args=(world_rank, world_size, hostname, run, 'mpi'))
    p.start()
    p.join()

When the program starts, it always throws an error like:

  File "mpi_test.py", line 17, in run
    dist.scatter(tensor= tensor, src=0, group=group)
TypeError: scatter() missing 1 required positional argument: 'scatter_list'

However, the error is emitted by ranks 1 and 2 which don’t need argument: ‘scatter_list’.
I tried many ways, but failed. Does anybody know why?
Thank you for reading.