GPU memory use increase and training speed slow down

Hi, all!
I am new to Pytorch and I meet a strange problem while training a my model with GPU. The GPU memory use increase gradually which training and will finally be stable. At the beginning, it will consume about 4G GPU memory, and will increase to around 7G. This happens in the first epoch and the memory use will be stable.
Meanwhile, the training speed will unacceptably slow down after every epoch. In four epochs, it costs 1000, 2000, 3500, 5200 ms per batch, respectively. Here is the code snippet:

def train_epoch(model, data_train, optimizer, criterion_cls, args):
model.train()
total_loss = 0
cnt = 0
label_list = list()
pred_list = list()
start_time = time.time()

for batch, batch_data in enumerate(data_train):
    data, targets= batch_data
    padding = torch.zeros(data.size(0), args.window_size-1).long()
    data = torch.cat([padding, data],dim=1)
    cnt += targets.size(0)
    label = get_label(targets)
    if args.cuda:
        data = data.cuda()
        label = label.cuda()

    output = model(data)

    loss = criterion_cls(output, label)
    total_loss += loss.item()

    optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), args.clip)
    optimizer.step()

    prediction = torch.max(output, 1)[1]
    label_list.extend(label.tolist())
    pred_list.extend(prediction.tolist())

    if (batch+1) % args.log_interval == 0:

        p = precision_score(label_list, pred_list, average='weighted')
        r = recall_score(label_list, pred_list, average='weighted')
        f = f1_score(label_list, pred_list, average='weighted')
        acc = accuracy_score(label_list, pred_list)

        elapsed = time.time() - start_time
        train_loss, acc, p, r, f = total_loss/cnt, acc, p, r, f

        cnt = 0
        total_loss = 0
        label_list.clear()
        pred_list.clear()
        start_time = time.time()

The evaluation code is similar and with torch.no_grad() is also used. I guess the problem is not related to the specific model, because I encounter this issue with different models (CNN, RNN, etc).
Could someone help me out? Thanks a lot!!

Hi,
Could you try to do the following?:

prediction = torch.max(output.detach().cpu(), 1)[1]

As you are listing predictions, the whole computational graph is being stored (and it’s probably allocated in gpu right?)

Thanks for your suggestion! But it doesn’t solve my problem. GPU memory and training time still increase as previous version.

Is it possible that you are doing something similar with any other variable? like labels? If labels are allocated on gpu, could you move them to cpu before enlisting them? In general you don’t have to save cuda variables without moving them to cpu before nor tensors which contains gradient tracks without detaching (and allocating on cpu). That’s typically the cause of your problem. Could you have a look at the rest of your code to see if you are doing something like that?

Yeah, I agree with you. The possible reason of training time increase is that the computational graph is not freed correctly. But I think that only loss and output contain gradient tracks, other parts like labels is not traced by the graph, right? I try to delete all the intermediate tensors, but it doesn’t work either! And why the GPU memory will increase to a certain amount rather than keep growing and cause an OOM issue?

The problem is that even if labels are not traced, you are enlisting them and thus, gpu memory keeps increasing until you reach a limit (which is the amount of iterations per epoch). If you reduce batch size you should probably observe even more memory comsumption as more labels would be allocated.

Anyway that does not explains the slow down. Is it possible that your machine gets slowed down by any other reason? (Dataloader for example?) could you split the timing into data loading and gpu computation? For example, you could get ram full and start using swap memory or something like that.

At least in the snippet you pasted I cannot apreciate more errors than the already mentioned ones. If you properly detached tensors after my suggestion, the erroy may come from other side.

In an extreme case, you should try to generate data randomly and un a minimal example with just the model and simple training pipe to discard issues in the model.
If it does work (which will probably will), start to save those variables (accuracy and so on). If it does work, check dataloader.

I met exactly the same problem when training a baseline Faster R-CNN (e2e) model using maskrcnn-benchmark with two GPUs (Nvidia P100) on a Linux server. Here’s my training code snipper:

