Training encoder decoder model with 4GPUs different device error

Hi everyone.
I’m training encoder + decoder model (efficientnet v2 + transformer)
I’m using 4GPUs server.

My problem
: I’m getting different device error.
My code examples

class Combined model:
    self.__init__():
        self.encoder_model
        self.decoder_model
    
    forward(x):
        x = self.encoder_model(x)
        x = self.decoder_model(x)
        return x


model = Combined_model()
model = torch.nn.DataParallel(model)
model = model.to(device)

It how my logics going. And when I need to create new tensor, I allocated them to same device of input.
(If I’m not allocating them to device manually, it’s on cpu)
I want to know what’s wrong with my logic.
I can show you full code if you need it to answer my question.

Thanks and have a great day

Your current model definition has a few errors and won’t be executable. Could you post a minimal, executable code snippet which would reproduce this issue, please?

Can I upload my git repository? I think it’s too long to post here

I don’t think the full repository might be needed to reproduce the issue so try to create a minimal and executable code snippet which would raise the error.

I’m sorry that I can’t upload excutable code, b.c it has too many custom implementation. I don’t want you to spend more time on understanding whole of my un-clean code.
Img2smiles_net → main model

class Img2smiles_net(nn.Module):
    def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_size, d_model, n_heads, max_len,
                ffn_hidden, n_layers, drop_prob, img_embedding_dim, device):
        super().__init__()
        self.img_encoder = EfficientNetV2('b3',
                        in_channels=3,
                        pretrained=False)
        self.transformer_decoder = Transformer(src_pad_idx, 
                                  trg_pad_idx,
                                  trg_sos_idx,
                                  enc_voc_size, 
                                  dec_voc_size, 
                                  d_model, 
                                  n_heads, 
                                  max_len, 
                                  ffn_hidden,
                                  n_layers,
                                  drop_prob,
                                  img_embedding_dim)
        
    
    def forward(self, img, smiles, mask):
        src = self.img_encoder.forward(img)
        out = self.transformer_decoder(src, smiles, mask)
        return out

train function

def train_one_epoch(epoch, Net, trainloader, validloader, criterion, optimizer):  
    Net.train()
    epoch_loss = 0
    valid_epoch_loss = 0
    for i, data in enumerate(tqdm(trainloader)):
        img, smiles = data
        # optimizer.optimizer.zero_grad()
        optimizer.zero_grad()
        smiles_batch_input = smiles[:, :-1]
        smiles_batch_target = smiles[:, 1:]
        mask = transformer_model_parallel.create_mask(smiles_batch_target)
        out = Net(img, smiles_batch_input, mask)
        # if i % 100 == 0:
        #     images = wandb.Image(img[:5])
        #     pred = convert_2_selfie(torch.argmax(out[:5], axis = -1))
        #     ans = convert_2_selfie(smiles_batch_target[:5])
        #     train_table.add_data(epoch, i//100, images, pred, ans)
        batch_loss = criterion(torch.transpose(out, -2, -1), smiles_batch_target.to(out.device))
        epoch_loss += batch_loss.item()
        batch_loss.backward()
        torch.nn.utils.clip_grad_norm_(Net.parameters(), 10)
        optimizer.step()

And the error message’s like below
I think error message says that there’s something wrong on batch norm while doing some backward operation

Traceback (most recent call last):
  File "eff2_train_model_parallel.py", line 282, in <module>
    epoch_train_loss, epoch_valid_loss = train_one_epoch(epoch, Net, train_loader, val_loader, criterion, custom_opt)
  File "eff2_train_model_parallel.py", line 222, in train_one_epoch
    batch_loss.backward()
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:3! (when checking argument for argument weight in method wrapper__cudnn_batch_norm_backward)
Traceback (most recent call last):
  File "eff2_train_model_parallel.py", line 282, in <module>
    epoch_train_loss, epoch_valid_loss = train_one_epoch(epoch, Net, train_loader, val_loader, criterion, custom_opt)
  File "eff2_train_model_parallel.py", line 222, in train_one_epoch
    batch_loss.backward()
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cuda:3! (when checking argument for argument weight in method wrapper__cudnn_batch_norm_backward)

EfficientNetV2 implementation from
pytorch-efficientnet/efficientnet_v2.py at master · abhuse/pytorch-efficientnet (github.com)

I just changed it’s forward fucntion like below (I only need feature extraction)

def forward(self, x):
        x = self.stem_conv(x)
        x = self.stem_bn(x)
        x = self.stem_act(x)
        feature = None
        feat_idx = 0
        for block_idx, block in enumerate(self.blocks):
            x = block(x)
            if block_idx == self.feature_block_ids[feat_idx]:
                feature = x
                feat_idx += 1
        feature = feature.view(feature.shape[0], feature.shape[1], -1)
        feature = torch.permute(feature, (0,2,1))
        return feature

And changed def forward like below → I think it’s the problem now, I can feed forward without any error, but I think It send BN layer or other layers to specific device, I think I need to check if op is conv2d first, and then send it to specific device

#### Before ###########
def forward(self, x):
        x_device = x.device # added
        inp = x
        for op in self.ops_lst:
            op = op.to(x_device) # added
            x = op(x)
        if self.skip_enabled:
            return x + inp
        else:
            return x
#############################
# Modified
def forward(self, x):
        inp = x
        x_device = x.device 
        for op in self.ops_lst:
            if isinstance(op, nn.Conv2d): #check if layer is conv2d
                op = op.to(x_device)
                x = op(x)
            else:
                x = op(x)
        if self.skip_enabled:
            return x + inp
        else:
            return x

error_message

