RuntimeError: CUDA out of memory?

The same code , with the same setup, with the same torch version, was running fine 2 days ago. Today I tried to rerun the code I got this error.
I’m using colab

/content/WS_DAN_PyTorch-master
Dataset Name:car, Train:[223927], Val:[61159]
Batch Size:[12], Total:::Train Batches:[18661],Val Batches:[5097]
Namespace(action='train', alpha=0.95, batch_size=12, checkpoint_path='checkpoint/car', dataset='car', epochs=80, gpu_ids='0', image_size=512, input_size=448, lr=0.001, model_name='inception', momentum=0.9, multi_gpu=True, optim='sgd', parts=32, print_freq=100, resume='', scheduler='step', use_gpu=True, weight_decay=1e-05, workers=4)
/usr/local/lib/python3.6/site-packages/torch/optim/lr_scheduler.py:123: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
Start epoch 0 ==========,lr=0.001000
/usr/local/lib/python3.6/site-packages/torch/nn/functional.py:3121: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  "See the documentation of nn.Upsample for details.".format(mode))
[W TensorIterator.cpp:924] Warning: Mixed memory format inputs detected while calling the operator. The operator will output channels_last tensor even if some of the inputs are not in channels_last format. (function operator())
[W TensorIterator.cpp:918] Warning: Mixed memory format inputs detected while calling the operator. The operator will output contiguous tensor even if some of the inputs are in channels_last format. (function operator())
Epoch: [0][0/18661]	Time 4.517 (4.517)	Data 1.092 (1.092)	Loss 10.3476 (10.3476)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
loss1,loss2,loss3,feature_center_loss 9.355782508850098 9.292245864868164 9.394888877868652 0.9999999403953552
Traceback (most recent call last):
  File "train_bap.py", line 210, in <module>
    train()
  File "train_bap.py", line 145, in train
    train_prec, train_loss = engine.train(state, e)
  File "/content/WS_DAN_PyTorch-master/utils/engine.py", line 58, in train
    _, _, output3 = model(img_crop)
  File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/content/WS_DAN_PyTorch-master/model/inception_bap.py", line 185, in forward
    ftm = self.Mixed_6e(x) #N x 768 x 17 x 17
  File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/content/WS_DAN_PyTorch-master/model/inception_bap.py", line 296, in forward
    branch7x7 = self.branch7x7_1(x)
  File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/content/WS_DAN_PyTorch-master/model/inception_bap.py", line 422, in forward
    x = self.bn(x)
  File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 136, in forward
    self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
  File "/usr/local/lib/python3.6/site-packages/torch/nn/functional.py", line 2016, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 11.17 GiB total capacity; 10.49 GiB already allocated; 9.81 MiB free; 10.80 GiB reserved in total by PyTorch)

engine code:

