Training fails by out of memory error on Pytorch 0.4 but runs fine on 0.3.1

hi, I recently updated my Pytorch installation to the latest version which is 0.4. I installed the newer version like this :

me@shishosama:/media/me/tmpstore/SimpNet_PyTorch$ pip install http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl 
Collecting torch==0.4.0 from http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl
  Downloading http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl (566.4MB)
    100% |████████████████████████████████| 566.4MB 943kB/s 
Installing collected packages: torch
  Found existing installation: torch 0.3.1
    Uninstalling torch-0.3.1:
      Successfully uninstalled torch-0.3.1
Successfully installed torch-0.4.0
You are using pip version 9.0.1, however version 10.0.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.
me@shishosama:/media/me/tmpstore/SimpNet_PyTorch$ pip install torchvision
Requirement already satisfied: torchvision in /home/me/anaconda3/lib/python3.6/site-packages
Requirement already satisfied: six in /home/me/anaconda3/lib/python3.6/site-packages (from torchvision)
Requirement already satisfied: numpy in /home/me/anaconda3/lib/python3.6/site-packages (from torchvision)
Requirement already satisfied: torch in /home/me/anaconda3/lib/python3.6/site-packages (from torchvision)
Requirement already satisfied: pillow>=4.1.1 in /home/me/anaconda3/lib/python3.6/site-packages (from torchvision)
You are using pip version 9.0.1, however version 10.0.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.

This is my training script (gist link) and as you can see its pretty simple and straightforward.
When the script reaches to the evaluation part it throws an error.
Here is the full log :

me@shisho:/media/me/tmpstore/SimpNet_PyTorch$ bash training_sequence.sh 
=> creating model 'simple_imagenet_3p'
=> Model : simple_imagenet_3p(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=[3, 3], stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(64, 128, kernel_size=[3, 3], stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=(1, 1), ceil_mode=False)
    (13): Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
    (14): BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (15): ReLU(inplace)
    (16): Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
    (17): BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (18): ReLU(inplace)
    (19): Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
    (20): BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (21): ReLU(inplace)
    (22): Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
    (23): BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (24): ReLU(inplace)
    (25): Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
    (26): BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (27): ReLU(inplace)
    (28): Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
    (29): BatchNorm2d(512, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (30): ReLU(inplace)
    (31): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=(1, 1), ceil_mode=False)
    (32): Conv2d(512, 2048, kernel_size=[1, 1], stride=(1, 1))
    (33): BatchNorm2d(2048, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (34): ReLU(inplace)
    (35): Conv2d(2048, 256, kernel_size=[1, 1], stride=(1, 1))
    (36): BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (37): ReLU(inplace)
    (38): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=(1, 1), ceil_mode=False)
    (39): Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
    (40): BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True, track_running_stats=True)
    (41): ReLU(inplace)
  )
  (classifier): Linear(in_features=256, out_features=1000, bias=True)
)
=> parameter : Namespace(arch='simple_imagenet_3p', batch_size=128, data='/media/me/SSD/ImageNet_DataSet', epochs=150, evaluate=False, lr=0.1, momentum=0.9, prefix='2018-06-30-6885', print_freq=200, resume='./snapshots/imagenet/simplenets/5mil_3p/checkpoint.simple_imagenet_3p.2018-06-27-1781_2018-06-27_13-15-13.pth.tar', save_dir='./snapshots/imagenet/simplenets/5mil_3p/', start_epoch=86, train_dir_name='training_set_t12/', val_dir_name='imagenet_val/', weight_decay=1e-05, workers=12)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           1,792
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
            Conv2d-4          [-1, 128, 56, 56]          73,856
       BatchNorm2d-5          [-1, 128, 56, 56]             256
              ReLU-6          [-1, 128, 56, 56]               0
            Conv2d-7          [-1, 128, 56, 56]         147,584
       BatchNorm2d-8          [-1, 128, 56, 56]             256
              ReLU-9          [-1, 128, 56, 56]               0
           Conv2d-10          [-1, 128, 56, 56]         147,584
      BatchNorm2d-11          [-1, 128, 56, 56]             256
             ReLU-12          [-1, 128, 56, 56]               0
        MaxPool2d-13          [-1, 128, 28, 28]               0
           Conv2d-14          [-1, 128, 28, 28]         147,584
      BatchNorm2d-15          [-1, 128, 28, 28]             256
             ReLU-16          [-1, 128, 28, 28]               0
           Conv2d-17          [-1, 128, 28, 28]         147,584
      BatchNorm2d-18          [-1, 128, 28, 28]             256
             ReLU-19          [-1, 128, 28, 28]               0
           Conv2d-20          [-1, 256, 28, 28]         295,168
      BatchNorm2d-21          [-1, 256, 28, 28]             512
             ReLU-22          [-1, 256, 28, 28]               0
           Conv2d-23          [-1, 256, 28, 28]         590,080
      BatchNorm2d-24          [-1, 256, 28, 28]             512
             ReLU-25          [-1, 256, 28, 28]               0
           Conv2d-26          [-1, 256, 28, 28]         590,080
      BatchNorm2d-27          [-1, 256, 28, 28]             512
             ReLU-28          [-1, 256, 28, 28]               0
           Conv2d-29          [-1, 512, 28, 28]       1,180,160
      BatchNorm2d-30          [-1, 512, 28, 28]           1,024
             ReLU-31          [-1, 512, 28, 28]               0
        MaxPool2d-32          [-1, 512, 14, 14]               0
           Conv2d-33         [-1, 2048, 14, 14]       1,050,624
      BatchNorm2d-34         [-1, 2048, 14, 14]           4,096
             ReLU-35         [-1, 2048, 14, 14]               0
           Conv2d-36          [-1, 256, 14, 14]         524,544
      BatchNorm2d-37          [-1, 256, 14, 14]             512
             ReLU-38          [-1, 256, 14, 14]               0
        MaxPool2d-39            [-1, 256, 7, 7]               0
           Conv2d-40            [-1, 256, 7, 7]         590,080
      BatchNorm2d-41            [-1, 256, 7, 7]             512
             ReLU-42            [-1, 256, 7, 7]               0
           Linear-43                 [-1, 1000]         257,000