Traceback (most recent call last):
  File "eff2_train_model_parallel.py", line 282, in <module>
    epoch_train_loss, epoch_valid_loss = train_one_epoch(epoch, Net, train_loader, val_loader, criterion, custom_opt)
  File "eff2_train_model_parallel.py", line 213, in train_one_epoch
    out = Net(img, smiles_batch_input, mask)
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/_utils.py", line 461, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/shared/home/qlql323/img2smiles/model_eff2.py", line 29, in forward
    src = self.img_encoder.forward(img)
  File "/shared/home/qlql323/img2smiles/efficientnet_v2.py", line 712, in forward
    x = block(x)
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/shared/home/qlql323/img2smiles/efficientnet_v2.py", line 393, in forward
    x = op(x)
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 168, in forward
    return F.batch_norm(
  File "/shared/home/qlql323/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 2438, in batch_norm
    return torch.batch_norm(
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_batch_norm)

op is consisted with those layer

Conv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), bias=False)
BatchNorm2d(232, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
DropConnect()
Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False)
BatchNorm2d(1392, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
SiLU()
Conv2d(1392, 1392, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1392, bias=False)
BatchNorm2d(1392, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
SiLU()
SqueezeExcitate(
  (dim_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1))
  (dim_restore): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1))
  (activation): SiLU()
.....
)

Thanks for the follow up. I’m currently unsure what might be causing the issue and would need to debug it. Are you able to recreate the error using the model definition only as well as random inputs? If so, could you post this code snippet so that I could try to reproduce it?

Sample Run Code

# Model config
img_embedding_dim = 232
d_model = 512
ffn_hidden = 2048
n_layers = 4
n_heads = 8
src_pad_idx = 0
trg_pad_idx = 0
trg_sos_idx = 8
enc_voc_size = 232
dec_voc_size = 30
max_len = 46
drop_prob = 0.1
from_ckpt = 0


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import math
import tensorflow as tf
import torchvision
from torchvision import datasets, models, transforms
import torchvision.transforms as T

# Custom implementation
import transformer_model_v2
from transformer_model_v2 import Transformer
from model_eff2 import Img2smiles_net

if 'device' not in globals():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
Net = Img2smiles_net(src_pad_idx, 
                              trg_pad_idx, 
                              trg_sos_idx,
                              enc_voc_size, 
                              dec_voc_size,
                              d_model, 
                              n_heads,
                              max_len, 
                              ffn_hidden,
                              n_layers,
                              drop_prob,
                              img_embedding_dim)

Net = torch.nn.DataParallel(Net)
Net.cuda()

print("loaded model")

print("start feed forward test")
rand_img = torch.randn(8,3,299,299)
trg = torch.randint(30,(8,46)).type(torch.LongTensor)
trg_input = trg[:, :-1]
trg_output = trg[:, 1:]
trg_mask = transformer_model_v2.create_mask(trg_output)
sample_result = Net.forward(rand_img, trg_input, trg_mask)
print(f"sample_result is on {sample_result.device}")
print("forward test done")



print("start backward test")
Net = Img2smiles_net(src_pad_idx, 
                              trg_pad_idx, 
                              trg_sos_idx,
                              enc_voc_size, 
                              dec_voc_size,
                              d_model, 
                              n_heads,
                              max_len, 
                              ffn_hidden,
                              n_layers,
                              drop_prob,
                              img_embedding_dim)

Net = torch.nn.DataParallel(Net)
Net.cuda()
print("re-initialized network")
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.AdamW(Net.parameters(),lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
optimizer.zero_grad()

Net.train()
pred = Net.forward(rand_img, trg_input, trg_mask)
loss = criterion(torch.transpose(pred, -2, -1), trg_output.to(device))
loss.backward()
optimizer.step

print("backward test done")
exit()

custom implementation code

transformer_model_v2

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def make_pad_mask(k):
    mask = ~k.ne(0).unsqueeze(1).unsqueeze(2)
    return mask

def make_no_peak_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal = 1).type(torch.BoolTensor)
    return mask

def create_mask(q):
    size = q.size(1)
    pad_mask = make_pad_mask(q)
    no_peak_mask = make_no_peak_mask(size)
    
    return ~torch.maximum(pad_mask , no_peak_mask)


## Scaled dot attention
class ScaleDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(0.1)
    def forward(self, q, k ,v, mask = None, e = -1e9):
        batch_size, head, length, d_tensor = k.size()

        k_t = k.view(batch_size, head, d_tensor, length)
        score = (q @ k_t)/math.sqrt(d_tensor)
        if mask is not None:
            score = score.masked_fill(mask == 0, e) 

        score = self.softmax(score)
        score = self.dropout(score)
        v = score @ v

        return v

# # Multi-Head Attention

class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(d_model, d_model) # B, len, d_model
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_concat = nn.Linear(d_model, d_model)

    def split_head(self, tensor): 
        batch_size, length, d_model = tensor.size()
        d_tensor = d_model//self.n_head
        tensor = tensor.view(batch_size, self.n_head, length, d_tensor)
        return tensor

    def concat_head(self, tensor):
        batch_size, head, length, d_tensor = tensor.size()
        tensor = tensor.view(batch_size, length, head * d_tensor)
        return tensor

    def forward(self, q,k,v, mask = None):
        q = self.w_q(q)
        k = self.w_k(k)
        v = self.w_v(v)

        q = self.split_head(q)
        k = self.split_head(k)
        v = self.split_head(v)

        out = self.attention(q,k,v, mask = mask)

        out = self.concat_head(out)
        out = self.w_concat(out)

        return out

# # FFN

