How to gather data from all GPUs for DataParallel?

Hi,

In my forward pass, I am trying to gather the features from a CNN model to compute the class prototypes from the current batch. I can do that for DistributedDataParallel easily using the example given by ‘MOCO’. But I am not sure how to do it for DataParallel. It would be great if someone can give me some pointers for that.

TIA.

When working with PyTorch’s DataParallel, extracting features from a CNN model can be a bit different compared to using DistributedDataParallel. In DataParallel, the model is replicated across multiple devices (GPUs), and each replica processes a different batch of data and b2b contact enrichment. To gather features from these replicas, you need to collect the outputs from each device and concatenate them. Here’s a general outline of how you can modify your forward pass for feature extraction when using DataParallel:
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel

Your CNN model

class YourModel(nn.Module):
def init(self):
super(YourModel, self).init()
# Define your model layers here

def forward(self, x):
    # Your forward pass logic
    return x  # Adjust this line based on your actual forward pass

Wrap your model with DataParallel

model = YourModel()
model = DataParallel(model)

Assuming input tensor ‘input_data’ for the forward pass

input_data = torch.randn((batch_size, channels, height, width))

Forward pass with DataParallel

output_data = model(input_data)

Gather features from different replicas

gathered_features = output_data.module.feature_extraction_function()

‘gathered_features’ now contains the features from all replicas