Matplotlib doesn't work in distributed training

I am training a model in a 4-GPUs machine with torch.distributed,and I want the ONLY rank_0 process be responsible for plotting,so I wrote code like this:

 if is_distributed() and distributed.get_rank()!=0:
             print('Only rank_0 will do plotting,this is rank_{}'.format(distributed.get_rank()))
             return# in parallel context,single plot is enough
print('this is rank_0 and  will do plotting')

So,if the process rank is not 0,it should print out:

    Only rank_0 will do plotting,this is rank_x

and I do get 3 printings of this type
if the process rank is 0,it should print out:

      this is rank_0 and  will do plotting

and I never got this type of printing,and meanwhile all processes hanging and no exception got thrown out

watch -n0.1 nvidia-smi tell that before these code all, all GPU will have memory usage > 10341MB,when hitting these lines,the first GPU’s memory usage drops to 2387MB,others remain,More strangely,if changing code to
if is_distributed() and distributed.get_rank()!=1:
which let the second GPU to be responsible for plotting,when comes to these lines,1st,3rd,4th GPU’s memory usage still > 10341MB,but the 2nd GPU’s memory usage drop to 1073MB,training hangs,no exception got thrown out.
With same code in non-distributed training,the plotting works fine,would you please tell me how to make plotting work?

After adding:
before any rank_x specific operation,everything goes fine,silly me