logger.info("Start training")
meters = MetricLogger(delimiter="  ")
max_iter = len(data_loader)
start_iter = arguments["iteration"]
model.train()
start_training_time = time.time()
end = time.time()

for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
    if any(len(target) < 1 for target in targets):
        logger.error(f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}" )
        continue
    data_time = time.time() - end
    iteration = iteration + 1
    arguments["iteration"] = iteration

    scheduler.step()

    copy_time = time.time()
    images = images.to(device)
    targets = [target.to(device) for target in targets]
    copy_time = time.time() - copy_time

    forward_time = time.time()
    loss_dict = model(images, targets)
    forward_time = time.time() - forward_time

    losses = sum(loss for loss in loss_dict.values())

    # reduce losses over all GPUs for logging purposes
    reduce_time = time.time()
    loss_dict_reduced = reduce_loss_dict(loss_dict)
    reduce_time = time.time() - reduce_time
    losses_reduced = sum(loss for loss in loss_dict_reduced.values())
    meters.update(loss=losses_reduced, **loss_dict_reduced)

    backward_time = time.time()
    optimizer.zero_grad()
    # Note: If mixed precision is not used, this ends up doing nothing
    # Otherwise apply loss scaling for mixed-precision recipe
    with amp.scale_loss(losses, optimizer) as scaled_losses:
        scaled_losses.backward()
    optimizer.step()
    backward_time = time.time() - backward_time

    batch_time = time.time() - end
    end = time.time()
    meters.update(
        time=batch_time, data=data_time,
        cpu2gpu=copy_time, forward=forward_time,
        reduce=reduce_time, backward=backward_time
    )

I print the time consumption of each step every 20 iterations as shown below, it can be seen that the forward time, backward time and the gpu memory consumption increase gradually comparing to other timing metrics as the training goes. Just like the situation described by @NeoZ, they will become stable after some number of iterations (~9000 iterations in my case). I’m sure the problem doesn’t exist in my dataloader since the data loading and preprocessing time is very short and stable all the time. So, does someone know what’s the real cause for this situation? Really appreciate any help, thanks!