def train(self,state,epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        config = state['config']
        print_freq = config.print_freq
        model = state['model']
        criterion = state['criterion']
        optimizer = state['optimizer']
        train_loader = state['train_loader']
        model.train()
        end = time.time()
        for i, (img, label) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            target = label.cuda()
            input = img.cuda()
            # compute output
            attention_maps, raw_features, output1 = model(input)
            features = raw_features.reshape(raw_features.shape[0], -1)

            feature_center_loss, center_diff = calculate_pooling_center_loss(
                features, state['center'], target, alfa=config.alpha)

            # update model.centers
            state['center'][target] += center_diff

            # compute refined loss
            # img_drop = attention_drop(attention_maps,input)
            # img_crop = attention_crop(attention_maps, input)
            img_crop, img_drop = attention_crop_drop(attention_maps, input)
            _, _, output2 = model(img_drop)
            _, _, output3 = model(img_crop)

Thanks for your time.

I guess this is the culprit, you are addingt he graph here as well. try using the .item() on it and you should be fine.

This doesn’t solve the issue. Maybe I missplaced the .items(). Where do I put it exaclty?

Check the attention_maps and raw_features, as well. If they are not scalers, and they do not participate in the training (contribute to it, and you are just using them as pure outputs) considert doing a detach() on them as well.

I need them for the training process. The probelm is the same code was functional! I don’t know what happened. Does the version of pytorch has to do with this problem?

I saw that reducing the batch size may be a solution, however , I don"t know how to do that

[W TensorIterator.cpp:924] Warning: Mixed memory format inputs detected while calling the operator. The operator will output channels_last tensor even if some of the inputs are not in channels_last format. (function operator())
[W TensorIterator.cpp:918] Warning: Mixed memory format inputs detected while calling the operator. The operator will output contiguous tensor even if some of the inputs are in channels_last format. (function operator())

I think that this warning is causing the problem… It wasn’t there for my previous execution. I tried to lower the batchsize but the problem persists


I think that I have a problem with the cuda memory?

%cd /content/WS_DAN_PyTorch-master
!python train_bap.py train\
    --model-name inception \
    --batch-size 8 \
    --dataset car \
    --image-size 512 \
    --input-size 448 \
    --checkpoint-path checkpoint/car \
    --optim sgd \
    --scheduler step \
    --lr 0.001 \
    --momentum 0.9 \
    --weight-decay 1e-5 \
    --workers 4 \
    --parts 32 \
    --epochs 80 \
    --use-gpu \
    --multi-gpu \
    --gpu-ids 0 \

does reducing any of this parameter solve the problem?


I reduced the batch size to 8 it seems to work here, but I’m afraid it will take 13h to train 1 epoch. Any advice?

There seem to be multiple issues in this topic, so I’ll try to address them separately:

  • If your code was running fine and suddenly runs out of memory without any software or code changes, you should check, if the GPU is empty or if another process is using memory via nvidia-smi.
  • Are you using the memory_format=torch.channels_last somewhere in your code and if so, could you post this code snippet so that we can have a look at the warning?
  • If you haven’t allocated any tensors on the GPU, torch.cuda.memory_summary() is expected to output empty allocations.
2 Likes

Thanks for the reply,
the code is running now (with 12 batch size), I reduced the input size and the image size to 224 but I’m worried if this will affect the performance of my mode.
For the second point I’m not sure of that I’m going to post the code snippet here. But the warnings has disappeared when I used pytroch 1.5… with pytoch 1.6 it gave me another type of error due to the function torch.where() missing 2 required positional argument: "input", "other".
attention script(.py) :

############################################################
#   File: attention.py                                     #
#   Created: 2019-11-05 19:19:08                           #
#   Author : wvinzh                                        #
#   Email : wvinzh@qq.com                                  #
#   ------------------------------------------             #
#   Description:attention.py                               #
#   Copyright@2019 wvinzh, HUST                            #
############################################################

import numpy as np
import random
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import time


def attention_crop(attention_maps,input_image):
    
    # start = time.time()
    B,N,W,H = input_image.shape
    input_tensor = input_image
    batch_size, num_parts, height, width = attention_maps.shape
    attention_maps = torch.nn.functional.interpolate(attention_maps.detach(),size=(W,H),mode='bilinear')
    part_weights = F.avg_pool2d(attention_maps,(W,H)).reshape(batch_size,-1)
    part_weights = torch.add(torch.sqrt(part_weights),1e-12)
    part_weights = torch.div(part_weights,torch.sum(part_weights,dim=1).unsqueeze(1)).cpu()
    part_weights = part_weights.numpy()
    ret_imgs = []
    # print(part_weights[3])
    for i in range(batch_size):
        attention_map = attention_maps[i]
        part_weight = part_weights[i]
        selected_index = np.random.choice(
            np.arange(0, num_parts), 1, p=part_weight)[0]
        mask = attention_map[selected_index, :, :]
        # print(type(mask))
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        threshold = random.uniform(0.4, 0.6)
        # threshold = 0.5
        # itemindex = np.where(mask >= threshold)
        itemindex = np.where(mask >= mask.max() * threshold)

        # itemindex = torch.nonzero(mask >= threshold)
        padding_h = int(0.1*H)
        padding_w = int(0.1*W)
        height_min = itemindex[0].min()
        height_min = max(0,height_min-padding_h)
        height_max = itemindex[0].max() + padding_h
        width_min = itemindex[1].min()
        width_min = max(0,width_min-padding_w)
        width_max = itemindex[1].max() + padding_w
        out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0)
        out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True)
        out_img = out_img.squeeze(0)
        # print(out_img.shape)
        ret_imgs.append(out_img)
    ret_imgs = torch.stack(ret_imgs)
    return ret_imgs