class PositionwiseFeedForward(nn.Module):
    
    def __init__(self,d_model,hidden,drop_prob = 0.1):
        super(PositionwiseFeedForward,self).__init__()
        self.linear1 = nn.Linear(d_model,hidden)
        self.linear2 = nn.Linear(hidden,d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p = drop_prob)
    
    def forward(self,x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


# # Encoder Layer

class EncoderLayer(nn.Module):
    
    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super(EncoderLayer, self).__init__()

        self.attention = MultiHeadAttention(d_model, n_head)

        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(p = drop_prob)

        self.ffn = PositionwiseFeedForward(d_model = d_model, hidden = ffn_hidden, drop_prob = drop_prob)

        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(p = drop_prob)

    def forward(self, x):
        attn_output = self.attention(q = x, k = x, v = x)
        attn_output = self.dropout1(attn_output)
        out1 = self.norm1(x + attn_output)
        # x = self.dropout1(x)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        out2 = self.norm2(out1 + ffn_output)
        # x = self.dropout2(x)

        return out2

# # Encoder

class Encoder(nn.Module):

    def __init__(self, enc_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob):
        super().__init__()
        self.max_len = max_len
        self.d_model = d_model
        self.embed = nn.Linear(232, d_model)
        self.dropout = nn.Dropout(p = drop_prob)
        self.layers = nn.ModuleList([EncoderLayer(d_model = d_model, ffn_hidden = ffn_hidden, n_head = n_head, drop_prob = drop_prob) for _ in range(n_layers)])
        
    def positional_encoding(self, x,  max_len, d_model):
        encoding = torch.zeros(max_len, d_model, device = x.device)
        encoding.requires_grad = False 

        pos = torch.arange(0, max_len) 
        pos = pos.float().unsqueeze(dim = 1) 

        _2i = torch.arange(0, d_model, step = 2).float() 

        encoding[:, 0::2] = torch.sin(pos/(10000**(_2i/d_model)))
        encoding[:, 1::2] = torch.cos(pos/(10000**(_2i/d_model)))
        batch_size, seq_len, _ = x.size()

        return encoding[:seq_len, :]
        
    def forward(self, x):
        x = self.embed(x)
        pe = self.positional_encoding(x, 100, self.d_model)
        x = x + pe
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x)

        return x


# # Decoder Layer


class DecoderLayer(nn.Module):

    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super(DecoderLayer, self).__init__()

        self.self_attention = MultiHeadAttention(d_model = d_model, n_head = n_head)

        self.norm1 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(p = drop_prob)

        self.enc_dec_attention = MultiHeadAttention(d_model = d_model, n_head = n_head)

        self.norm2 = nn.LayerNorm(d_model)

        self.dropout2 = nn.Dropout(p = drop_prob)

        self.ffn = PositionwiseFeedForward(d_model = d_model, hidden= ffn_hidden, drop_prob= drop_prob)

        self.norm3 = nn.LayerNorm(d_model)

        self.dropout3 = nn.Dropout(p = drop_prob)

    def forward(self, dec, enc, trg_mask):
        attn1 = self.self_attention(q = dec, k = dec, v = dec, mask = trg_mask)
        attn1 = self.dropout1(attn1)
        out1 = self.norm1(attn1 + dec)
        
        attn2 = self.enc_dec_attention(q = out1, k = enc, v = enc)
        attn2 = self.dropout2(attn2)
        out2 = self.norm2(attn2 + out1)

        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output)
        
        out3 = self.norm3(ffn_output + out2)
        # x = self.dropout3(x)

        return out3




# # Decoder Layer

class Decoder(nn.Module):
    def __init__(self,dec_voc_size,max_len,d_model,ffn_hidden,n_head,n_layers,
                drop_prob):
        super().__init__()
        
        #Embedding
        self.embed = nn.Embedding(num_embeddings = dec_voc_size, embedding_dim = d_model)
        
        #Add decoder layers
        self.layers = nn.ModuleList([DecoderLayer(d_model = d_model,
                                                 ffn_hidden = ffn_hidden,
                                                 n_head = n_head,
                                                 drop_prob = drop_prob) for _ in range(n_layers)])
        self.max_len = max_len
        #Linear
        self.linear = nn.Linear(d_model,dec_voc_size)
        self.softmax = nn.Softmax()
        self.d_model = d_model
        
    def positional_encoding(self, x,  max_len, d_model):
        encoding = torch.zeros(max_len, d_model, device = x.device) # (len, d_model) 
        encoding.requires_grad = False 

        pos = torch.arange(0, max_len) 
        pos = pos.float().unsqueeze(dim = 1)

        _2i = torch.arange(0, d_model, step = 2).float() 

        encoding[:, 0::2] = torch.sin(pos/(10000**(_2i/d_model)))
        encoding[:, 1::2] = torch.cos(pos/(10000**(_2i/d_model)))
        batch_size, seq_len, _ = x.size()

        return encoding[:seq_len, :]
    
    def forward(self,trg,src,trg_mask):
        #Compute Embedding
        trg = self.embed(trg)
        trg *= math.sqrt(self.d_model)
        #Get Positional Encoding
        trg_pe = self.positional_encoding(trg, self.max_len, self.d_model)
        
        #Embedding + Positional Encoding
        trg = trg + trg_pe
        
        #Compute Decoder layers
        for layer in self.layers:
            trg = layer(trg,src,trg_mask)
        
        #pass to LM head
        output = self.linear(trg)

        return output