simplenetv1_imagenet_3p-44                 [-1, 1000]               0
================================================================
Total params: 5,752,808
Trainable params: 5,752,808
Non-trainable params: 0
----------------------------------------------------------------
None
FLOPs: 3830.96M, Params: 5.75M
{'milestones': [30, 60, 90, 130, 150], 'gamma': 0.1, 'base_lrs': [0.1], 'last_epoch': -1}
{'milestones': [30, 60, 90, 130, 150], 'gamma': 0.1, 'base_lrs': [0.1], 'last_epoch': 86}
=> loading checkpoint './snapshots/imagenet/simplenets/5mil_3p/checkpoint.simple_imagenet_3p.2018-06-27-1781_2018-06-27_13-15-13.pth.tar'
=> loaded checkpoint './snapshots/imagenet/simplenets/5mil_3p/checkpoint.simple_imagenet_3p.2018-06-27-1781_2018-06-27_13-15-13.pth.tar' (epoch 86)
/home/me/anaconda3/lib/python3.6/site-packages/torchvision/transforms/transforms.py:397: UserWarning: The use of the transforms.RandomSizedCrop transform is deprecated, please use transforms.RandomResizedCrop instead.
  "please use transforms.RandomResizedCrop instead.")
/home/me/anaconda3/lib/python3.6/site-packages/torchvision/transforms/transforms.py:156: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  "please use transforms.Resize instead.")

==>>[2018-06-30 13:31:35] [Epoch=086/150] [Need: 00:00:00] [learning_rate=0.0010000000000000002] [Best : Accuracy(T1/T5)=65.74/86.39, Error=34.26/13.61]
imagenet_train.py:317: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
  losses.update(loss.data[0], input.size(0))
Epoch: [86][0/10010]	Time 5.993 (5.993)	Data 2.479 (2.479)	Loss 1.4199 (1.4199)	Prec@1 64.062 (64.062)	Prec@5 85.938 (85.938)
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:725: UserWarning: Possibly corrupt EXIF data.  Expecting to read 2555904 bytes but only got 0. Skipping tag 0
  " Skipping tag %s" % (size, len(data), tag))