def attention_drop(attention_maps,input_image):
    B,N,W,H = input_image.shape
    input_tensor = input_image
    batch_size, num_parts, height, width = attention_maps.shape
    attention_maps = torch.nn.functional.interpolate(attention_maps.detach(),size=(W,H),mode='bilinear')
    part_weights = F.avg_pool2d(attention_maps,(W,H)).reshape(batch_size,-1)
    part_weights = torch.add(torch.sqrt(part_weights),1e-12)
    part_weights = torch.div(part_weights,torch.sum(part_weights,dim=1).unsqueeze(1)).cpu().numpy()
    # attention_maps = torch.nn.functional.interpolate(attention_maps,size=(W,H),mode='bilinear', align_corners=True)
    # print(part_weights.shape)
    masks = []
    for i in range(batch_size):
        attention_map = attention_maps[i].detach()
        part_weight = part_weights[i]
        selected_index = np.random.choice(
            np.arange(0, num_parts), 1, p=part_weight)[0]
        mask = attention_map[selected_index:selected_index + 1, :, :]

        # soft mask
        # threshold = random.uniform(0.2, 0.5)
        # threshold = 0.5
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        # mask = (mask < threshold).float()
        threshold = random.uniform(0.2, 0.5)
        mask = (mask < threshold * mask.max()).float()
        masks.append(mask)
    masks = torch.stack(masks)
    # print(masks.shape)
    ret = input_tensor*masks
    return ret

def attention_crop_drop(attention_maps,input_image):
    # start = time.time()
    B,N,W,H = input_image.shape
    input_tensor = input_image
    batch_size, num_parts, height, width = attention_maps.shape
    attention_maps = torch.nn.functional.interpolate(attention_maps.detach(),size=(W,H),mode='bilinear')
    part_weights = F.avg_pool2d(attention_maps.detach(),(W,H)).reshape(batch_size,-1)
    part_weights = torch.add(torch.sqrt(part_weights),1e-12)
    part_weights = torch.div(part_weights,torch.sum(part_weights,dim=1).unsqueeze(1)).cpu()
    part_weights = part_weights.numpy()
    # print(part_weights.shape)
    ret_imgs = []
    masks = []
    # print(part_weights[3])
    for i in range(batch_size):
        attention_map = attention_maps[i]
        part_weight = part_weights[i]
        selected_index = np.random.choice(np.arange(0, num_parts), 1, p=part_weight)[0]
        selected_index2 = np.random.choice(np.arange(0, num_parts), 1, p=part_weight)[0]
        ## create crop imgs
        mask = attention_map[selected_index, :, :]
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        threshold = random.uniform(0.4, 0.6)
        # threshold = 0.5
        itemindex = torch.where(mask >= mask.max()*threshold)
        # print(itemindex.shape)
        # itemindex = torch.nonzero(mask >= threshold*mask.max())
        padding_h = int(0.1*H)
        padding_w = int(0.1*W)
        height_min = itemindex[0].min()
        height_min = max(0,height_min-padding_h)
        height_max = itemindex[0].max() + padding_h
        width_min = itemindex[1].min()
        width_min = max(0,width_min-padding_w)
        width_max = itemindex[1].max() + padding_w
        # print('numpy',height_min,height_max,width_min,width_max)
        out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0)
        out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True)
        out_img = out_img.squeeze(0)
        ret_imgs.append(out_img)

        ## create drop imgs
        mask2 = attention_map[selected_index2:selected_index2 + 1, :, :]
        threshold = random.uniform(0.2, 0.5)
        mask2 = (mask2 < threshold * mask2.max()).float()
        masks.append(mask2)
    # bboxes = np.asarray(bboxes, np.float32)
    crop_imgs = torch.stack(ret_imgs)
    masks = torch.stack(masks)
    drop_imgs = input_tensor*masks
    return (crop_imgs,drop_imgs)

def mask2bbox(attention_maps,input_image):
    input_tensor = input_image
    B,C,H,W = input_tensor.shape
    batch_size, num_parts, Hh, Ww = attention_maps.shape
    attention_maps = torch.nn.functional.interpolate(attention_maps,size=(W,H),mode='bilinear')
    ret_imgs = []
    # print(part_weights[3])
    for i in range(batch_size):
        attention_map = attention_maps[i]
        # print(attention_map.shape)
        mask = attention_map.mean(dim=0)
        # print(type(mask))
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        # threshold = random.uniform(0.4, 0.6)
        threshold = 0.1
        max_activate = mask.max()
        min_activate = threshold * max_activate
        itemindex = torch.nonzero(mask >= min_activate)

        padding_h = int(0.05*H)
        padding_w = int(0.05*W)
        height_min = itemindex[:, 0].min()
        height_min = max(0,height_min-padding_h)
        height_max = itemindex[:, 0].max() + padding_h
        width_min = itemindex[:, 1].min()
        width_min = max(0,width_min-padding_w)
        width_max = itemindex[:, 1].max() + padding_w
        # print(height_min,height_max,width_min,width_max)
        out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0)
        out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True)
        out_img = out_img.squeeze(0)
        # print(out_img.shape)
        ret_imgs.append(out_img)
    ret_imgs = torch.stack(ret_imgs)
    # print(ret_imgs.shape)
    return ret_imgs

