I am trying to understand how to use DataParallel API
https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
Currently, I have two GPU’s, and I am passing an object to the model. I can see that two GPU’s are now being used. I know that Pytorch launches to “threads” for each GPU. I want to have an API to get that number in my forward pass so I can allocate the GPU memory correctly. How can I know it?