Using FSDP to infer an image that doesn't fit on a single GPU

I have an already trained Unet model, which has not been trained in a distributed environment. I would like to wrap that model in an FSDP so that it can perform an inference on a 4k image that does not fit on a single GPU.

I have doubts if FSDP actually fragments the model on the GPUs so that it can partition this image and perform full inference.

I ran the example at Getting Started with Fully Sharded Data Parallel(FSDP) — PyTorch Tutorials 2.3.0+cu121 documentation but it gives me the impression that it performs the inference twice and not just once.

Hey there,
It seems like you want to use FSDP to perform inference on a large 4k image using your Unet model. FSDP distributes model parameters across GPUs, but you’re not sure if it can split the image for full inference.
The example you ran may not be doing what you need. FSDP is good for training but might need customization for inference on large images.
You might want to explore other methods for distributed inference or ask for advice from the PyTorch community.

Could you describe which customizations are needed for inference?

Yes, please!! I need an example. When I use FSDP with 2 process, the model was executed two times.