def calculate_pooling_center_loss(features, centers, label, alfa=0.95):
    # centers = model.centers
    # print('111111111',sum(sum(centers)))
    # mse_loss = torch.nn.MSELoss()
    features = features.reshape(features.shape[0], -1)
    # print(features.shape)
    centers_batch = centers[label]
    # print(centers_batch)
    # print(centers_batch.shape,centers.shape)
    centers_batch = torch.nn.functional.normalize(centers_batch, dim=-1)
    diff =  (1-alfa)*(features.detach() - centers_batch)
    distance = torch.pow(features - centers_batch,2)
    distance = torch.sum(distance, dim=-1)
    center_loss = torch.mean(distance)
    # loss2 = mse_loss(features,centers_batch)
    # print('================',center_loss.item(),loss2.item())
    return center_loss, diff

def attention_crop_drop2(attention_maps,input_image):
    # start = time.time()
    B,N,W,H = input_image.shape
    input_tensor = input_image
    batch_size, num_parts, height, width = attention_maps.shape
    attention_maps = torch.nn.functional.interpolate(attention_maps.detach(),size=(W,H),mode='bilinear')
    part_weights = F.avg_pool2d(attention_maps.detach(),(W,H)).reshape(batch_size,-1)
    part_weights = torch.add(torch.sqrt(part_weights),1e-12)
    part_weights = torch.div(part_weights,torch.sum(part_weights,dim=1).unsqueeze(1)).cpu()
    part_weights = part_weights.numpy()
    # print(part_weights.shape)
    ret_imgs = []
    masks = []
    # print(part_weights[3])
    for i in range(batch_size):
        attention_map = attention_maps[i]
        part_weight = part_weights[i]
        selected_index = np.random.choice(np.arange(0, num_parts), 1, p=part_weight)[0]
        selected_index2 = np.random.choice(np.arange(0, num_parts), 1, p=part_weight)[0]
        ## create crop imgs
        mask = attention_map[selected_index, :, :]
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        threshold = random.uniform(0.4, 0.6)
        # threshold = 0.5
        # itemindex = np.where(mask >= mask.max()*threshold)
        # print(itemindex.shape)
        itemindex = torch.nonzero(mask >= threshold*mask.max())
        padding_h = int(0.1*H)
        padding_w = int(0.1*W)
        height_min = itemindex[:,0].min()
        height_min = max(0,height_min-padding_h)
        height_max = itemindex[:,0].max() + padding_h
        width_min = itemindex[:,1].min()
        width_min = max(0,width_min-padding_w)
        width_max = itemindex[:,1].max() + padding_w
        # print(height_min,height_max,width_min,width_max)
        out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0)
        out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True)
        out_img = out_img.squeeze(0)
        ret_imgs.append(out_img)

        ## create drop imgs
        mask2 = attention_map[selected_index2:selected_index2 + 1, :, :]
        threshold = random.uniform(0.2, 0.5)
        mask2 = (mask2 < threshold * mask2.max()).float()
        masks.append(mask2)
    # bboxes = np.asarray(bboxes, np.float32)
    crop_imgs = torch.stack(ret_imgs)
    masks = torch.stack(masks)
    drop_imgs = input_tensor*masks
    return (crop_imgs,drop_imgs)






if __name__ == '__main__':
    import torch
    a = torch.rand(4*26*26*32).reshape(4, 32, 26, 26)
    # a = torch.Tensor((4, 32, 26, 26))
    img = torch.arange(4*3*448*448.0).reshape(4, 3, 448, 448)
    # a = torch.arange(4*1*1*8.0).reshape(4, 8, 1, 1)
    # b = torch.ones(10*1*1*8).reshape(10, 8)
    # label = torch.LongTensor([1, 2, 3, 4])
    # a = torch.div(a,4*26*26*8)
    # ret = attention_drop2(a,img)
    ret1 = attention_crop_drop(a,img)
    ret2 = attention_crop_drop2(a,img)
    # ret2 = attention_crop2(a,img)
    # ret = calculate_pooling_center_loss(a, b, label)
    # print(ret)
    # print(ret.shape,ret2.shape)
    # print(type(ret),type(ret2))

Engine Script :

############################################################
#   File: attention.py                                     #
#   Created: 2019-11-05 19:19:08                           #
#   Author : wvinzh                                        #
#   Email : wvinzh@qq.com                                  #
#   ------------------------------------------             #
#   Description:attention.py                               #
#   Copyright@2019 wvinzh, HUST                            #
############################################################