Epoch: [86][200/10010]	Time 0.393 (0.421)	Data 0.000 (0.013)	Loss 1.5897 (1.6016)	Prec@1 64.844 (63.040)	Prec@5 78.906 (83.190)
Epoch: [86][400/10010]	Time 0.399 (0.409)	Data 0.000 (0.008)	Loss 1.3367 (1.5933)	Prec@1 67.969 (63.310)	Prec@5 85.156 (83.241)
Epoch: [86][600/10010]	Time 0.398 (0.405)	Data 0.000 (0.007)	Loss 1.4794 (1.5994)	Prec@1 66.406 (63.111)	Prec@5 84.375 (83.191)
Epoch: [86][800/10010]	Time 0.396 (0.403)	Data 0.000 (0.006)	Loss 1.4389 (1.5995)	Prec@1 71.094 (63.136)	Prec@5 84.375 (83.192)
Epoch: [86][1000/10010]	Time 0.401 (0.403)	Data 0.000 (0.005)	Loss 1.4466 (1.5962)	Prec@1 68.750 (63.204)	Prec@5 86.719 (83.206)
Epoch: [86][1200/10010]	Time 0.395 (0.402)	Data 0.000 (0.005)	Loss 1.0079 (1.5964)	Prec@1 77.344 (63.150)	Prec@5 91.406 (83.184)
Epoch: [86][1400/10010]	Time 0.398 (0.401)	Data 0.000 (0.005)	Loss 1.6049 (1.5974)	Prec@1 62.500 (63.175)	Prec@5 78.906 (83.177)
Epoch: [86][1600/10010]	Time 0.397 (0.401)	Data 0.000 (0.005)	Loss 1.3969 (1.5955)	Prec@1 66.406 (63.194)	Prec@5 85.938 (83.189)
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:725: UserWarning: Possibly corrupt EXIF data.  Expecting to read 2555904 bytes but only got 0. Skipping tag 0
  " Skipping tag %s" % (size, len(data), tag))
Epoch: [86][1800/10010]	Time 0.397 (0.400)	Data 0.000 (0.005)	Loss 1.7823 (1.5972)	Prec@1 57.812 (63.106)	Prec@5 82.812 (83.157)
Epoch: [86][2000/10010]	Time 0.399 (0.400)	Data 0.000 (0.004)	Loss 1.2100 (1.5991)	Prec@1 69.531 (63.081)	Prec@5 89.844 (83.143)
Epoch: [86][2200/10010]	Time 0.399 (0.400)	Data 0.000 (0.004)	Loss 1.4692 (1.5992)	Prec@1 71.875 (63.116)	Prec@5 85.156 (83.132)
Epoch: [86][2400/10010]	Time 0.403 (0.400)	Data 0.000 (0.004)	Loss 1.9854 (1.5993)	Prec@1 63.281 (63.113)	Prec@5 76.562 (83.130)
Epoch: [86][2600/10010]	Time 0.396 (0.400)	Data 0.000 (0.004)	Loss 1.5382 (1.5992)	Prec@1 61.719 (63.114)	Prec@5 82.031 (83.137)
Epoch: [86][2800/10010]	Time 0.398 (0.400)	Data 0.000 (0.004)	Loss 1.8942 (1.5994)	Prec@1 57.812 (63.111)	Prec@5 82.812 (83.150)
Epoch: [86][3000/10010]	Time 0.398 (0.399)	Data 0.000 (0.004)	Loss 1.4567 (1.6003)	Prec@1 60.938 (63.071)	Prec@5 89.844 (83.136)
Epoch: [86][3200/10010]	Time 0.396 (0.399)	Data 0.000 (0.004)	Loss 1.4535 (1.6001)	Prec@1 62.500 (63.056)	Prec@5 84.375 (83.123)
Epoch: [86][3400/10010]	Time 0.396 (0.399)	Data 0.000 (0.004)	Loss 1.5194 (1.5993)	Prec@1 65.625 (63.083)	Prec@5 86.719 (83.141)
Epoch: [86][3600/10010]	Time 0.399 (0.399)	Data 0.000 (0.004)	Loss 1.4627 (1.5993)	Prec@1 65.625 (63.069)	Prec@5 85.938 (83.140)
Epoch: [86][3800/10010]	Time 0.404 (0.399)	Data 0.000 (0.004)	Loss 1.5466 (1.5986)	Prec@1 65.625 (63.077)	Prec@5 82.031 (83.148)
Epoch: [86][4000/10010]	Time 0.397 (0.399)	Data 0.000 (0.004)	Loss 1.2347 (1.6004)	Prec@1 71.875 (63.050)	Prec@5 89.062 (83.129)
Epoch: [86][4200/10010]	Time 0.396 (0.399)	Data 0.000 (0.004)	Loss 1.7307 (1.6003)	Prec@1 62.500 (63.045)	Prec@5 80.469 (83.129)
Epoch: [86][4400/10010]	Time 0.399 (0.399)	Data 0.000 (0.004)	Loss 1.5589 (1.6006)	Prec@1 64.844 (63.049)	Prec@5 82.812 (83.133)
Epoch: [86][4600/10010]	Time 0.400 (0.399)	Data 0.000 (0.004)	Loss 1.5220 (1.6007)	Prec@1 65.625 (63.036)	Prec@5 82.031 (83.125)
Epoch: [86][4800/10010]	Time 0.398 (0.399)	Data 0.000 (0.004)	Loss 1.4993 (1.6002)	Prec@1 60.938 (63.055)	Prec@5 86.719 (83.121)
Epoch: [86][5000/10010]	Time 0.396 (0.399)	Data 0.000 (0.004)	Loss 1.5266 (1.6015)	Prec@1 68.750 (63.030)	Prec@5 87.500 (83.105)
Epoch: [86][5200/10010]	Time 0.398 (0.399)	Data 0.000 (0.004)	Loss 1.4016 (1.6017)	Prec@1 67.188 (63.019)	Prec@5 84.375 (83.107)
Epoch: [86][5400/10010]	Time 0.396 (0.399)	Data 0.000 (0.004)	Loss 1.7396 (1.6020)	Prec@1 60.938 (63.015)	Prec@5 82.031 (83.105)
Epoch: [86][5600/10010]	Time 0.400 (0.399)	Data 0.000 (0.004)	Loss 1.3337 (1.6012)	Prec@1 68.750 (63.033)	Prec@5 86.719 (83.124)
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:725: UserWarning: Possibly corrupt EXIF data.  Expecting to read 2555904 bytes but only got 0. Skipping tag 0
  " Skipping tag %s" % (size, len(data), tag))