maskrcnn/maskrcnn_benchmark/engine/trainer.py:  50 INFO: Start training
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 21:40:43  iter: 20  loss: 7.0394 (7.4629)  loss_box_reg: 0.0068 (0.0173)  loss_classifier: 0.0740 (0.1274)  loss_objectness: 0.5356 (0.5280)  loss_rpn_box_reg: 0.0187 (0.0411)  time: 0.3847 (0.4568)  data: 0.0066 (0.0320)  cpu2gpu: 0.1230 (0.1184)  forward: 0.1665 (0.2189)  reduce: 0.0003 (0.0003)  backward: 0.0706 (0.0716)  lr: 0.000680  max mem: 2981
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 21:05:50  iter: 40  loss: 5.0266 (6.2484)  loss_box_reg: 0.0653 (0.0478)  loss_classifier: 0.1414 (0.1588)  loss_objectness: 0.1579 (0.3589)  loss_rpn_box_reg: 0.0260 (0.0363)  time: 0.4380 (0.4510)  data: 0.0073 (0.0199)  cpu2gpu: 0.1366 (0.1265)  forward: 0.1876 (0.2043)  reduce: 0.0003 (0.0003)  backward: 0.0895 (0.0857)  lr: 0.000860  max mem: 4058
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 20:51:38  iter: 60  loss: 4.3640 (5.6448)  loss_box_reg: 0.0925 (0.0641)  loss_classifier: 0.1535 (0.1631)  loss_objectness: 0.1048 (0.2927)  loss_rpn_box_reg: 0.0241 (0.0363)  time: 0.4394 (0.4487)  data: 0.0072 (0.0159)  cpu2gpu: 0.1247 (0.1264)  forward: 0.1841 (0.1993)  reduce: 0.0003 (0.0003)  backward: 0.1036 (0.0924)  lr: 0.001040  max mem: 4058
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 20:57:19  iter: 80  loss: 4.0550 (5.2782)  loss_box_reg: 0.0893 (0.0738)  loss_classifier: 0.1413 (0.1625)  loss_objectness: 0.1033 (0.2534)  loss_rpn_box_reg: 0.0250 (0.0349)  time: 0.4273 (0.4497)  data: 0.0075 (0.0138)  cpu2gpu: 0.1311 (0.1279)  forward: 0.1906 (0.1984)  reduce: 0.0003 (0.0003)  backward: 0.1000 (0.0961)  lr: 0.001220  max mem: 4058
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 21:05:04  iter: 100  loss: 4.0677 (5.0424)  loss_box_reg: 0.1020 (0.0804)  loss_classifier: 0.1369 (0.1593)  loss_objectness: 0.0844 (0.2275)  loss_rpn_box_reg: 0.0235 (0.0344)  time: 0.4470 (0.4510)  data: 0.0076 (0.0126)  cpu2gpu: 0.1261 (0.1293)  forward: 0.1866 (0.1970)  reduce: 0.0003 (0.0003)  backward: 0.1029 (0.0984)  lr: 0.001400  max mem: 4058
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 20:59:19  iter: 120  loss: 3.9398 (4.8649)  loss_box_reg: 0.0849 (0.0825)  loss_classifier: 0.1308 (0.1548)  loss_objectness: 0.0735 (0.2028)  loss_rpn_box_reg: 0.0188 (0.0328)  time: 0.4265 (0.4500)  data: 0.0070 (0.0117)  cpu2gpu: 0.1242 (0.1294)  forward: 0.1829 (0.1954)  reduce: 0.0003 (0.0003)  backward: 0.0916 (0.0994)  lr: 0.001580  max mem: 4058
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 21:24:35  iter: 140  loss: 3.8684 (4.7239)  loss_box_reg: 0.1126 (0.0892)  loss_classifier: 0.1458 (0.1559)  loss_objectness: 0.0552 (0.1874)  loss_rpn_box_reg: 0.0171 (0.0323)  time: 0.4648 (0.4543)  data: 0.0071 (0.0111)  cpu2gpu: 0.1297 (0.1303)  forward: 0.1918 (0.1958)  reduce: 0.0003 (0.0003)  backward: 0.1004 (0.1027)  lr: 0.001760  max mem: 4058
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 21:23:17  iter: 160  loss: 3.8812 (4.6223)  loss_box_reg: 0.0921 (0.0911)  loss_classifier: 0.1369 (0.1539)  loss_objectness: 0.0678 (0.1738)  loss_rpn_box_reg: 0.0173 (0.0324)  time: 0.4534 (0.4541)  data: 0.0070 (0.0107)  cpu2gpu: 0.1312 (0.1313)  forward: 0.1744 (0.1944)  reduce: 0.0003 (0.0003)  backward: 0.0958 (0.1026)  lr: 0.001940  max mem: 4058
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 21:31:55  iter: 180  loss: 3.8045 (4.5327)  loss_box_reg: 0.1131 (0.0946)  loss_classifier: 0.1331 (0.1534)  loss_objectness: 0.0545 (0.1619)  loss_rpn_box_reg: 0.0187 (0.0314)  time: 0.4682 (0.4555)  data: 0.0079 (0.0103)  cpu2gpu: 0.1255 (0.1309)  forward: 0.1945 (0.1948)  reduce: 0.0003 (0.0003)  backward: 0.1132 (0.1045)  lr: 0.002120  max mem: 4058
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 21:49:16  iter: 200  loss: 4.0134 (4.4732)  loss_box_reg: 0.1317 (0.1001)  loss_classifier: 0.1635 (0.1552)  loss_objectness: 0.0777 (0.1560)  loss_rpn_box_reg: 0.0341 (0.0325)  time: 0.4797 (0.4585)  data: 0.0072 (0.0101)  cpu2gpu: 0.1337 (0.1317)  forward: 0.2020 (0.1957)  reduce: 0.0003 (0.0003)  backward: 0.1211 (0.1060)  lr: 0.002300  max mem: 4058
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 22:01:37  iter: 220  loss: 3.6355 (4.4004)  loss_box_reg: 0.1289 (0.1039)  loss_classifier: 0.1628 (0.1564)  loss_objectness: 0.0662 (0.1485)  loss_rpn_box_reg: 0.0178 (0.0318)  time: 0.4608 (0.4606)  data: 0.0076 (0.0099)  cpu2gpu: 0.1258 (0.1314)  forward: 0.1935 (0.1965)  reduce: 0.0003 (0.0003)  backward: 0.1060 (0.1082)  lr: 0.002480  max mem: 4379
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 22:15:52  iter: 240  loss: 3.5626 (4.3338)  loss_box_reg: 0.1481 (0.1078)  loss_classifier: 0.1520 (0.1577)  loss_objectness: 0.0474 (0.1409)  loss_rpn_box_reg: 0.0153 (0.0308)  time: 0.4813 (0.4630)  data: 0.0072 (0.0097)  cpu2gpu: 0.1318 (0.1322)  forward: 0.1910 (0.1967)  reduce: 0.0003 (0.0003)  backward: 0.1078 (0.1095)  lr: 0.002660  max mem: 4379
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 22:31:14  iter: 260  loss: 3.9373 (4.2989)  loss_box_reg: 0.1472 (0.1118)  loss_classifier: 0.1647 (0.1597)  loss_objectness: 0.0552 (0.1351)  loss_rpn_box_reg: 0.0315 (0.0309)  time: 0.4451 (0.4655)  data: 0.0080 (0.0096)  cpu2gpu: 0.1201 (0.1323)  forward: 0.1979 (0.1972)  reduce: 0.0003 (0.0003)  backward: 0.1123 (0.1113)  lr: 0.002840  max mem: 4379
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 22:47:43  iter: 280  loss: 3.2553 (4.2348)  loss_box_reg: 0.1327 (0.1144)  loss_classifier: 0.1449 (0.1603)  loss_objectness: 0.0445 (0.1287)  loss_rpn_box_reg: 0.0198 (0.0303)  time: 0.5066 (0.4683)  data: 0.0078 (0.0095)  cpu2gpu: 0.1332 (0.1328)  forward: 0.2082 (0.1984)  reduce: 0.0003 (0.0003)  backward: 0.1288 (0.1128)  lr: 0.003020  max mem: 4379
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 23:10:17  iter: 300  loss: 3.4788 (4.1915)  loss_box_reg: 0.1916 (0.1201)  loss_classifier: 0.2173 (0.1643)  loss_objectness: 0.0662 (0.1252)  loss_rpn_box_reg: 0.0345 (0.0307)  time: 0.5200 (0.4721)  data: 0.0073 (0.0094)  cpu2gpu: 0.1331 (0.1333)  forward: 0.2076 (0.1991)  reduce: 0.0003 (0.0003)  backward: 0.1387 (0.1150)  lr: 0.003200  max mem: 4379
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 23:19:29  iter: 320  loss: 3.1697 (4.1298)  loss_box_reg: 0.1373 (0.1209)  loss_classifier: 0.1565 (0.1634)  loss_objectness: 0.0377 (0.1201)  loss_rpn_box_reg: 0.0175 (0.0307)  time: 0.4782 (0.4737)  data: 0.0076 (0.0093)  cpu2gpu: 0.1317 (0.1335)  forward: 0.2023 (0.1997)  reduce: 0.0003 (0.0003)  backward: 0.1177 (0.1160)  lr: 0.003380  max mem: 4379
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 23:27:53  iter: 340  loss: 3.5830 (4.0903)  loss_box_reg: 0.1246 (0.1226)  loss_classifier: 0.1451 (0.1639)  loss_objectness: 0.0364 (0.1158)  loss_rpn_box_reg: 0.0189 (0.0305)  time: 0.4817 (0.4751)  data: 0.0072 (0.0092)  cpu2gpu: 0.1341 (0.1342)  forward: 0.1938 (0.1998)  reduce: 0.0003 (0.0003)  backward: 0.1007 (0.1163)  lr: 0.003560  max mem: 4379
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 23:34:03  iter: 360  loss: 3.4124 (4.0554)  loss_box_reg: 0.1252 (0.1234)  loss_classifier: 0.1601 (0.1637)  loss_objectness: 0.0377 (0.1120)  loss_rpn_box_reg: 0.0241 (0.0300)  time: 0.4882 (0.4762)  data: 0.0073 (0.0092)  cpu2gpu: 0.1313 (0.1344)  forward: 0.1970 (0.2001)  reduce: 0.0003 (0.0003)  backward: 0.1131 (0.1170)  lr: 0.003740  max mem: 4379
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 1 day, 23:47:14  iter: 380  loss: 3.2500 (4.0164)  loss_box_reg: 0.1572 (0.1261)  loss_classifier: 0.1803 (0.1663)  loss_objectness: 0.0479 (0.1090)  loss_rpn_box_reg: 0.0216 (0.0296)  time: 0.5071 (0.4784)  data: 0.0075 (0.0091)  cpu2gpu: 0.1306 (0.1353)  forward: 0.2054 (0.2004)  reduce: 0.0003 (0.0003)  backward: 0.1204 (0.1178)  lr: 0.003920  max mem: 4379

