About loss backward in the mode of Model Parallelism

I wrote a code to do multilabel classification using the clip. I chose the method in the paper DualCoOp which uses “Prompt Learning” to learn a positive and negative prompt for each category. However, in the paper, there are only 80 classes in the COCO dataset. In my dataset, there are more than 600 classes, which means more than 1200 prompts and that is 1200 text features extracted by the clip’s text encoder. I have four RTX3090 GPUs, I use model parallelism to average distribute the prompt and their text features on the four devices and collect the classification logits on the first device. However, I don’t know how to backward loss. Here is the code:

import torch
import torch.nn as nn

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from copy import deepcopy
import torch.nn.functional as F

_tokenizer = _Tokenizer()


from .. import almodel
from ..model import ALModel
from .dualcoop_clip import build_model_conv_proj
from ..loss import build_loss
from ...dataset.utils import load_json


def load_clip_to_cpu(backbone_name, input_size, dir_cache):
    
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url, root=dir_cache)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")
    model = build_model_conv_proj(state_dict or model.state_dict(), input_size)

    return model


class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x


class MLCPromptLearner(nn.Module):
    def __init__(
        self,
        n_ctx_pos: int, n_ctx_neg: int, 
        classnames: list, clip_model, csc: bool,
        ctx_init_pos=None, ctx_init_neg=None, 
    ):
        super().__init__()
        n_cls = len(classnames)
        # ctx_init_pos = cfg.TRAINER.COOP_MLC.POSITIVE_PROMPT_INIT.strip()
        # ctx_init_neg = cfg.TRAINER.COOP_MLC.NEGATIVE_PROMPT_INIT.strip()
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]

        if ctx_init_pos and ctx_init_neg:
            # use given words to initialize context vectors
            ctx_init_pos = ctx_init_pos.replace("_", " ")
            ctx_init_neg = ctx_init_neg.replace("_", " ")
            n_ctx_pos = len(ctx_init_pos.split(" "))
            n_ctx_neg = len(ctx_init_neg.split(" "))
            prompt_pos = clip.tokenize(ctx_init_pos)
            prompt_neg = clip.tokenize(ctx_init_neg)
            with torch.no_grad():
                embedding_pos = clip_model.token_embedding(prompt_pos).type(dtype)
                embedding_neg = clip_model.token_embedding(prompt_neg).type(dtype)
            ctx_vectors_pos = embedding_pos[0, 1: 1 + n_ctx_pos, :]
            ctx_vectors_neg = embedding_neg[0, 1: 1 + n_ctx_neg, :]
            prompt_prefix_pos = ctx_init_pos
            prompt_prefix_neg = ctx_init_neg
            if csc:
                ctx_vectors_pos_ = []
                ctx_vectors_neg_ = []
                for _ in range(n_cls):
                    ctx_vectors_pos_.append(deepcopy(ctx_vectors_pos))
                    ctx_vectors_neg_.append(deepcopy(ctx_vectors_neg))
                ctx_vectors_pos = torch.stack(ctx_vectors_pos_, dim=0)
                ctx_vectors_neg = torch.stack(ctx_vectors_neg_, dim=0)

        else:
            # Random Initialization
            if csc:
                print("Initializing class-specific contexts")
                ctx_vectors_pos = torch.empty(n_cls, n_ctx_pos, ctx_dim, dtype=dtype)
                ctx_vectors_neg = torch.empty(n_cls, n_ctx_neg, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors_pos = torch.empty(n_ctx_pos, ctx_dim, dtype=dtype)
                ctx_vectors_neg = torch.empty(n_ctx_neg, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors_pos, std=0.02)
            nn.init.normal_(ctx_vectors_neg, std=0.02)
            prompt_prefix_pos = " ".join(["X"] * n_ctx_pos)
            prompt_prefix_neg = " ".join(["X"] * n_ctx_neg)

        print(f'Initial positive context: "{prompt_prefix_pos}"')
        print(f'Initial negative  context: "{prompt_prefix_neg}"')
        print(f"Number of positive context words (tokens): {n_ctx_pos}")
        print(f"Number of negative context words (tokens): {n_ctx_neg}")

        self.ctx_pos = nn.Parameter(ctx_vectors_pos)  # to be optimized
        self.ctx_neg = nn.Parameter(ctx_vectors_neg)  # to be optimized

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts_pos = [prompt_prefix_pos + " " + name + "." for name in classnames]
        prompts_neg = [prompt_prefix_neg + " " + name + "." for name in classnames]

        tokenized_prompts_pos = []
        tokenized_prompts_neg = []
        for p_pos, p_neg in zip(prompts_pos, prompts_neg):
            tokenized_prompts_pos.append(clip.tokenize(p_pos))
            tokenized_prompts_neg.append(clip.tokenize(p_neg))
        tokenized_prompts_pos = torch.cat(tokenized_prompts_pos)
        tokenized_prompts_neg = torch.cat(tokenized_prompts_neg)
        with torch.no_grad():
            embedding_pos = clip_model.token_embedding(tokenized_prompts_pos).type(dtype)
            embedding_neg = clip_model.token_embedding(tokenized_prompts_neg).type(dtype)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix_pos", embedding_pos[:, :1, :] )
        self.register_buffer("token_suffix_pos", embedding_pos[:, 1 + n_ctx_pos:, :])
        self.register_buffer("token_prefix_neg", embedding_neg[:, :1, :])
        self.register_buffer("token_suffix_neg", embedding_neg[:, 1 + n_ctx_neg:, :])

        self.n_cls = n_cls
        self.n_ctx_pos = n_ctx_pos
        self.n_ctx_neg = n_ctx_neg
        tokenized_prompts = torch.cat([tokenized_prompts_neg, tokenized_prompts_pos], dim=0)  # torch.Tensor
        self.register_buffer("tokenized_prompts", tokenized_prompts)
        self.name_lens = name_lens

    def forward(self, cls_id=None):
        ctx_pos = self.ctx_pos
        ctx_neg = self.ctx_neg

        if ctx_pos.dim() == 2:
            if cls_id is None:
                ctx_pos = ctx_pos.unsqueeze(0).expand(self.n_cls, -1, -1)
            else:
                ctx_pos = ctx_pos.unsqueeze(0).expand(len(cls_id), -1, -1)
        else:
            if cls_id is not None:
                ctx_pos = ctx_pos[cls_id]

        if ctx_neg.dim() == 2:
            if cls_id is None:
                ctx_neg = ctx_neg.unsqueeze(0).expand(self.n_cls, -1, -1)
            else:
                ctx_neg = ctx_neg.unsqueeze(0).expand(len(cls_id), -1, -1)
        else:
            if cls_id is not None:
                ctx_neg = ctx_neg[cls_id]

        if cls_id is None:
            prefix_pos = self.token_prefix_pos
            prefix_neg = self.token_prefix_neg
            suffix_pos = self.token_suffix_pos
            suffix_neg = self.token_suffix_neg
        else:
            prefix_pos = self.token_prefix_pos[cls_id]
            prefix_neg = self.token_prefix_neg[cls_id]
            suffix_pos = self.token_suffix_pos[cls_id]
            suffix_neg = self.token_suffix_neg[cls_id]


        prompts_pos = torch.cat(
            [
                prefix_pos,  # (n_cls, 1, dim)
                ctx_pos,  # (n_cls, n_ctx, dim)
                suffix_pos,  # (n_cls, *, dim)
            ],
            dim=1,
        )

        prompts_neg = torch.cat(
            [
                prefix_neg,  # (n_cls, 1, dim)
                ctx_neg,  # (n_cls, n_ctx, dim)
                suffix_neg,  # (n_cls, *, dim)
            ],
            dim=1,
        )

        prompts = torch.cat([prompts_neg, prompts_pos], dim=0)

        if cls_id is not None:
            tokenized_prompts_pos = self.tokenized_prompts[self.n_cls:][cls_id]
            tokenized_prompts_neg = self.tokenized_prompts[:self.n_cls][cls_id]
            tokenized_prompts = torch.cat([tokenized_prompts_neg, tokenized_prompts_pos], dim=0)
        else:
            tokenized_prompts = self.tokenized_prompts


        return prompts, tokenized_prompts


@almodel("DualCoop")
class DualCoop(ALModel):
    def __init__(
        self, 
        input_size: int, backbone: dict, loss: dict, mlc_prompt_learner: dict, f_attribute_index: list, logit_scale: float, 
        finetune_backbone: bool, finetune_attn: bool, optimizer: dict, dir_cache: str, fp16: bool = True 
    ):
        super().__init__()

        classnames = list(load_json(f_attribute_index).keys())

        self.loss_fn = build_loss(loss['name'])(**loss)

        # self.clip_model = load_clip_to_cpu(backbone["name"], input_size, dir_cache).float()

        self.optim_set = optimizer
        
        if not finetune_backbone:
            print('Freeze the backbone weights')
            backbone_params = self.backbone_params()
            for param in backbone_params:
                param.requires_grad_(False)

        if not finetune_attn:
            print('Freeze the attn weights')
            attn_params = self.attn_params()
            for param in attn_params:
                param.requires_grad_(False)


        self.visual_encoder_type = backbone["name"]

        clip_model = load_clip_to_cpu(backbone["name"], input_size, dir_cache)
        
        if not fp16:
            clip_model.float()

        self.dtype = clip_model.dtype
        print(f"Using {self.dtype}")
            
        
        self.prompt_learner = MLCPromptLearner(**mlc_prompt_learner, classnames=classnames, clip_model=clip_model)
        
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = logit_scale
        
        self.start_device = 1
        self.num_devices = torch.cuda.device_count() - self.start_device
        self.num_prompts = len(classnames) * 2
        
        self.devices = [torch.device(f'cuda:{i + self.start_device}') for i in range(self.num_devices)]

    
    def distributed_prompts(self, prompts, tokenized_prompts):
        dist_prompts, dist_tokenized_prompts = [], []
        num_prompt_per_device = prompts.shape[0] // self.num_devices
        num_rest_prompt = prompts.shape[0] - self.num_devices * num_prompt_per_device
        group_bin = [[i * num_prompt_per_device, (i + 1) * num_prompt_per_device] for i in range(self.num_devices)]
        group_bin[-1][-1] += num_rest_prompt
        
        for i, gb in enumerate(group_bin):
            dp = prompts[gb[0]: gb[1]].to(self.devices[i])
            dtp = tokenized_prompts[gb[0]: gb[1]].to(self.devices[i])
            dist_prompts.append(dp)
            dist_tokenized_prompts.append(dtp)
        
        return dist_prompts, dist_tokenized_prompts
    
    def distributed_extract_text_features(self, dist_prompts, dist_tokenized_prompts):
        encoder_device = torch.device('cuda:0')
        distributed_text_features = []
        
        for dp, dtp in zip(dist_prompts, dist_tokenized_prompts):
            
            prompt_device = dp.device

            self.text_encoder = self.text_encoder.to(prompt_device)
            
            dtf = self.text_encoder(dp, dtp)
            dtf = dtf / dtf.norm(dim=-1, keepdim=True)

            distributed_text_features.append(dtf)

        self.text_encoder = self.text_encoder.to(encoder_device)

        return distributed_text_features

    def distributed_feature_aggregation(self, image_features, distributed_text_features):
        imf_device = image_features
        outputs = []
        for dtf in distributed_text_features:
            dtf_device = dtf.device
            image_features = image_features.to(dtf_device)
            output = 20 * F.conv1d(image_features, dtf[:, :, None]).to(imf_device)
            outputs.append(output)
        
        return torch.cat(outputs, dim=1)


    def infer(self, data, cls_id=None):
        # get image and text features
        image = data['i']
        
        image_features, attn_weights = self.image_encoder(image.type(self.dtype))
        
        prompts, tokenized_prompts = self.prompt_learner(cls_id)

        dist_prompts, dist_tokenized_prompts = self.distributed_prompts(prompts, tokenized_prompts)

        distributed_text_features = self.distributed_extract_text_features(dist_prompts, dist_tokenized_prompts)

        # text_features = self.text_encoder(prompts, tokenized_prompts)
        # normalize features
        # text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        image_features_norm = image_features / image_features.norm(dim=1, keepdim=True)

        # Class-Specific Region Feature Aggregation
        # output = 20 * F.conv1d(image_features_norm, text_features[:, :, None])

        output = self.distributed_feature_aggregation(image_features_norm, distributed_text_features)

        b, c, _ = output.shape
        output_half = output[:,  c // 2:]
        w_half = F.softmax(output_half, dim=-1)
        w = torch.cat([w_half, w_half], dim=1)
        output = 5 * (output * w).sum(-1)

        b, c = output.shape
        # convert the shape of logits to [b, 2, num_class]
        logits = output.resize(b, 2, c//2)

        return logits

    @property
    def network_name(self):
        name = ''
        name += 'DualCoop-{}'.format(self.visual_encoder_type)
        return name

    def backbone_params(self):
        params = []
        for name, param in self.named_parameters():
            if "image_encoder" in name and "prompt_learner" not in name and 'attnpool' not in name:
                params.append(param)
        return params

    def attn_params(self):
        params = []
        for name, param in self.named_parameters():
            if 'attnpool' in name and 'image_encoder' in name:
                params.append(param)
        return params

    def prompt_params(self):
        params = []
        for name, param in self.named_parameters():
            if "prompt_learner" in name:
                params.append(param)
        return params
    
    def get_params(self):
        
        params = []
        
        params.extend(self.backbone_params())
        params.extend(self.attn_params())
        params.extend(self.prompt_params())

        return params

    def get_optimizer(self):
        params = self.get_params()
        return torch.optim.SGD(params=params, lr=self.optim_set["base_lr"], weight_decay=0)
    
    def compute_loss(self, pred, target):
        loss = self.loss_fn(pred, target)
        loss_dict = {self.loss_fn.name: loss.item()}
        return loss, loss_dict
    
    def forward(self, data):
        if self.training:
            return self.train_model(data)
        else:
            return self.infer(data)

    def train_model(self, data):
        target = data['t']
        pred = self.infer(data).float()
        return self.compute_loss(pred, target)
    

Based on your code it seems you are using explicit model sharing by moving tensors between devices. Moving a tensor to a device via the to() operation is differentiable and won’t break the computation graph. You can thus calculate the loss and call .backward() on it directly.

1 Like

Thanks for your reply! But there is still a RuntimeError about Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0!. Here is the full terminal output:

Freeze the backbone weights
Freeze the attn weights
Using torch.float16
Initializing a generic context
Initial positive context: "X X X X X X X X X X X X X X X X"
Initial negative  context: "X X X X X X X X X X X X X X X X"
Number of positive context words (tokens): 16
Number of negative context words (tokens): 16
Using dataloader default collate function
Using dataloader default collate function
Using dataloader default collate function
Epoch: 1 / 12 Training:   0%|                                                                                                  | 0/3387 [00:00<?, ?it/s/home/wangxinran/anaconda3/envs/kgva/lib/python3.8/site-packages/torch/_tensor.py:549: UserWarning: non-inplace resize is deprecated
  warnings.warn("non-inplace resize is deprecated")
loss device cuda:0
Epoch: 1 / 12 Training:   0%|                                                                                                  | 0/3387 [00:17<?, ?it/s]
Traceback (most recent call last):
  File "main.py", line 82, in <module>
    main(args.project, task_name, task_setting, mode=args.mode)
  File "main.py", line 31, in main
    task.run()
  File "/home/wangxinran/Papercode/KGVA-ICIP/seal/task/attribute_recognition.py", line 161, in run
    self.train()
  File "/home/wangxinran/Papercode/KGVA-ICIP/seal/task/attribute_recognition.py", line 108, in train
    self.train_util(self.model, self.trainloader, optimizer, epoch, self.train_settings.get_settings()["epochs"], self.device, amp=self.train_settings.get_settings()["amp"])
  File "/home/wangxinran/Papercode/KGVA-ICIP/seal/utils/train_utils.py", line 32, in train_one_epoch
    scaler.scale(loss).backward()  
  File "/home/wangxinran/anaconda3/envs/kgva/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/wangxinran/anaconda3/envs/kgva/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking argument for argument mat2 in method wrapper_mm)

And here is my training function, I print the loss device is cuda:0, and there is a part of learnable prompts and their text features on cuda:3. My troch version is 1.11.0+cu113.

@train_util("train_one_epoch")
def train_one_epoch(model, dataloader, optimizer, epoch, epoch_num, device, amp=True, use_wandb=False):
    if amp:
        scaler = GradScaler()
    pbar = tqdm(dataloader)
    pbar.set_description(f'Epoch: {epoch + 1} / {epoch_num} Training')
    model.train()
    for i, data in enumerate(pbar):
        optimizer.zero_grad()
        data = load_to_device(data, device)
        if amp:
            with autocast():
                loss, loss_dict = model(data)
                print(f"loss device {loss.device}")
                
            scaler.scale(loss).backward()  
            scaler.step(optimizer)  
            scaler.update()  
        else:
            loss, loss_dict = model(data)
            loss.backward()
            optimizer.step()
        lr_dict = {f"param-group-{i}": pg['lr'] for i, pg in enumerate(optimizer.state_dict()['param_groups'])}
        loss_dict.update(lr_dict)
        pbar.set_postfix(loss_dict)
        if use_wandb:
            wandb.log(loss_dict)

And here is my loss class code:

@loss("DualAsymmetricLoss")
class DualAsymmetricLoss(Loss):
    def __init__(
        self, 
        name="DualAsymmetricLoss", 
        gamma_neg=4, gamma_pos=1, 
        clip=0.05, 
        eps=1e-6, 
        disable_torch_grad_focal_loss=True
    ):
        
        super().__init__(name)

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """

        # Calculating Probabilities
        x_softmax = self.softmax(x)
        xs_pos = x_softmax[:, 1, :]
        xs_neg = x_softmax[:, 0, :]
        y = y.reshape(-1)
        xs_pos = xs_pos.reshape(-1)
        xs_neg = xs_neg.reshape(-1)

        xs_pos = xs_pos[y!=2]
        xs_neg = xs_neg[y!=2]
        y = y[y!=2]

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)

        # Basic CE calculation
        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
        loss = los_pos + los_neg
        # import pdb
        # pdb.set_trace()

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            pt0 = xs_pos * y
            pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
            pt = pt0 + pt1
            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            loss *= one_sided_w

        return -loss.sum()

Could you try to isolate the failing part and post a minimal and executable code snippet reproducing the error?

Thank you for your reply, it really helps me. This has been solved, I use DataParallel to distribute the text encoder of CLIP. Which does not explicitly use the to() method to move the text encoder from one device to another. And I do not distribute learnable prompts (which is a nn.Parameter). This ensures the models’ parameters and the loss are on the same device.

def to(self, device):
        model = super().to(device)
        model.text_encoder = model.text_encoder.to(torch.device("cuda"))
        model.text_encoder = nn.DataParallel(model.text_encoder)
        return model

And the updated distributed inference function in the DualCoop class:

def distributed_prompts(self, prompts, tokenized_prompts):
        dist_prompts, dist_tokenized_prompts = [], []
        num_prompt_per_device = prompts.shape[0] // self.num_devices
        num_rest_prompt = prompts.shape[0] - self.num_devices * num_prompt_per_device
        group_bin = [[i * num_prompt_per_device, (i + 1) * num_prompt_per_device] for i in range(self.num_devices)]
        group_bin[-1][-1] += num_rest_prompt
        
        for i, gb in enumerate(group_bin):
            dp = prompts[gb[0]: gb[1]]
            dtp = tokenized_prompts[gb[0]: gb[1]]
            dist_prompts.append(dp)
            dist_tokenized_prompts.append(dtp)     
        
        return dist_prompts, dist_tokenized_prompts 
    
    def distributed_extract_text_features(self, dist_prompts, dist_tokenized_prompts):
        encoder_device = torch.device('cuda:0')
        distributed_text_features = []
        
        for i, (dp, dtp) in enumerate(zip(dist_prompts, dist_tokenized_prompts)):

            dtf = self.text_encoder(dp.to(self.devices[i]), dtp.to(self.devices[i]))
            dtf = dtf / dtf.norm(dim=-1, keepdim=True)

            distributed_text_features.append(dtf)

            dp = dp.to(encoder_device)
            dtp = dtp.to(encoder_device)

        return distributed_text_features

    def distributed_feature_aggregation(self, image_features, distributed_text_features):
        imf_device = image_features
        outputs = []
        for dtf in distributed_text_features:
            dtf_device = dtf.device
            image_features = image_features.to(dtf_device)
            output = 20 * F.conv1d(image_features, dtf[:, :, None]).to(imf_device)
            outputs.append(output)
        
        return torch.cat(outputs, dim=1)