import numpy as np
import random
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import time


def attention_crop(attention_maps,input_image):
    
    # start = time.time()
    B,N,W,H = input_image.shape
    input_tensor = input_image
    batch_size, num_parts, height, width = attention_maps.shape
    attention_maps = torch.nn.functional.interpolate(attention_maps.detach(),size=(W,H),mode='bilinear')
    part_weights = F.avg_pool2d(attention_maps,(W,H)).reshape(batch_size,-1)
    part_weights = torch.add(torch.sqrt(part_weights),1e-12)
    part_weights = torch.div(part_weights,torch.sum(part_weights,dim=1).unsqueeze(1)).cpu()
    part_weights = part_weights.numpy()
    ret_imgs = []
    # print(part_weights[3])
    for i in range(batch_size):
        attention_map = attention_maps[i]
        part_weight = part_weights[i]
        selected_index = np.random.choice(
            np.arange(0, num_parts), 1, p=part_weight)[0]
        mask = attention_map[selected_index, :, :]
        # print(type(mask))
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        threshold = random.uniform(0.4, 0.6)
        # threshold = 0.5
        # itemindex = np.where(mask >= threshold)
        itemindex = np.where(mask >= mask.max() * threshold)

        # itemindex = torch.nonzero(mask >= threshold)
        padding_h = int(0.1*H)
        padding_w = int(0.1*W)
        height_min = itemindex[0].min()
        height_min = max(0,height_min-padding_h)
        height_max = itemindex[0].max() + padding_h
        width_min = itemindex[1].min()
        width_min = max(0,width_min-padding_w)
        width_max = itemindex[1].max() + padding_w
        out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0)
        out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True)
        out_img = out_img.squeeze(0)
        # print(out_img.shape)
        ret_imgs.append(out_img)
    ret_imgs = torch.stack(ret_imgs)
    return ret_imgs


def attention_drop(attention_maps,input_image):
    B,N,W,H = input_image.shape
    input_tensor = input_image
    batch_size, num_parts, height, width = attention_maps.shape
    attention_maps = torch.nn.functional.interpolate(attention_maps.detach(),size=(W,H),mode='bilinear')
    part_weights = F.avg_pool2d(attention_maps,(W,H)).reshape(batch_size,-1)
    part_weights = torch.add(torch.sqrt(part_weights),1e-12)
    part_weights = torch.div(part_weights,torch.sum(part_weights,dim=1).unsqueeze(1)).cpu().numpy()
    # attention_maps = torch.nn.functional.interpolate(attention_maps,size=(W,H),mode='bilinear', align_corners=True)
    # print(part_weights.shape)
    masks = []
    for i in range(batch_size):
        attention_map = attention_maps[i].detach()
        part_weight = part_weights[i]
        selected_index = np.random.choice(
            np.arange(0, num_parts), 1, p=part_weight)[0]
        mask = attention_map[selected_index:selected_index + 1, :, :]

        # soft mask
        # threshold = random.uniform(0.2, 0.5)
        # threshold = 0.5
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        # mask = (mask < threshold).float()
        threshold = random.uniform(0.2, 0.5)
        mask = (mask < threshold * mask.max()).float()
        masks.append(mask)
    masks = torch.stack(masks)
    # print(masks.shape)
    ret = input_tensor*masks
    return ret