…………

maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:06:54  iter: 1500  loss: 2.4801 (3.1181)  loss_box_reg: 0.1295 (0.1449)  loss_classifier: 0.1653 (0.1798)  loss_objectness: 0.0367 (0.0623)  loss_rpn_box_reg: 0.0202 (0.0271)  time: 0.5593 (0.5233)  data: 0.0081 (0.0084)  cpu2gpu: 0.1254 (0.1423)  forward: 0.2159 (0.2126)  reduce: 0.0003 (0.0003)  backward: 0.1503 (0.1415)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:07:50  iter: 1520  loss: 2.3080 (3.1081)  loss_box_reg: 0.1508 (0.1449)  loss_classifier: 0.1658 (0.1796)  loss_objectness: 0.0327 (0.0620)  loss_rpn_box_reg: 0.0207 (0.0274)  time: 0.4965 (0.5235)  data: 0.0078 (0.0084)  cpu2gpu: 0.1257 (0.1423)  forward: 0.2098 (0.2126)  reduce: 0.0003 (0.0003)  backward: 0.1278 (0.1416)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:08:53  iter: 1540  loss: 2.4779 (3.1005)  loss_box_reg: 0.1054 (0.1446)  loss_classifier: 0.1538 (0.1793)  loss_objectness: 0.0487 (0.0620)  loss_rpn_box_reg: 0.0170 (0.0275)  time: 0.5419 (0.5237)  data: 0.0076 (0.0084)  cpu2gpu: 0.1201 (0.1423)  forward: 0.2264 (0.2127)  reduce: 0.0003 (0.0003)  backward: 0.1441 (0.1418)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:09:55  iter: 1560  loss: 2.5180 (3.0927)  loss_box_reg: 0.1225 (0.1444)  loss_classifier: 0.1547 (0.1790)  loss_objectness: 0.0241 (0.0617)  loss_rpn_box_reg: 0.0229 (0.0274)  time: 0.5217 (0.5239)  data: 0.0083 (0.0084)  cpu2gpu: 0.1211 (0.1421)  forward: 0.2227 (0.2129)  reduce: 0.0003 (0.0003)  backward: 0.1403 (0.1421)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:11:58  iter: 1580  loss: 2.5792 (3.0861)  loss_box_reg: 0.1215 (0.1444)  loss_classifier: 0.1706 (0.1791)  loss_objectness: 0.0316 (0.0616)  loss_rpn_box_reg: 0.0158 (0.0275)  time: 0.5264 (0.5243)  data: 0.0083 (0.0084)  cpu2gpu: 0.1396 (0.1422)  forward: 0.2116 (0.2130)  reduce: 0.0003 (0.0003)  backward: 0.1335 (0.1422)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:14:21  iter: 1600  loss: 2.5311 (3.0786)  loss_box_reg: 0.1521 (0.1445)  loss_classifier: 0.1792 (0.1791)  loss_objectness: 0.0367 (0.0613)  loss_rpn_box_reg: 0.0179 (0.0274)  time: 0.5520 (0.5247)  data: 0.0078 (0.0084)  cpu2gpu: 0.1319 (0.1423)  forward: 0.2236 (0.2131)  reduce: 0.0003 (0.0003)  backward: 0.1474 (0.1424)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:15:34  iter: 1620  loss: 2.4053 (3.0714)  loss_box_reg: 0.1476 (0.1445)  loss_classifier: 0.2073 (0.1794)  loss_objectness: 0.0326 (0.0611)  loss_rpn_box_reg: 0.0178 (0.0275)  time: 0.5354 (0.5250)  data: 0.0077 (0.0084)  cpu2gpu: 0.1224 (0.1423)  forward: 0.2194 (0.2132)  reduce: 0.0003 (0.0003)  backward: 0.1756 (0.1426)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:19:16  iter: 1640  loss: 2.6117 (3.0665)  loss_box_reg: 0.1539 (0.1448)  loss_classifier: 0.1869 (0.1797)  loss_objectness: 0.0254 (0.0609)  loss_rpn_box_reg: 0.0142 (0.0274)  time: 0.5810 (0.5256)  data: 0.0082 (0.0084)  cpu2gpu: 0.1215 (0.1423)  forward: 0.2371 (0.2134)  reduce: 0.0003 (0.0003)  backward: 0.1808 (0.1430)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:21:34  iter: 1660  loss: 2.4975 (3.0600)  loss_box_reg: 0.1336 (0.1448)  loss_classifier: 0.1531 (0.1797)  loss_objectness: 0.0359 (0.0607)  loss_rpn_box_reg: 0.0204 (0.0275)  time: 0.5511 (0.5260)  data: 0.0083 (0.0084)  cpu2gpu: 0.1372 (0.1425)  forward: 0.2126 (0.2135)  reduce: 0.0003 (0.0003)  backward: 0.1353 (0.1431)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:22:11  iter: 1680  loss: 2.4077 (3.0530)  loss_box_reg: 0.1154 (0.1447)  loss_classifier: 0.1439 (0.1795)  loss_objectness: 0.0327 (0.0605)  loss_rpn_box_reg: 0.0170 (0.0274)  time: 0.5093 (0.5262)  data: 0.0083 (0.0084)  cpu2gpu: 0.1224 (0.1425)  forward: 0.2048 (0.2136)  reduce: 0.0003 (0.0003)  backward: 0.1452 (0.1432)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:24:46  iter: 1700  loss: 2.3313 (3.0448)  loss_box_reg: 0.1261 (0.1445)  loss_classifier: 0.1429 (0.1793)  loss_objectness: 0.0257 (0.0602)  loss_rpn_box_reg: 0.0145 (0.0273)  time: 0.5361 (0.5266)  data: 0.0080 (0.0084)  cpu2gpu: 0.1421 (0.1428)  forward: 0.1962 (0.2136)  reduce: 0.0003 (0.0003)  backward: 0.1095 (0.1432)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:25:25  iter: 1720  loss: 2.4918 (3.0383)  loss_box_reg: 0.1289 (0.1444)  loss_classifier: 0.1790 (0.1792)  loss_objectness: 0.0270 (0.0598)  loss_rpn_box_reg: 0.0172 (0.0273)  time: 0.5318 (0.5268)  data: 0.0080 (0.0084)  cpu2gpu: 0.1240 (0.1427)  forward: 0.2130 (0.2136)  reduce: 0.0003 (0.0003)  backward: 0.1294 (0.1433)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:26:39  iter: 1740  loss: 2.4167 (3.0315)  loss_box_reg: 0.1342 (0.1443)  loss_classifier: 0.1781 (0.1791)  loss_objectness: 0.0406 (0.0599)  loss_rpn_box_reg: 0.0182 (0.0275)  time: 0.5242 (0.5270)  data: 0.0081 (0.0084)  cpu2gpu: 0.1216 (0.1427)  forward: 0.2113 (0.2137)  reduce: 0.0003 (0.0003)  backward: 0.1244 (0.1435)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:27:20  iter: 1760  loss: 2.1650 (3.0233)  loss_box_reg: 0.1069 (0.1441)  loss_classifier: 0.1526 (0.1789)  loss_objectness: 0.0277 (0.0597)  loss_rpn_box_reg: 0.0148 (0.0274)  time: 0.5301 (0.5271)  data: 0.0077 (0.0084)  cpu2gpu: 0.1325 (0.1427)  forward: 0.2103 (0.2137)  reduce: 0.0003 (0.0003)  backward: 0.1286 (0.1436)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:28:23  iter: 1780  loss: 2.3300 (3.0167)  loss_box_reg: 0.0959 (0.1442)  loss_classifier: 0.1505 (0.1788)  loss_objectness: 0.0285 (0.0596)  loss_rpn_box_reg: 0.0154 (0.0274)  time: 0.5150 (0.5273)  data: 0.0080 (0.0084)  cpu2gpu: 0.1387 (0.1429)  forward: 0.2046 (0.2138)  reduce: 0.0003 (0.0003)  backward: 0.1238 (0.1436)  lr: 0.005000  max mem: 4976
maskrcnn/maskrcnn_benchmark/engine/trainer.py: 122 INFO: eta: 2 days, 4:27:41  iter: 1800  loss: 2.1726 (3.0082)  loss_box_reg: 0.1228 (0.1441)  loss_classifier: 0.1485 (0.1787)  loss_objectness: 0.0288 (0.0594)  loss_rpn_box_reg: 0.0167 (0.0273)  time: 0.4875 (0.5273)  data: 0.0081 (0.0084)  cpu2gpu: 0.1205 (0.1428)  forward: 0.2088 (0.2138)  reduce: 0.0003 (0.0003)  backward: 0.1166 (0.1436)  lr: 0.005000  max mem: 4976