class Transformer(nn.Module):
    
    def __init__(self,src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len,
                ffn_hidden, n_layers, drop_prob, img_embedding_dim):
        super().__init__()
        #Get <PAD> idx
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.trg_sos_idx = trg_sos_idx
        self.dropout = nn.Dropout(p = 0.1)
        #Encoder
        self.encoder = Encoder(enc_voc_size = enc_voc_size,
                              max_len = max_len,
                              d_model = d_model,
                              ffn_hidden = ffn_hidden,
                              n_head = n_head,
                              n_layers = n_layers,
                              drop_prob = drop_prob)
        
        #Decoder
        self.decoder = Decoder(dec_voc_size = dec_voc_size,
                              max_len = max_len,
                              d_model = d_model,
                              ffn_hidden = ffn_hidden,
                              n_head = n_head,
                              n_layers = n_layers,
                              drop_prob = drop_prob)
     
    def forward(self,src,trg, mask):
        src = self.dropout(src)
        enc_src = self.encoder(src)
        output = self.decoder(trg, enc_src, mask)
        
        return output

model_eff2

import torch
import torch.nn as nn
from transformer_model_v2 import Transformer
from efficientnet_v2 import EfficientNetV2

class Img2smiles_net(nn.Module):
    def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_size, d_model, n_heads, max_len,
                ffn_hidden, n_layers, drop_prob, img_embedding_dim):
        super().__init__()
        self.img_encoder = EfficientNetV2('b3',
                        in_channels=3,
                        pretrained=False)
        self.transformer_decoder = Transformer(src_pad_idx, 
                                  trg_pad_idx,
                                  trg_sos_idx,
                                  enc_voc_size, 
                                  dec_voc_size, 
                                  d_model, 
                                  n_heads, 
                                  max_len, 
                                  ffn_hidden,
                                  n_layers,
                                  drop_prob,
                                  img_embedding_dim)
        
    
    def forward(self, img, smiles, mask):
        src = self.img_encoder.forward(img)
        out = self.transformer_decoder(src, smiles, mask)
        return out

efficientnet_v2

import collections.abc as container_abc
from collections import OrderedDict
from math import ceil, floor

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo


def _pair(x):
    if isinstance(x, container_abc.Iterable):
        return x
    return (x, x)


def torch_conv_out_spatial_shape(in_spatial_shape, kernel_size, stride):
    if in_spatial_shape is None:
        return None
    # in_spatial_shape -> [H,W]
    hin, win = _pair(in_spatial_shape)
    kh, kw = _pair(kernel_size)
    sh, sw = _pair(stride)

    # dilation and padding are ignored since they are always fixed in efficientnetV2
    hout = int(floor((hin - kh - 1) / sh + 1))
    wout = int(floor((win - kw - 1) / sw + 1))
    return hout, wout


def get_activation(act_fn: str, **kwargs):
    if act_fn in ('silu', 'swish'):
        return nn.SiLU(**kwargs)
    elif act_fn == 'relu':
        return nn.ReLU(**kwargs)
    elif act_fn == 'relu6':
        return nn.ReLU6(**kwargs)
    elif act_fn == 'elu':
        return nn.ELU(**kwargs)
    elif act_fn == 'leaky_relu':
        return nn.LeakyReLU(**kwargs)
    elif act_fn == 'selu':
        return nn.SELU(**kwargs)
    elif act_fn == 'mish':
        return nn.Mish(**kwargs)
    else:
        raise ValueError('Unsupported act_fn {}'.format(act_fn))


def round_filters(filters, width_coefficient, depth_divisor=8):
    """Round number of filters based on depth multiplier."""
    min_depth = depth_divisor
    filters *= width_coefficient
    new_filters = max(min_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor)
    return int(new_filters)


def round_repeats(repeats, depth_coefficient):
    """Round number of filters based on depth multiplier."""
    return int(ceil(depth_coefficient * repeats))


class DropConnect(nn.Module):
    def __init__(self, rate=0.5):
        super(DropConnect, self).__init__()
        self.keep_prob = None
        self.set_rate(rate)

    def set_rate(self, rate):
        if not 0 <= rate < 1:
            raise ValueError("rate must be 0<=rate<1, got {} instead".format(rate))
        self.keep_prob = 1 - rate

    def forward(self, x):
        if self.training:
            random_tensor = self.keep_prob + torch.rand([x.size(0), 1, 1, 1],
                                                        dtype=x.dtype).to(x.device)
            # ,
            #                                              device=x.device)
            binary_tensor = torch.floor(random_tensor)
            return torch.mul(torch.div(x, self.keep_prob), binary_tensor)
        else:
            return x


class SamePaddingConv2d(nn.Module):
    def __init__(self,
                 in_spatial_shape,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 dilation=1,
                 enforce_in_spatial_shape=False,
                 **kwargs):
        super(SamePaddingConv2d, self).__init__()

        self._in_spatial_shape = _pair(in_spatial_shape)
        # e.g. throw exception if input spatial shape does not match in_spatial_shape
        # when calling self.forward()
        self.enforce_in_spatial_shape = enforce_in_spatial_shape
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        dilation = _pair(dilation)

        in_height, in_width = self._in_spatial_shape
        filter_height, filter_width = kernel_size
        stride_heigth, stride_width = stride
        dilation_height, dilation_width = dilation

        out_height = int(ceil(float(in_height) / float(stride_heigth)))
        out_width = int(ceil(float(in_width) / float(stride_width)))

        pad_along_height = max((out_height - 1) * stride_heigth +
                               filter_height + (filter_height - 1) * (dilation_height - 1) - in_height, 0)
        pad_along_width = max((out_width - 1) * stride_width +
                              filter_width + (filter_width - 1) * (dilation_width - 1) - in_width, 0)

        pad_top = pad_along_height // 2
        pad_bottom = pad_along_height - pad_top
        pad_left = pad_along_width // 2
        pad_right = pad_along_width - pad_left

        paddings = (pad_left, pad_right, pad_top, pad_bottom)
        if any(p > 0 for p in paddings):
            self.zero_pad = nn.ZeroPad2d(paddings)
        else:
            self.zero_pad = None
        self.conv = nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              stride=stride,
                              dilation=dilation,
                              **kwargs)

        self._out_spatial_shape = (out_height, out_width)

    @property
    def out_spatial_shape(self):
        return self._out_spatial_shape

    def check_spatial_shape(self, x):
        if x.size(2) != self._in_spatial_shape[0] or \
                x.size(3) != self._in_spatial_shape[1]:
            raise ValueError(
                "Expected input spatial shape {}, got {} instead".format(self._in_spatial_shape, x.shape[2:]))

    def forward(self, x):
        if self.enforce_in_spatial_shape:
            self.check_spatial_shape(x)
        if self.zero_pad is not None:
            x = self.zero_pad(x)
        x = self.conv(x)
        return x