def attention_crop_drop(attention_maps,input_image):
    # start = time.time()
    B,N,W,H = input_image.shape
    input_tensor = input_image
    batch_size, num_parts, height, width = attention_maps.shape
    attention_maps = torch.nn.functional.interpolate(attention_maps.detach(),size=(W,H),mode='bilinear')
    part_weights = F.avg_pool2d(attention_maps.detach(),(W,H)).reshape(batch_size,-1)
    part_weights = torch.add(torch.sqrt(part_weights),1e-12)
    part_weights = torch.div(part_weights,torch.sum(part_weights,dim=1).unsqueeze(1)).cpu()
    part_weights = part_weights.numpy()
    # print(part_weights.shape)
    ret_imgs = []
    masks = []
    # print(part_weights[3])
    for i in range(batch_size):
        attention_map = attention_maps[i]
        part_weight = part_weights[i]
        selected_index = np.random.choice(np.arange(0, num_parts), 1, p=part_weight)[0]
        selected_index2 = np.random.choice(np.arange(0, num_parts), 1, p=part_weight)[0]
        ## create crop imgs
        mask = attention_map[selected_index, :, :]
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        threshold = random.uniform(0.4, 0.6)
        # threshold = 0.5
        itemindex = torch.where(mask >= mask.max()*threshold)
        # print(itemindex.shape)
        # itemindex = torch.nonzero(mask >= threshold*mask.max())
        padding_h = int(0.1*H)
        padding_w = int(0.1*W)
        height_min = itemindex[0].min()
        height_min = max(0,height_min-padding_h)
        height_max = itemindex[0].max() + padding_h
        width_min = itemindex[1].min()
        width_min = max(0,width_min-padding_w)
        width_max = itemindex[1].max() + padding_w
        # print('numpy',height_min,height_max,width_min,width_max)
        out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0)
        out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True)
        out_img = out_img.squeeze(0)
        ret_imgs.append(out_img)

        ## create drop imgs
        mask2 = attention_map[selected_index2:selected_index2 + 1, :, :]
        threshold = random.uniform(0.2, 0.5)
        mask2 = (mask2 < threshold * mask2.max()).float()
        masks.append(mask2)
    # bboxes = np.asarray(bboxes, np.float32)
    crop_imgs = torch.stack(ret_imgs)
    masks = torch.stack(masks)
    drop_imgs = input_tensor*masks
    return (crop_imgs,drop_imgs)

def mask2bbox(attention_maps,input_image):
    input_tensor = input_image
    B,C,H,W = input_tensor.shape
    batch_size, num_parts, Hh, Ww = attention_maps.shape
    attention_maps = torch.nn.functional.interpolate(attention_maps,size=(W,H),mode='bilinear')
    ret_imgs = []
    # print(part_weights[3])
    for i in range(batch_size):
        attention_map = attention_maps[i]
        # print(attention_map.shape)
        mask = attention_map.mean(dim=0)
        # print(type(mask))
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        # threshold = random.uniform(0.4, 0.6)
        threshold = 0.1
        max_activate = mask.max()
        min_activate = threshold * max_activate
        itemindex = torch.nonzero(mask >= min_activate)

        padding_h = int(0.05*H)
        padding_w = int(0.05*W)
        height_min = itemindex[:, 0].min()
        height_min = max(0,height_min-padding_h)
        height_max = itemindex[:, 0].max() + padding_h
        width_min = itemindex[:, 1].min()
        width_min = max(0,width_min-padding_w)
        width_max = itemindex[:, 1].max() + padding_w
        # print(height_min,height_max,width_min,width_max)
        out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0)
        out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True)
        out_img = out_img.squeeze(0)
        # print(out_img.shape)
        ret_imgs.append(out_img)
    ret_imgs = torch.stack(ret_imgs)
    # print(ret_imgs.shape)
    return ret_imgs

def calculate_pooling_center_loss(features, centers, label, alfa=0.95):
    # centers = model.centers
    # print('111111111',sum(sum(centers)))
    # mse_loss = torch.nn.MSELoss()
    features = features.reshape(features.shape[0], -1)
    # print(features.shape)
    centers_batch = centers[label]
    # print(centers_batch)
    # print(centers_batch.shape,centers.shape)
    centers_batch = torch.nn.functional.normalize(centers_batch, dim=-1)
    diff =  (1-alfa)*(features.detach() - centers_batch)
    distance = torch.pow(features - centers_batch,2)
    distance = torch.sum(distance, dim=-1)
    center_loss = torch.mean(distance)
    # loss2 = mse_loss(features,centers_batch)
    # print('================',center_loss.item(),loss2.item())
    return center_loss, diff

