Unable to optimize with DistributedDataParallel

Hello,

Thank you so much for reading this and trying to help me out!
I am trying to develop a vision transformer for classifying spectrograms.
To help speed up the processing of the algorithm, I am using DistributedDataParallel to split the data over two A100 GPUs.

However, I run into an issue where the loss of my model remains constant over time, and it seems like the optimization step is not working.
My data is a float64 400 by 400 tensor with values around 10^-21.

Here is my code:

import os
import time

# os.system("clear -x")
import pandas as pd
from scipy import optimize
from sympy import N
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
import torchmetrics
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
import pickle
import sys

from vit_pytorch import vit as simple_vit
from vit_pytorch import vit_for_small_dataset as vit_sd
from vit_pytorch.deepvit import DeepViT

import datetime
import time
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import wandb
from coredldev.dataloaders import train_validation_test_dataloaders
from coredldev.dataset import CoReDataset
from coredldev.finders.distance_scaling.h5 import h5Finder
from coredldev.preprocessing.ligo_noise.file_noise import generate_noise
from coredldev.preprocessing.ligo_noise.inject_noise import noise_injection
from coredldev.preprocessing.raw_postmerger.detector_angle_mixing import (
    detector_angle_mixing,
)
from coredldev._resources.eos_to_index_map import eos_to_index_map as eosmap

eosmap = eosmap[0]
from coredldev.preprocessing.raw_postmerger.distance_scale import distance_scale
from coredldev.preprocessing.raw_postmerger.wavelet_transform import wavelet_transform
from coredldev.preprocessing.to_tensor import to_tensor_clean
from coredldev.sources.distance_scaling.h5 import h5Source
from coredldev.utilites.pipeline import pipeline
from torch import autograd
from torch.multiprocessing import freeze_support, set_start_method
from torch.utils.data import DataLoader
from torchmetrics import (
    AUROC,
    ROC,
    Accuracy,
    F1Score,
    FBetaScore,
    MetricCollection,
    Precision,
    Recall,
)
import pathlib as p

import gc

print("Starting program", os.getpid(), os.getppid())


def get_df_from_rdict(rdict):
    return pd.DataFrame(pd.Series(rdict).map(lambda x: x.cpu().item())).T


def save_model_to_file(model, path):
    torch.save(model.state_dict(), path)


def init_model():
    return nn.Sequential(
    nn.BatchNorm2d(num_features=1).cuda(),
    DeepViT(
        image_size=400,
        patch_size=20,
        num_classes=19,
        dim=1024,
        depth=2,
        heads=25,
        mlp_dim=int(2048 / 2),
        dropout=0.1,
        emb_dropout=0.1,
        channels=1,
    ).cuda())


