CUDA memory error when using fastmoe

I am trying to use the FMoE layer from fastmoe in a ViT-Base model for a simple classification task. However, there is a gradual increase in CUDA memory, which eventually leads to out-of-memory error. Digging deeper, I observe that there is a small increase in memory every time after the loss.backward() call. Here is what the memory growth looks like:

Training (50/10000 Steps) Data time=0.03(0.04) Batch time=1.00(0.97) Memory=8588.2(8440.4)
Training (51/10000 Steps) Data time=0.04(0.04) Batch time=0.78(0.96) Memory=8590.5(8443.3)
Training (52/10000 Steps) Data time=0.03(0.04) Batch time=1.13(0.97) Memory=8602.8(8446.3)
Training (53/10000 Steps) Data time=0.04(0.04) Batch time=0.67(0.96) Memory=8602.8(8449.2)
Training (54/10000 Steps) Data time=0.08(0.04) Batch time=1.13(0.96) Memory=8602.8(8452.0)
Training (55/10000 Steps) Data time=0.02(0.04) Batch time=0.85(0.96) Memory=8602.8(8454.7)
Training (56/10000 Steps) Data time=0.06(0.04) Batch time=0.96(0.96) Memory=8602.8(8457.3)
Training (57/10000 Steps) Data time=0.03(0.04) Batch time=1.02(0.96) Memory=8602.8(8459.9)
Training (58/10000 Steps) Data time=0.06(0.04) Batch time=0.81(0.96) Memory=8623.7(8462.5)
Training (59/10000 Steps) Data time=0.04(0.04) Batch time=1.08(0.96) Memory=8623.7(8465.2)
Training (60/10000 Steps) Data time=0.04(0.04) Batch time=0.72(0.96) Memory=8623.7(8467.8)
Training (61/10000 Steps) Data time=0.03(0.04) Batch time=1.11(0.96) Memory=8623.7(8470.4)
Training (62/10000 Steps) Data time=0.02(0.04) Batch time=0.77(0.96) Memory=8623.7(8472.8)
Training (63/10000 Steps) Data time=0.04(0.04) Batch time=1.10(0.96) Memory=8655.3(8475.4)
Training (64/10000 Steps) Data time=0.02(0.04) Batch time=0.88(0.96) Memory=8655.3(8478.2)
Training (65/10000 Steps) Data time=0.04(0.04) Batch time=0.92(0.96) Memory=8667.8(8481.0)
Training (66/10000 Steps) Data time=0.03(0.04) Batch time=1.02(0.96) Memory=8667.8(8483.8)
Training (67/10000 Steps) Data time=0.04(0.04) Batch time=0.72(0.95) Memory=8667.8(8486.6)
Training (68/10000 Steps) Data time=0.03(0.04) Batch time=1.10(0.96) Memory=8667.8(8489.2)
Training (69/10000 Steps) Data time=0.02(0.04) Batch time=0.70(0.95) Memory=8667.8(8491.8)
Training (70/10000 Steps) Data time=0.06(0.04) Batch time=1.09(0.96) Memory=8667.8(8494.3)
Training (71/10000 Steps) Data time=0.02(0.04) Batch time=0.89(0.95) Memory=8667.8(8496.7)
Training (72/10000 Steps) Data time=0.05(0.04) Batch time=0.86(0.95) Memory=8725.6(8499.5)
Training (73/10000 Steps) Data time=0.03(0.04) Batch time=1.01(0.95) Memory=8725.6(8502.5)
Training (74/10000 Steps) Data time=0.04(0.04) Batch time=0.71(0.95) Memory=8725.6(8505.5)

Here is my training loop. If I replace the FMoE Mlp layer with a regular Mlp layer, the training works fine.

    for step, batch in enumerate(train_loader):
        data_time.update(time.time() - end)
        batch = tuple(t.to(args.device) for t in batch)

        x, y = batch

        pred = model(x, 0)
        loss_ = loss_fct(pred, y)
        
        loss_.backward()
        model.allreduce_params()
        optimizer.step()
        optimizer.zero_grad()
        MB = 1024 * 1024
        memory_meter.update(torch.cuda.max_memory_allocated() / MB)
        log.info("Training ({}/{} Steps)\tData time={:.2f}({:.2f})\tBatch time={:.2f}({:.2f})\tMemory={:.1f}({:.1f})".format(
            global_step, t_total, data_time.val, data_time.avg, batch_time.val, batch_time.avg, memory_meter.val, memory_meter.avg))

Can you please help me with what might be causing this? Thanks.

Are you appending any outputs or losses to a list or are you accumulating them somewhere without detaching? An increase in device memory usage often points to script issue where tensors attached to a computation graph are accumulated thus disallowing PyTorch to free the intermediate activations.

No, I try to make sure that I use .item() before appending to lists to avoid saving the computation graph. Also, the issue does not exist if I use the exact same script using a regular Mlp layer instead of the FMoE layer from fastmoe.

Ah OK. Did you already create am issue in their GitHub repository as it seems the layer is causing the issue?

Yes, I posted an issue. But this is really confusing for me since I am able to use fastmoe as part of other code repositories.