class SqueezeExcitate(nn.Module):
    def __init__(self,
                 in_channels,
                 se_size,
                 activation=None):
        super(SqueezeExcitate, self).__init__()
        self.dim_reduce = nn.Conv2d(in_channels=in_channels,
                                    out_channels=se_size,
                                    kernel_size=1)
        self.dim_restore = nn.Conv2d(in_channels=se_size,
                                     out_channels=in_channels,
                                     kernel_size=1)
        self.activation = F.relu if activation is None else activation

    def forward(self, x):
        inp = x
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = self.dim_reduce(x)
        x = self.activation(x)
        x = self.dim_restore(x)
        x = torch.sigmoid(x)
        return torch.mul(inp, x)


class MBConvBlockV2(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 expansion_factor,
                 act_fn,
                 act_kwargs=None,
                 bn_epsilon=None,
                 bn_momentum=None,
                 se_size=None,
                 drop_connect_rate=None,
                 bias=False,
                 tf_style_conv=False,
                 in_spatial_shape=None):

        super().__init__()

        if act_kwargs is None:
            act_kwargs = {}
        exp_channels = in_channels * expansion_factor

        self.ops_lst = []

        # expansion convolution
        if expansion_factor != 1:
            self.expand_conv = nn.Conv2d(in_channels=in_channels,
                                         out_channels=exp_channels,
                                         kernel_size=1,
                                         bias=bias)

            self.expand_bn = nn.BatchNorm2d(num_features=exp_channels,
                                            eps=bn_epsilon,
                                            momentum=bn_momentum)

            self.expand_act = get_activation(act_fn, **act_kwargs)
            self.ops_lst.extend([self.expand_conv, self.expand_bn, self.expand_act])

        # depth-wise convolution
        if tf_style_conv:
            self.dp_conv = SamePaddingConv2d(in_spatial_shape=in_spatial_shape,
                                             in_channels=exp_channels,
                                             out_channels=exp_channels,
                                             kernel_size=kernel_size,
                                             stride=stride,
                                             groups=exp_channels,
                                             bias=bias)
            self.out_spatial_shape = self.dp_conv.out_spatial_shape
        else:
            self.dp_conv = nn.Conv2d(in_channels=exp_channels,
                                     out_channels=exp_channels,
                                     kernel_size=kernel_size,
                                     stride=stride,
                                     padding=1,
                                     groups=exp_channels,
                                     bias=bias)
            self.out_spatial_shape = torch_conv_out_spatial_shape(in_spatial_shape, kernel_size, stride)

        self.dp_bn = nn.BatchNorm2d(num_features=exp_channels,
                                    eps=bn_epsilon,
                                    momentum=bn_momentum)

        self.dp_act = get_activation(act_fn, **act_kwargs)
        self.ops_lst.extend([self.dp_conv, self.dp_bn, self.dp_act])

        # Squeeze and Excitate
        if se_size is not None:
            self.se = SqueezeExcitate(exp_channels,
                                      se_size,
                                      activation=get_activation(act_fn, **act_kwargs))
            self.ops_lst.append(self.se)

        # projection layer
        self.project_conv = nn.Conv2d(in_channels=exp_channels,
                                      out_channels=out_channels,
                                      kernel_size=1,
                                      bias=bias)

        self.project_bn = nn.BatchNorm2d(num_features=out_channels,
                                         eps=bn_epsilon,
                                         momentum=bn_momentum)

        # no activation function in projection layer

        self.ops_lst.extend([self.project_conv, self.project_bn])

        self.skip_enabled = in_channels == out_channels and stride == 1

        if self.skip_enabled and drop_connect_rate is not None:
            self.drop_connect = DropConnect(drop_connect_rate)
            self.ops_lst.append(self.drop_connect)

    def forward(self, x):
        x_device = x.device # added
        inp = x
        for op in self.ops_lst:
            op = op.to(x_device)
            x = op(x)
        if self.skip_enabled:
            return x + inp
        else:
            return x


