In my forward pass, I am trying to gather the features from a CNN model to compute the class prototypes from the current batch. I can do that for DistributedDataParallel easily using the example given by ‘MOCO’. But I am not sure how to do it for DataParallel. It would be great if someone can give me some pointers for that.
When working with PyTorch’s DataParallel, extracting features from a CNN model can be a bit different compared to using DistributedDataParallel. In DataParallel, the model is replicated across multiple devices (GPUs), and each replica processes a different batch of data and b2b contact enrichment. To gather features from these replicas, you need to collect the outputs from each device and concatenate them. Here’s a general outline of how you can modify your forward pass for feature extraction when using DataParallel:
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
Your CNN model
class YourModel(nn.Module):
def init(self):
super(YourModel, self).init()
# Define your model layers here
def forward(self, x):
# Your forward pass logic
return x # Adjust this line based on your actual forward pass
Wrap your model with DataParallel
model = YourModel()
model = DataParallel(model)
Assuming input tensor ‘input_data’ for the forward pass