Epoch: [86][5800/10010]	Time 0.399 (0.399)	Data 0.000 (0.004)	Loss 1.8474 (1.6011)	Prec@1 61.719 (63.029)	Prec@5 81.250 (83.120)
Epoch: [86][6000/10010]	Time 0.397 (0.399)	Data 0.000 (0.004)	Loss 1.6060 (1.6006)	Prec@1 60.156 (63.038)	Prec@5 85.156 (83.126)
Epoch: [86][6200/10010]	Time 0.397 (0.399)	Data 0.000 (0.004)	Loss 1.5556 (1.6007)	Prec@1 66.406 (63.030)	Prec@5 85.156 (83.120)
Epoch: [86][6400/10010]	Time 0.398 (0.399)	Data 0.000 (0.004)	Loss 2.0969 (1.6002)	Prec@1 58.594 (63.036)	Prec@5 75.781 (83.131)
Epoch: [86][6600/10010]	Time 0.399 (0.399)	Data 0.000 (0.004)	Loss 1.5191 (1.6003)	Prec@1 69.531 (63.045)	Prec@5 85.938 (83.137)
Epoch: [86][6800/10010]	Time 0.400 (0.399)	Data 0.001 (0.004)	Loss 1.3767 (1.6005)	Prec@1 64.844 (63.031)	Prec@5 84.375 (83.126)
Epoch: [86][7000/10010]	Time 0.399 (0.399)	Data 0.000 (0.004)	Loss 1.3647 (1.6005)	Prec@1 63.281 (63.045)	Prec@5 85.156 (83.124)
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:725: UserWarning: Possibly corrupt EXIF data.  Expecting to read 2555904 bytes but only got 0. Skipping tag 0
  " Skipping tag %s" % (size, len(data), tag))
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:725: UserWarning: Possibly corrupt EXIF data.  Expecting to read 2555904 bytes but only got 0. Skipping tag 0
  " Skipping tag %s" % (size, len(data), tag))