class trainer:
    def __init__(
        self,
        wandb_config,
        save_interval,
        model: torch.nn.Module,
        optimizers: torch.optim,
        lossfn: torch.nn,
        devices: list,
        load_device: int,
        rank_metrics: torchmetrics.Metric,
        metrics: torchmetrics.MetricCollection,
        train_dataloader: torch.utils.data.DataLoader,
        val_dataloader: torch.utils.data.DataLoader,
        test_dataloader: torch.utils.data.DataLoader,
        scheduler: torch.optim.lr_scheduler,
        scheduler_params: dict,
    ):
        self.model = model
        self.lossfn = lossfn
        self.devices = devices
        self.load_device = load_device
        self.model = DDP(model, device_ids=self.devices, output_device=self.load_device)
        self.epochs = wandb_config.epochs
        self.wandb_config = wandb_config
        self.save_interval = save_interval
        self.rank_metrics = rank_metrics
        self.metrics = metrics
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader
        self.history = pd.DataFrame()
        self.optimizers = []
        for i in optimizers:
            self.optimizers.append(
                i(params=self.model.parameters(), lr=self.wandb_config.lr)
            )
        print(self.optimizers)
        self.scheduler = scheduler(self.optimizers[0], **scheduler_params)

    def _run_batch(
        self,
        source,
        targets,
        epoch,
        batchnum,
    ):
        self.model.train()
        stime = time.time()
        for optimizer in self.optimizers:
            optimizer.zero_grad()
        shape = source.shape
        source = (
            source.to(self.devices[0])
            .view(shape[0], 1, shape[1], shape[2])
            .to(torch.float)
        )
        targets = targets[:, 0].to(self.devices[0]).to(torch.long)
        outputs = self.model(source)
        loss = self.lossfn(outputs, targets)
        loss.backward()
        for optimizer in self.optimizers:
            optimizer.step()
        # torch.cuda.empty_cache()
        rmetric = self.rank_metrics(outputs.to("cpu"), targets.to("cpu"))

        infodict = dict(
            loss=loss.item(),
            rank_metric=rmetric.item(),
            time=time.time() - stime,
            lr=self.scheduler.get_last_lr()[0],
            step=epoch * len(self.train_dataloader) + batchnum,
            epoch=epoch,
            batch_number=batchnum,
        )

        wandb.log(infodict)
        print(
            f"{self.devices[0]} | {epoch:3}/{self.wandb_config.epochs:3} {batchnum:5}/{len(self.train_dataloader):5} loss: {loss.item():.3f} rank_metric: {rmetric.item():.3f} time: {infodict['time']} lr: {infodict['lr']}, rank: {self.devices[0]}"
        )

    def eval_model(self, dataloader):
        print("Evaluating Model - This will take a while")
        self.model.eval()
        raw_out = []
        raw_targets = []
        btime = time.time()
        aggtime = 0.0
        with torch.no_grad():
            for batch, (sg, params) in enumerate(dataloader):
                stime = time.time()
                sg = sg.to(self.devices[0]).float()
                sgsh = sg.shape
                sg = sg.view(sgsh[0], 1, sgsh[1], sgsh[2]).to(self.devices[0])
                params = params[:, 0].to(self.devices[0]).to(torch.long)
                raw_out.append(self.model(sg).detach().cpu())
                raw_targets.append(params.cpu())
                mltime = time.time() - stime
                aggtime += mltime
                tot_time = time.time() - btime
                del sg, params
                torch.cuda.empty_cache()
                gc.collect()
                print(
                    f"{self.devices[0]:<3} {batch+1:<7} / {len(dataloader):<7} | [% complete]: {round(100*(batch+1)/len(dataloader),6):<10} | [total time]: {datetime.timedelta(seconds = tot_time)} | [ml time]:{round(mltime,5):<10} | [%/hour]:{round((100 * (batch+1)/len(dataloader))/((tot_time)/3600),5):<10} | [(total - just ml time)%]: {round((tot_time - aggtime)/tot_time,5):<10}",
                )
            self.model.train()
            output = torch.vstack(raw_out)
            parameters = torch.concat(raw_targets, dim=0)
            return self.metrics(output.cpu(), parameters.cpu()), self.rank_metrics(
                output.cpu(), parameters.cpu()
            )

    def save_model(self, dataloader, epoch, step=-1):
        out, rmetric = self.eval_model(dataloader)
        self.history = pd.concat([self.history, get_df_from_rdict(out)], axis=0)
        wandb.log({"epoch": epoch} | out | dict(self.history.max()))
        if rmetric >= max(self.history["MulticlassAccuracy"]):
            save_model_to_file(
                self.model,
                p.Path(__file__).parent.absolute()
                / f"saved_models/classifier-dvit-{self.wandb_config.start_time.replace(':','-')}-maxacc-{rmetric}-{epoch}-{step}-rank{self.devices}.pt",
            )

    def _run_epoch(self, epoch):
        self.model.train()
        for batchnum, (source, targets) in enumerate(self.train_dataloader):
            self._run_batch(source, targets, epoch, batchnum)
            if batchnum % self.save_interval == 0:
                self.save_model(self.val_dataloader, epoch, batchnum)

    def final_eval(self):
        self.model.eval()
        raw_out = []
        raw_targets = []
        eoscomp = []
        with torch.no_grad():
            for batch, (sg, params) in enumerate(self.test_dataloader):
                stime = time.time()
                sg = sg.to(self.devices[0]).float()
                sgsh = sg.shape
                sg = sg.view(sgsh[0], 1, sgsh[1], sgsh[2]).to(self.devices[0])
                params = params[:, 0].to(self.devices[0]).to(torch.long)
                eoscomp.append(params)
                raw_out.append(self.model(sg).detach().cpu())
                raw_targets.append(params.cpu())
                print(
                    f"{self.devices[0]} {batch+1} / {len(self.test_dataloader)} {time.time()-stime}",
                    end="\r",
                )
            self.model.train()
            output = torch.vstack(raw_out)
            parameters = torch.concat(raw_targets, dim=0)
            roc = ROC(task="multiclass", num_classes=19)
            fpr, tpr, thresholds = roc(output, parameters)
            wandb.log(
                {"final": 0}
                | self.metrics(output.cpu(), parameters.cpu())
                | {"fpr": fpr, "tpr": tpr, "thresholds": thresholds},
            )

            torch.save([fpr, tpr, thresholds], "./roc.pt")
            wandb.save("./roc.pt")

        output = torch.argmax(output, dim=1)
        comparisons = torch.eq(output.to("cpu"), torch.concat(eoscomp).to("cpu")).to(
            torch.float
        )
        finalacc = {"final_acc": torch.mean(comparisons)}
        wandb.log(finalacc)

        parameters = torch.concat([comparisons, parameters])

        df = pd.DataFrame(parameters)
        df = df.rename(
            columns={
                0: "Correct",
                2: "EOS",
                3: "M1",
                4: "M2",
                5: "SHFT",
                6: "DIST",
                7: "EXRAD",
                8: "RA",
                9: "DEC",
                10: "POL",
                11: "spec",
            }
        )

        df.to_csv("./c_results.csv")

        wandb.save("./c_results.csv")
        # wandb.save(df)
        wandb.finish()
        print("finished with training")

    def train(self):
        print(
            "started training for rank:",
            self.devices[0],
        )
        for epoch in range(self.epochs):
            self._run_epoch(epoch)
            self.scheduler.step()
        self.final_eval()