def attention_crop_drop2(attention_maps,input_image):
    # start = time.time()
    B,N,W,H = input_image.shape
    input_tensor = input_image
    batch_size, num_parts, height, width = attention_maps.shape
    attention_maps = torch.nn.functional.interpolate(attention_maps.detach(),size=(W,H),mode='bilinear')
    part_weights = F.avg_pool2d(attention_maps.detach(),(W,H)).reshape(batch_size,-1)
    part_weights = torch.add(torch.sqrt(part_weights),1e-12)
    part_weights = torch.div(part_weights,torch.sum(part_weights,dim=1).unsqueeze(1)).cpu()
    part_weights = part_weights.numpy()
    # print(part_weights.shape)
    ret_imgs = []
    masks = []
    # print(part_weights[3])
    for i in range(batch_size):
        attention_map = attention_maps[i]
        part_weight = part_weights[i]
        selected_index = np.random.choice(np.arange(0, num_parts), 1, p=part_weight)[0]
        selected_index2 = np.random.choice(np.arange(0, num_parts), 1, p=part_weight)[0]
        ## create crop imgs
        mask = attention_map[selected_index, :, :]
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        threshold = random.uniform(0.4, 0.6)
        # threshold = 0.5
        # itemindex = np.where(mask >= mask.max()*threshold)
        # print(itemindex.shape)
        itemindex = torch.nonzero(mask >= threshold*mask.max())
        padding_h = int(0.1*H)
        padding_w = int(0.1*W)
        height_min = itemindex[:,0].min()
        height_min = max(0,height_min-padding_h)
        height_max = itemindex[:,0].max() + padding_h
        width_min = itemindex[:,1].min()
        width_min = max(0,width_min-padding_w)
        width_max = itemindex[:,1].max() + padding_w
        # print(height_min,height_max,width_min,width_max)
        out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0)
        out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True)
        out_img = out_img.squeeze(0)
        ret_imgs.append(out_img)

        ## create drop imgs
        mask2 = attention_map[selected_index2:selected_index2 + 1, :, :]
        threshold = random.uniform(0.2, 0.5)
        mask2 = (mask2 < threshold * mask2.max()).float()
        masks.append(mask2)
    # bboxes = np.asarray(bboxes, np.float32)
    crop_imgs = torch.stack(ret_imgs)
    masks = torch.stack(masks)
    drop_imgs = input_tensor*masks
    return (crop_imgs,drop_imgs)






if __name__ == '__main__':
    import torch
    a = torch.rand(4*26*26*32).reshape(4, 32, 26, 26)
    # a = torch.Tensor((4, 32, 26, 26))
    img = torch.arange(4*3*448*448.0).reshape(4, 3, 448, 448)
    # a = torch.arange(4*1*1*8.0).reshape(4, 8, 1, 1)
    # b = torch.ones(10*1*1*8).reshape(10, 8)
    # label = torch.LongTensor([1, 2, 3, 4])
    # a = torch.div(a,4*26*26*8)
    # ret = attention_drop2(a,img)
    ret1 = attention_crop_drop(a,img)
    ret2 = attention_crop_drop2(a,img)
    # ret2 = attention_crop2(a,img)
    # ret = calculate_pooling_center_loss(a, b, label)
    # print(ret)
    # print(ret.shape,ret2.shape)
    # print(type(ret),type(ret2))

train script :

############################################################
#   File: train_bap.py                                     #
#   Created: 2019-11-06 13:22:23                           #
#   Author : wvinzh                                        #
#   Email : wvinzh@qq.com                                  #
#   ------------------------------------------             #
#   Description:train_bap.py                               #
#   Copyright@2019 wvinzh, HUST                            #
############################################################

# system
import os
import time
import shutil
import random
import numpy as np

# my implementation
from model.inception_bap import inception_v3_bap
from model.resnet import resnet50
from dataset.custom_dataset import CustomDataset

from utils import calculate_pooling_center_loss, mask2bbox
from utils import attention_crop, attention_drop, attention_crop_drop
from utils import getDatasetConfig, getConfig
from utils import accuracy, get_lr, save_checkpoint, AverageMeter, set_seed
from utils import Engine

# pytorch
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.models as models
import torch.nn.functional as F
from tensorboardX import SummaryWriter

GLOBAL_SEED = 1231
def _init_fn(worker_id):
    set_seed(GLOBAL_SEED+worker_id)