Could you check, if you see a valid grad_fn if you print(losses_reduced)?
I’m not sure, what reduce_loss_dict does, but if it doesn’t detach the loss, you are storing the whole computation graph in your meters.

1 Like

@ptrblck Thanks for your reply. reduce_loss_dict() is used to reduce all losses computed from two GPUs for logging purpose. Here’s the code snippet:

def reduce_loss_dict(loss_dict):
    """
    Reduce the loss dictionary from all processes so that process with rank
    0 has the averaged results. Returns a dict with the same fields as
    loss_dict, after reduction.
    """
    world_size = get_world_size()
    if world_size < 2:
        return loss_dict
    with torch.no_grad():
        loss_names = []
        all_losses = []
        for k in sorted(loss_dict.keys()):
            loss_names.append(k)
            all_losses.append(loss_dict[k])
        all_losses = torch.stack(all_losses, dim=0)
        dist.reduce(all_losses, dst=0)
        if dist.get_rank() == 0:
            # only main process gets accumulated, so only divide by
            # world_size in this case
            all_losses /= world_size
        reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
    return reduced_losses

As you said, I’ve also noticed that it doesn’t detach the loss in reduce_loss_dict. However, when updating losses by meters, the losses will be detached by item(). I also tried modifying the second to the last line in reduce_loss_dict where reduced_losses = {k: v.item() for k, v in zip(loss_names, all_losses)}, but it doesn’t help.

The code snippet of meters.update():

class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)
1 Like

have you solver the problem?