acc = Accuracy(task="multiclass", num_classes=19)
combined = MetricCollection(
    [
        acc,
        AUROC(task="multiclass", num_classes=19),
        Precision(task="multiclass", num_classes=19),
        Recall(task="multiclass", num_classes=19),
        F1Score(task="multiclass", num_classes=19),
        FBetaScore(task="multiclass", num_classes=19, beta=0.5),
    ]
)


class config_setup:
    pass


def main(rank: int, world_size: int, use_gpu, stime):
    wandb.init(project="dvit-psu-cluster", group=f"DDP-{stime}")
    print("in main function", rank)

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "0"

    init_process_group(backend="gloo", rank=rank, world_size=world_size)
    print(os.getpid(), os.getppid(), rank, world_size)
    torch.cuda.set_device(rank)

    config = config_setup()
    config.epochs = 5
    config.start_time = datetime.datetime.now().isoformat()
    config.lr = 5e-4

    datapoints = pickle.load(
        open(p.Path(__file__).parent.absolute() / "datapoints.p", "rb")
    )
    source = h5Source(eos_to_index_map=eosmap, sync=False)
    complete_dataset = CoReDataset(
        source,
        datapoints,
        pipeline(
            {
                "dam": detector_angle_mixing(),
                "dis": distance_scale(),
                "mwt": wavelet_transform(gpu=use_gpu, device=rank),
                "gnn": generate_noise(),
                "ttc": to_tensor_clean(),
            }
        ),
    )

    train_dl, valid_dl, test_dl = train_validation_test_dataloaders(
        complete_dataset,
        train_split=0.7,
        test_split=0.15,
        validation_split=0.15,
        train_batch_size=64,
        validation_batch_size=200,
        test_batch_size=200,
        shuffle_dataset=False,
    )
    print("creating trainer")
    t = trainer(
        wandb_config=config,
        save_interval=30_000,
        model=init_model(),
        optimizers=[optim.Adam, optim.AdamW, optim.NAdam],
        lossfn=nn.CrossEntropyLoss(),
        devices=[rank],
        load_device=0,
        rank_metrics=acc,
        metrics=combined,
        train_dataloader=train_dl,
        val_dataloader=valid_dl,
        test_dataloader=test_dl,
        scheduler=optim.lr_scheduler.StepLR,
        scheduler_params={"step_size": 1, "gamma": 0.7},
    )
    t.train()
    destroy_process_group()


if __name__ == "__main__":
    print("in main if statement of process: ", os.getpid(), "parent is:", os.getppid())
    world_size = torch.cuda.device_count()
    mp.spawn(
        main,
        args=(world_size, True, datetime.datetime.now().isoformat()),
        nprocs=world_size,
        join=True,
    )
    print("completed training")

Thanks!

Did you try to normalize the data?

I am using a batch normalization layer with an epsilon of 10^-70 for normalizing the data before it is fed to the algorithm, but I am unable to see any changes in the results