Hi!
I am working in image retrieval and I would like to compute the loss on the entire dataset.
In order to do that I have to first compute the features for the whole dataset, and then compute the loss in a batch-wise manner and do gradient accumulation (due to memory constraint).
I wanted to use DistributedDataParallel
in order to speed-up my training, but did not manage to do it.
The training would be on a single node with 3 to 4 gpu’s.
This is an example of what I am basically trying to do :
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, CenterCrop
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from tqdm import tqdm
transform = Compose(
(Resize((256,256)),
CenterCrop(224),
ToTensor(),
Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
)
)
dts_train = CIFAR10("/users/r/ramzie/datasets/", train=True, transform=transform, download=True)
def get_loader(dts, sampler=None):
return DataLoader(
dts,
batch_size=128,
shuffle=False,
drop_last=False,
pin_memory=True,
num_workers=10,
sampler=sampler,
)
class L2Norm(nn.Module):
def forward(self, X):
return F.normalize(X, dim=-1)
def criterion(di, lb, features, labels):
scores = torch.mm(di, features.t())
gt = lb.view(-1, 1) == labels.unsqueeze(0)
return F.relu(-scores[gt]).mean() + F.relu(scores[~gt]).mean()
net = resnet18(pretrained=True)
net.fc = L2Norm()
_ = net.cuda()
opt = torch.optim.SGD(net.parameters(), 0.1)
scaler = torch.cuda.amp.GradScaler()
for e in range(2):
loader = get_loader(dts_train)
features = []
labels = []
# We first compute the features and labels for the whole dataset
# This could a first distributed loop
for (x, y) in tqdm(loader, 'computing features'):
with torch.cuda.amp.autocast():
with torch.no_grad():
feat = net(x.cuda())
features.append(feat)
labels.append(y)
features = torch.cat(features)
labels = torch.cat(labels).cuda()
####################################################
####################################################
loader = get_loader(dts_train)
# This is the bottleneck, would could also be distributed
for (x, y) in tqdm(loader, 'accumulating gradient'):
with torch.cuda.amp.autocast():
di = net(x.cuda())
lb = y.cuda()
loss = criterion(di, lb, features, labels) / len(features)
# gradient accumulation for the entire dataset
scaler.scale(loss).backward()
####################################################
####################################################
# only at the end perform optimization (full batch)
scaler.step(opt)
scaler.update()
Thank you if you have any time to help!