Can we compute nn.functional.linear with torch.chunks of data

Hi, I want to split my input and weight tensors to x number of torch.chunks, and then be able to evaluate the chunk-wise nn.functional.linear or nn.functional.conv_2d in parallel without having to use a for loop. Can I know if this possible in pytorch?

Hey @Sai_Kiran, are you referring to Mesh-TensorFlow-like parallelism? We currently don’t have the API for it yet. But the split, scatter, parallel_apply, and gather can be done in the application side.

With parallel_apply, there won’t be a loop in the application code, but it internally uses a loop to launch multiple threads to process inputs in parallel. Here is an example usage of parallel_apply.