def train():
    # input params
    set_seed(GLOBAL_SEED)
    config = getConfig()
    data_config = getDatasetConfig(config.dataset)
    sw_log = 'logs/%s' % config.dataset
    sw = SummaryWriter(log_dir=sw_log)
    best_prec1 = 0.
    rate = 0.875

    # define train_dataset and loader
    transform_train = transforms.Compose([
        transforms.Resize((int(config.input_size//rate), int(config.input_size//rate))),
        transforms.RandomCrop((config.input_size,config.input_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=32./255.,saturation=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    train_dataset = CustomDataset(
        data_config['train'], data_config['train_root'], transform=transform_train)
    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.workers, pin_memory=True, worker_init_fn=_init_fn)

    transform_test = transforms.Compose([
        transforms.Resize((config.image_size, config.image_size)),
        transforms.CenterCrop(config.input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    val_dataset = CustomDataset(
        data_config['val'], data_config['val_root'], transform=transform_test)
    val_loader = DataLoader(
        val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.workers, pin_memory=True, worker_init_fn=_init_fn)
    # logging dataset info
    print('Dataset Name:{dataset_name}, Train:[{train_num}], Val:[{val_num}]'.format(
        dataset_name=config.dataset,
        train_num=len(train_dataset),
        val_num=len(val_dataset)))
    print('Batch Size:[{0}], Total:::Train Batches:[{1}],Val Batches:[{2}]'.format(
        config.batch_size, len(train_loader), len(val_loader)
    ))
    # define model
    if config.model_name == 'inception':
        net = inception_v3_bap(pretrained=True, aux_logits=False,num_parts=config.parts)
    elif config.model_name == 'resnet50':
        net = resnet50(pretrained=True,use_bap=True)

    
    in_features = net.fc_new.in_features
    new_linear = torch.nn.Linear(
        in_features=in_features, out_features=train_dataset.num_classes)
    net.fc_new = new_linear
    # feature center
    feature_len = 768 if config.model_name == 'inception' else 512
    center_dict = {'center': torch.zeros(
        train_dataset.num_classes, feature_len*config.parts)}

    # gpu config
    use_gpu = torch.cuda.is_available() and config.use_gpu
    if use_gpu:
        net = net.cuda()
        center_dict['center'] = center_dict['center'].cuda()
    gpu_ids = [int(r) for r in config.gpu_ids.split(',')]
    if use_gpu and config.multi_gpu:
        net = torch.nn.DataParallel(net, device_ids=gpu_ids)

    # define optimizer
    assert config.optim in ['sgd', 'adam'], 'optim name not found!'
    if config.optim == 'sgd':
        optimizer = torch.optim.SGD(
            net.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay)
    elif config.optim == 'adam':
        optimizer = torch.optim.Adam(
            net.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    # define learning scheduler
    assert config.scheduler in ['plateau',
                                'step'], 'scheduler not supported!!!'
    if config.scheduler == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', patience=3, factor=0.1)
    elif config.scheduler == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=2, gamma=0.9)

    # define loss
    criterion = torch.nn.CrossEntropyLoss()
    if use_gpu:
        criterion = criterion.cuda()

    # train val parameters dict
    state = {'model': net, 'train_loader': train_loader,
             'val_loader': val_loader, 'criterion': criterion,
             'center': center_dict['center'], 'config': config,
             'optimizer': optimizer}
    ## train and val
    engine = Engine()
    print(config)
    for e in range(config.epochs):
        if config.scheduler == 'step':
            scheduler.step()
        lr_val = get_lr(optimizer)
        print("Start epoch %d ==========,lr=%f" % (e, lr_val))
        train_prec, train_loss = engine.train(state, e)
        prec1, val_loss = engine.validate(state)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': e + 1,
            'state_dict': net.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
            'center': center_dict['center']
        }, is_best, config.checkpoint_path)
        sw.add_scalars("Accurancy", {'train': train_prec, 'val': prec1}, e)
        sw.add_scalars("Loss", {'train': train_loss, 'val': val_loss}, e)
        if config.scheduler == 'plateau':
            scheduler.step(val_loss)

def test():
    ##
    engine = Engine()
    config = getConfig()
    data_config = getDatasetConfig(config.dataset)
    # define dataset
    transform_test = transforms.Compose([
        transforms.Resize((config.image_size, config.image_size)),
        transforms.CenterCrop(config.input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    val_dataset = CustomDataset(
        data_config['val'], data_config['val_root'], transform=transform_test)
    val_loader = DataLoader(
        val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.workers, pin_memory=True)
    # define model
    if config.model_name == 'inception':
        net = inception_v3_bap(pretrained=True, aux_logits=False)
    elif config.model_name == 'resnet50':
        net = resnet50(pretrained=True)

    in_features = net.fc_new.in_features
    new_linear = torch.nn.Linear(
        in_features=in_features, out_features=val_dataset.num_classes)
    net.fc_new = new_linear

    # load checkpoint
    use_gpu = torch.cuda.is_available() and config.use_gpu
    if use_gpu:
        net = net.cuda()
    gpu_ids = [int(r) for r in config.gpu_ids.split(',')]
    if use_gpu and len(gpu_ids) > 1:
        net = torch.nn.DataParallel(net, device_ids=gpu_ids)
    #checkpoint_path = os.path.join(config.checkpoint_path,'model_best.pth.tar')
    net.load_state_dict(torch.load(config.checkpoint_path)['state_dict'])

    # define loss
    # define loss
    criterion = torch.nn.CrossEntropyLoss()
    if use_gpu:
        criterion = criterion.cuda()
    prec1, prec5 = engine.test(val_loader, net, criterion)


if __name__ == '__main__':
    config = getConfig()
    engine = Engine()
    if config.action == 'train':
        train()
    else:
        test()