class FusedMBConvBlockV2(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 expansion_factor,
                 act_fn,
                 act_kwargs=None,
                 bn_epsilon=None,
                 bn_momentum=None,
                 se_size=None,
                 drop_connect_rate=None,
                 bias=False,
                 tf_style_conv=False,
                 in_spatial_shape=None):

        super().__init__()

        if act_kwargs is None:
            act_kwargs = {}
        exp_channels = in_channels * expansion_factor

        self.ops_lst = []

        # expansion convolution
        expansion_out_shape = in_spatial_shape
        if expansion_factor != 1:
            if tf_style_conv:
                self.expand_conv = SamePaddingConv2d(in_spatial_shape=in_spatial_shape,
                                                     in_channels=in_channels,
                                                     out_channels=exp_channels,
                                                     kernel_size=kernel_size,
                                                     stride=stride,
                                                     bias=bias)
                expansion_out_shape = self.expand_conv.out_spatial_shape
            else:
                self.expand_conv = nn.Conv2d(in_channels=in_channels,
                                             out_channels=exp_channels,
                                             kernel_size=kernel_size,
                                             padding=1,
                                             stride=stride,
                                             bias=bias)
                expansion_out_shape = torch_conv_out_spatial_shape(in_spatial_shape, kernel_size, stride)

            self.expand_bn = nn.BatchNorm2d(num_features=exp_channels,
                                            eps=bn_epsilon,
                                            momentum=bn_momentum)

            self.expand_act = get_activation(act_fn, **act_kwargs)
            self.ops_lst.extend([self.expand_conv, self.expand_bn, self.expand_act])

        # Squeeze and Excitate
        if se_size is not None:
            self.se = SqueezeExcitate(exp_channels,
                                      se_size,
                                      activation=get_activation(act_fn, **act_kwargs))
            self.ops_lst.append(self.se)

        # projection layer
        kernel_size = 1 if expansion_factor != 1 else kernel_size
        stride = 1 if expansion_factor != 1 else stride
        if tf_style_conv:
            self.project_conv = SamePaddingConv2d(in_spatial_shape=expansion_out_shape,
                                                  in_channels=exp_channels,
                                                  out_channels=out_channels,
                                                  kernel_size=kernel_size,
                                                  stride=stride,
                                                  bias=bias)
            self.out_spatial_shape = self.project_conv.out_spatial_shape
        else:
            self.project_conv = nn.Conv2d(in_channels=exp_channels,
                                          out_channels=out_channels,
                                          kernel_size=kernel_size,
                                          stride=stride,
                                          padding=1 if kernel_size > 1 else 0,
                                          bias=bias)
            self.out_spatial_shape = torch_conv_out_spatial_shape(expansion_out_shape, kernel_size, stride)

        self.project_bn = nn.BatchNorm2d(num_features=out_channels,
                                         eps=bn_epsilon,
                                         momentum=bn_momentum)

        self.ops_lst.extend(
            [self.project_conv, self.project_bn])

        if expansion_factor == 1:
            self.project_act = get_activation(act_fn, **act_kwargs)
            self.ops_lst.append(self.project_act)

        self.skip_enabled = in_channels == out_channels and stride == 1

        if self.skip_enabled and drop_connect_rate is not None:
            self.drop_connect = DropConnect(drop_connect_rate)
            self.ops_lst.append(self.drop_connect)

    def forward(self, x):
        inp = x
        x_device = x.device # added
        for op in self.ops_lst:
            op = op.to(x_device)
            x = op(x)
        if self.skip_enabled:
            return x + inp
        else:
            return x