Epoch: [86][7200/10010]	Time 0.400 (0.399)	Data 0.000 (0.004)	Loss 1.7624 (1.6007)	Prec@1 60.938 (63.038)	Prec@5 79.688 (83.120)
Epoch: [86][7400/10010]	Time 0.398 (0.399)	Data 0.000 (0.004)	Loss 1.4949 (1.6004)	Prec@1 65.625 (63.049)	Prec@5 83.594 (83.125)
Epoch: [86][7600/10010]	Time 0.396 (0.399)	Data 0.000 (0.004)	Loss 1.3772 (1.6004)	Prec@1 67.188 (63.047)	Prec@5 85.156 (83.131)
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:725: UserWarning: Possibly corrupt EXIF data.  Expecting to read 19660800 bytes but only got 0. Skipping tag 0
  " Skipping tag %s" % (size, len(data), tag))
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:725: UserWarning: Possibly corrupt EXIF data.  Expecting to read 18481152 bytes but only got 0. Skipping tag 0
  " Skipping tag %s" % (size, len(data), tag))
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:725: UserWarning: Possibly corrupt EXIF data.  Expecting to read 37093376 bytes but only got 0. Skipping tag 0
  " Skipping tag %s" % (size, len(data), tag))
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:725: UserWarning: Possibly corrupt EXIF data.  Expecting to read 39976960 bytes but only got 0. Skipping tag 0
  " Skipping tag %s" % (size, len(data), tag))
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:725: UserWarning: Possibly corrupt EXIF data.  Expecting to read 34865152 bytes but only got 0. Skipping tag 0
  " Skipping tag %s" % (size, len(data), tag))
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:742: UserWarning: Corrupt EXIF data.  Expecting to read 12 bytes but only got 10. 
  warnings.warn(str(msg))
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:725: UserWarning: Possibly corrupt EXIF data.  Expecting to read 1835008 bytes but only got 0. Skipping tag 0
  " Skipping tag %s" % (size, len(data), tag))
Epoch: [86][7800/10010]	Time 0.398 (0.399)	Data 0.000 (0.004)	Loss 1.3626 (1.6007)	Prec@1 67.188 (63.041)	Prec@5 87.500 (83.124)
Epoch: [86][8000/10010]	Time 0.397 (0.399)	Data 0.001 (0.004)	Loss 1.9515 (1.6005)	Prec@1 59.375 (63.044)	Prec@5 78.125 (83.125)
Epoch: [86][8200/10010]	Time 0.397 (0.399)	Data 0.000 (0.004)	Loss 1.5302 (1.6005)	Prec@1 65.625 (63.046)	Prec@5 80.469 (83.126)
Epoch: [86][8400/10010]	Time 0.402 (0.399)	Data 0.000 (0.004)	Loss 1.5629 (1.6012)	Prec@1 61.719 (63.028)	Prec@5 84.375 (83.113)
Epoch: [86][8600/10010]	Time 0.398 (0.399)	Data 0.000 (0.004)	Loss 1.4765 (1.6010)	Prec@1 67.188 (63.023)	Prec@5 87.500 (83.117)
/home/me/anaconda3/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:742: UserWarning: Corrupt EXIF data.  Expecting to read 4 bytes but only got 0. 
  warnings.warn(str(msg))
Epoch: [86][8800/10010]	Time 0.399 (0.399)	Data 0.000 (0.004)	Loss 1.7422 (1.6007)	Prec@1 60.938 (63.023)	Prec@5 79.688 (83.122)
Epoch: [86][9000/10010]	Time 0.399 (0.399)	Data 0.000 (0.004)	Loss 1.7085 (1.6009)	Prec@1 63.281 (63.021)	Prec@5 82.031 (83.119)
Epoch: [86][9200/10010]	Time 0.394 (0.399)	Data 0.000 (0.004)	Loss 1.7975 (1.6009)	Prec@1 61.719 (63.017)	Prec@5 78.125 (83.119)
Epoch: [86][9400/10010]	Time 0.396 (0.399)	Data 0.000 (0.004)	Loss 1.8066 (1.6009)	Prec@1 57.812 (63.016)	Prec@5 85.156 (83.120)
Epoch: [86][9600/10010]	Time 0.399 (0.399)	Data 0.000 (0.004)	Loss 1.8106 (1.6010)	Prec@1 58.594 (63.023)	Prec@5 79.688 (83.121)
Epoch: [86][9800/10010]	Time 0.397 (0.399)	Data 0.000 (0.004)	Loss 1.2793 (1.6008)	Prec@1 64.844 (63.027)	Prec@5 89.844 (83.122)
Epoch: [86][10000/10010]	Time 0.400 (0.399)	Data 0.000 (0.004)	Loss 1.5781 (1.6010)	Prec@1 63.281 (63.019)	Prec@5 85.156 (83.124)
imagenet_train.py:354: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
  input_var = torch.autograd.Variable(input, volatile=True)