class EfficientNetV2(nn.Module):
    _models = {'b0': {'num_repeat': [1, 2, 2, 3, 5, 8],
                      'kernel_size': [3, 3, 3, 3, 3, 3],
                      'stride': [1, 2, 2, 2, 1, 2],
                      'expand_ratio': [1, 4, 4, 4, 6, 6],
                      'in_channel': [32, 16, 32, 48, 96, 112],
                      'out_channel': [16, 32, 48, 96, 112, 192],
                      'se_ratio': [None, None, None, 0.25, 0.25, 0.25],
                      'conv_type': [1, 1, 1, 0, 0, 0],
                      'is_feature_stage': [False, True, True, False, True, True],
                      'width_coefficient': 1.0,
                      'depth_coefficient': 1.0,
                      'train_size': 192,
                      'eval_size': 224,
                      'dropout': 0.2,
                      'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVBhWkZRcWNXR3dINmRLP2U9UUI5ZndH/root/content',
                      'model_name': 'efficientnet_v2_b0_21k_ft1k-a91e14c5.pth'},
               'b1': {'num_repeat': [1, 2, 2, 3, 5, 8],
                      'kernel_size': [3, 3, 3, 3, 3, 3],
                      'stride': [1, 2, 2, 2, 1, 2],
                      'expand_ratio': [1, 4, 4, 4, 6, 6],
                      'in_channel': [32, 16, 32, 48, 96, 112],
                      'out_channel': [16, 32, 48, 96, 112, 192],
                      'se_ratio': [None, None, None, 0.25, 0.25, 0.25],
                      'conv_type': [1, 1, 1, 0, 0, 0],
                      'is_feature_stage': [False, True, True, False, True, True],
                      'width_coefficient': 1.0,
                      'depth_coefficient': 1.1,
                      'train_size': 192,
                      'eval_size': 240,
                      'dropout': 0.2,
                      'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVJnVGV5UndSY2J2amwtP2U9dTBiV1lO/root/content',
                      'model_name': 'efficientnet_v2_b1_21k_ft1k-58f4fb47.pth'},
               'b2': {'num_repeat': [1, 2, 2, 3, 5, 8],
                      'kernel_size': [3, 3, 3, 3, 3, 3],
                      'stride': [1, 2, 2, 2, 1, 2],
                      'expand_ratio': [1, 4, 4, 4, 6, 6],
                      'in_channel': [32, 16, 32, 48, 96, 112],
                      'out_channel': [16, 32, 48, 96, 112, 192],
                      'se_ratio': [None, None, None, 0.25, 0.25, 0.25],
                      'conv_type': [1, 1, 1, 0, 0, 0],
                      'is_feature_stage': [False, True, True, False, True, True],
                      'width_coefficient': 1.1,
                      'depth_coefficient': 1.2,
                      'train_size': 208,
                      'eval_size': 260,
                      'dropout': 0.3,
                      'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVY4M2NySVFZbU41X0tGP2U9ZERZVmxK/root/content',
                      'model_name': 'efficientnet_v2_b2_21k_ft1k-db4ac0ee.pth'},
               'b3': {'num_repeat': [1, 2, 2, 3, 5, 8],
                      'kernel_size': [3, 3, 3, 3, 3, 3],
                      'stride': [1, 2, 2, 2, 1, 2],
                      'expand_ratio': [1, 4, 4, 4, 6, 6],
                      'in_channel': [32, 16, 32, 48, 96, 112],
                      'out_channel': [16, 32, 48, 96, 112, 192],
                      'se_ratio': [None, None, None, 0.25, 0.25, 0.25],
                      'conv_type': [1, 1, 1, 0, 0, 0],
                      'is_feature_stage': [False, True, True, False, True, True],
                      'width_coefficient': 1.2,
                      'depth_coefficient': 1.4,
                      'train_size': 240,
                      'eval_size': 300,
                      'dropout': 0.3,
                      'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVpkamdZUzhhaDdtTTZLP2U9anA4VWN2/root/content',
                      'model_name': 'efficientnet_v2_b3_21k_ft1k-3da5874c.pth'},
               's': {'num_repeat': [2, 4, 4, 6, 9, 15],
                     'kernel_size': [3, 3, 3, 3, 3, 3],
                     'stride': [1, 2, 2, 2, 1, 2],
                     'expand_ratio': [1, 4, 4, 4, 6, 6],
                     'in_channel': [24, 24, 48, 64, 128, 160],
                     'out_channel': [24, 48, 64, 128, 160, 256],
                     'se_ratio': [None, None, None, 0.25, 0.25, 0.25],
                     'conv_type': [1, 1, 1, 0, 0, 0],
                     'is_feature_stage': [False, True, True, False, True, True],
                     'width_coefficient': 1.0,
                     'depth_coefficient': 1.0,
                     'train_size': 300,
                     'eval_size': 384,
                     'dropout': 0.2,
                     'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmllbFF5VWJOZzd0cmhBbm8/root/content',
                     'model_name': 'efficientnet_v2_s_21k_ft1k-dbb43f38.pth'},
               'm': {'num_repeat': [3, 5, 5, 7, 14, 18, 5],
                     'kernel_size': [3, 3, 3, 3, 3, 3, 3],
                     'stride': [1, 2, 2, 2, 1, 2, 1],
                     'expand_ratio': [1, 4, 4, 4, 6, 6, 6],
                     'in_channel': [24, 24, 48, 80, 160, 176, 304],
                     'out_channel': [24, 48, 80, 160, 176, 304, 512],
                     'se_ratio': [None, None, None, 0.25, 0.25, 0.25, 0.25],
                     'conv_type': [1, 1, 1, 0, 0, 0, 0],
                     'is_feature_stage': [False, True, True, False, True, False, True],
                     'width_coefficient': 1.0,
                     'depth_coefficient': 1.0,
                     'train_size': 384,
                     'eval_size': 480,
                     'dropout': 0.3,
                     'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmllN1ZDazRFb0o1bnlyNUE/root/content',
                     'model_name': 'efficientnet_v2_m_21k_ft1k-da8e56c0.pth'},
               'l': {'num_repeat': [4, 7, 7, 10, 19, 25, 7],
                     'kernel_size': [3, 3, 3, 3, 3, 3, 3],
                     'stride': [1, 2, 2, 2, 1, 2, 1],
                     'expand_ratio': [1, 4, 4, 4, 6, 6, 6],
                     'in_channel': [32, 32, 64, 96, 192, 224, 384],
                     'out_channel': [32, 64, 96, 192, 224, 384, 640],
                     'se_ratio': [None, None, None, 0.25, 0.25, 0.25, 0.25],
                     'conv_type': [1, 1, 1, 0, 0, 0, 0],
                     'is_feature_stage': [False, True, True, False, True, False, True],
                     'feature_stages': [1, 2, 4, 6],
                     'width_coefficient': 1.0,
                     'depth_coefficient': 1.0,
                     'train_size': 384,
                     'eval_size': 480,
                     'dropout': 0.4,
                     'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlmcmIyRHEtQTBhUTBhWVE/root/content',
                     'model_name': 'efficientnet_v2_l_21k_ft1k-08121eee.pth'},
               'xl': {'num_repeat': [4, 8, 8, 16, 24, 32, 8],
                      'kernel_size': [3, 3, 3, 3, 3, 3, 3],
                      'stride': [1, 2, 2, 2, 1, 2, 1],
                      'expand_ratio': [1, 4, 4, 4, 6, 6, 6],
                      'in_channel': [32, 32, 64, 96, 192, 256, 512],
                      'out_channel': [32, 64, 96, 192, 256, 512, 640],
                      'se_ratio': [None, None, None, 0.25, 0.25, 0.25, 0.25],
                      'conv_type': [1, 1, 1, 0, 0, 0, 0],
                      'is_feature_stage': [False, True, True, False, True, False, True],
                      'feature_stages': [1, 2, 4, 6],
                      'width_coefficient': 1.0,
                      'depth_coefficient': 1.0,
                      'train_size': 384,
                      'eval_size': 512,
                      'dropout': 0.4,
                      'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlmVXQtRHJLa21taUkxWkE/root/content',
                      'model_name': 'efficientnet_v2_xl_21k_ft1k-1fcc9744.pth'}}

    def __init__(self,
                 model_name,
                 in_channels=3,
                 n_classes=1000,
                 tf_style_conv=False,
                 in_spatial_shape=None,
                 activation='silu',
                 activation_kwargs=None,
                 bias=False,
                 drop_connect_rate=0.2,
                 dropout_rate=None,
                 bn_epsilon=1e-3,
                 bn_momentum=0.01,
                 pretrained=False,
                 progress=False,
                 ):
        super().__init__()

        self.blocks = nn.ModuleList()
        self.model_name = model_name
        self.cfg = self._models[model_name]

        if tf_style_conv and in_spatial_shape is None:
            in_spatial_shape = self.cfg['eval_size']

        activation_kwargs = {} if activation_kwargs is None else activation_kwargs
        dropout_rate = self.cfg['dropout'] if dropout_rate is None else dropout_rate
        _input_ch = in_channels

        self.feature_block_ids = []

        # stem
        if tf_style_conv:
            self.stem_conv = SamePaddingConv2d(
                in_spatial_shape=in_spatial_shape,
                in_channels=in_channels,
                out_channels=round_filters(self.cfg['in_channel'][0], self.cfg['width_coefficient']),
                kernel_size=3,
                stride=2,
                bias=bias
            )
            in_spatial_shape = self.stem_conv.out_spatial_shape
        else:
            self.stem_conv = nn.Conv2d(
                in_channels=in_channels,
                out_channels=round_filters(self.cfg['in_channel'][0], self.cfg['width_coefficient']),
                kernel_size=3,
                stride=2,
                padding=1,
                bias=bias
            )

        self.stem_bn = nn.BatchNorm2d(
            num_features=round_filters(self.cfg['in_channel'][0], self.cfg['width_coefficient']),
            eps=bn_epsilon,
            momentum=bn_momentum)

        self.stem_act = get_activation(activation, **activation_kwargs)

        drop_connect_rates = self.get_dropconnect_rates(drop_connect_rate)

        stages = zip(*[self.cfg[x] for x in
                       ['num_repeat', 'kernel_size', 'stride', 'expand_ratio', 'in_channel', 'out_channel', 'se_ratio',
                        'conv_type', 'is_feature_stage']])

        idx = 0

        for stage_args in stages:
            (num_repeat, kernel_size, stride, expand_ratio,
             in_channels, out_channels, se_ratio, conv_type, is_feature_stage) = stage_args

            in_channels = round_filters(
                in_channels, self.cfg['width_coefficient'])
            out_channels = round_filters(
                out_channels, self.cfg['width_coefficient'])
            num_repeat = round_repeats(
                num_repeat, self.cfg['depth_coefficient'])

            conv_block = MBConvBlockV2 if conv_type == 0 else FusedMBConvBlockV2

            for _ in range(num_repeat):
                se_size = None if se_ratio is None else max(1, int(in_channels * se_ratio))
                _b = conv_block(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=kernel_size,
                                stride=stride,
                                expansion_factor=expand_ratio,
                                act_fn=activation,
                                act_kwargs=activation_kwargs,
                                bn_epsilon=bn_epsilon,
                                bn_momentum=bn_momentum,
                                se_size=se_size,
                                drop_connect_rate=drop_connect_rates[idx],
                                bias=bias,
                                tf_style_conv=tf_style_conv,
                                in_spatial_shape=in_spatial_shape
                                )
                self.blocks.append(_b)
                idx += 1
                if tf_style_conv:
                    in_spatial_shape = _b.out_spatial_shape
                in_channels = out_channels
                stride = 1

            if is_feature_stage:
                self.feature_block_ids.append(idx - 1)

        head_conv_out_channels = round_filters(1280, self.cfg['width_coefficient'])

        self.head_conv = nn.Conv2d(in_channels=in_channels,
                                   out_channels=head_conv_out_channels,
                                   kernel_size=1,
                                   bias=bias)
        self.head_bn = nn.BatchNorm2d(num_features=head_conv_out_channels,
                                      eps=bn_epsilon,
                                      momentum=bn_momentum)
        self.head_act = get_activation(activation, **activation_kwargs)

        self.dropout = nn.Dropout(p=dropout_rate)

        self.avpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(head_conv_out_channels, n_classes)

        if pretrained:
            self._load_state(_input_ch, n_classes, progress, tf_style_conv)

        return

    def _load_state(self, in_channels, n_classes, progress, tf_style_conv):
        state_dict = model_zoo.load_url(self.cfg['weight_url'],
                                        progress=progress,
                                        file_name=self.cfg['model_name'])

        strict = True

        if not tf_style_conv:
            state_dict = OrderedDict(
                [(k.replace('.conv.', '.'), v) if '.conv.' in k else (k, v) for k, v in state_dict.items()])

        if in_channels != 3:
            if tf_style_conv:
                state_dict.pop('stem_conv.conv.weight')
            else:
                state_dict.pop('stem_conv.weight')
            strict = False

        if n_classes != 1000:
            state_dict.pop('fc.weight')
            state_dict.pop('fc.bias')
            strict = False

        self.load_state_dict(state_dict, strict=strict)
        print("Model weights loaded successfully.")

    def get_dropconnect_rates(self, drop_connect_rate):
        nr = self.cfg['num_repeat']
        dc = self.cfg['depth_coefficient']
        total = sum(round_repeats(nr[i], dc) for i in range(len(nr)))
        return [drop_connect_rate * i / total for i in range(total)]

    def get_features(self, x):
        x = self.stem_act(self.stem_bn(self.stem_conv(x)))
        
        feature = None
        feat_idx = 0
        for block_idx, block in enumerate(self.blocks):
            x = block(x)
            if block_idx == self.feature_block_ids[feat_idx]:
                feature = x
                feat_idx += 1

        return feature

    def forward(self, x):
        x = self.stem_conv(x)
        x = self.stem_bn(x)
        x = self.stem_act(x)
        
        feature = None
        feat_idx = 0
        for block_idx, block in enumerate(self.blocks):
            x = block(x)
            if block_idx == self.feature_block_ids[feat_idx]:
                feature = x
                feat_idx += 1
        feature = feature.view(feature.shape[0], feature.shape[1], -1)
        feature = torch.permute(feature, (0,2,1))
        return feature

Those are the my custom implementation! Please tell me if you have any problem on running my sample code.