imagenet_train.py:355: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
  target_var = torch.autograd.Variable(target, volatile=True)
imagenet_train.py:363: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
  losses.update(loss.data[0], input.size(0))
Test: [0/391]	Time 3.970 (3.970)	Loss 0.7654 (0.7654)	Prec@1 85.938 (85.938)	Prec@5 92.188 (92.188)
THCudaCheck FAIL file=/pytorch/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory
Traceback (most recent call last):
  File "imagenet_train.py", line 523, in <module>
    main()
  File "imagenet_train.py", line 208, in main
    prec1,prec5, val_loss = validate(val_loader, model, criterion, log)
  File "imagenet_train.py", line 358, in validate
    output = model(input_var)
  File "/home/me/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/me/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 112, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/me/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/media/me/tmpstore/SimpNet_PyTorch/models/simplenet_v1_p3_imgnet.py", line 53, in forward
    out = self.features(x)
  File "/home/me/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/me/anaconda3/lib/python3.6/site-packages/torch/nn/modules/container.py", line 91, in forward
    input = module(input)
  File "/home/me/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/me/anaconda3/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 301, in forward
    self.padding, self.dilation, self.groups)
RuntimeError: cuda runtime error (2) : out of memory at /pytorch/aten/src/THC/generic/THCStorage.cu:58

More information :
GPU: GTX1080,
RAM: 18G
OS information :

x86_64
Kernel version: 4.13.0-45-generic
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=16.04
DISTRIB_CODENAME=xenial
DISTRIB_DESCRIPTION="Ubuntu 16.04.4 LTS"
NAME="Ubuntu"
VERSION="16.04.4 LTS (Xenial Xerus)"
ID=ubuntu
ID_LIKE=debian
PRETTY_NAME="Ubuntu 16.04.4 LTS"
VERSION_ID="16.04"
HOME_URL="http://www.ubuntu.com/"
SUPPORT_URL="http://help.ubuntu.com/"
BUG_REPORT_URL="http://bugs.launchpad.net/ubuntu/"
VERSION_CODENAME=xenial
UBUNTU_CODENAME=xenial
gcc-version : gcc (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609

What is happening here? Does version 0.4 take more memory than the former ones?

Pay attention to this. volatile now does nothing, and you should wrap your validate loop in with torch.no_grad():.

Your graph in validate won’t be freed until the output and loss variables are overwritten in the next iteration, and thus effective doubling memory usage.

3 Likes

Thank you very much. Is this OK now?

def validate(val_loader, model, criterion, log):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad(): #for versions >=0.4
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda(async=True)
            input_var = torch.autograd.Variable(input)#, volatile=True) #for versions <=0.3.1
            target_var = torch.autograd.Variable(target)#, volatile=True) #for versions <=0.3.1

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.data[0], input.size(0))
            top1.update(prec1[0], input.size(0))
            top5.update(prec5[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print_log('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       i, len(val_loader), batch_time=batch_time, loss=losses,
                       top1=top1, top5=top5), log)

        print_log(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss@ {error:.3f}'.format(top1=top1, top5=top5, error=losses.avg), log)

    return top1.avg, top5.avg, losses.avg

Yeah this is fine.

Also with 0.4, you don’t even need the torch.autograd.Variable wrappers anymore. This should be a helpful read to you: https://pytorch.org/2018/04/22/0_4_0-migration-guide.html

:slight_smile:

1 Like

Thank you very much sir, I really appreciate it :slight_smile: