Pytorch Profiler OOMs for any significant size

for any significant size, this profiler ooms my laptop, i had to tune down hidden_sizes from 256->64 and batches, batch_size, sequence_length to very low values which greatly skews the results of profiling and now profiler.step() dominates the time taken.

if config.get('profile', False):
        with torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                # torch.profiler.ProfilerActivity.CUDA,
            ],
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
            on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profiler'),
            record_shapes=True,
            profile_memory=False,
            with_stack=False
        ) as prof:
            best_eval_return = trainer.train(
                epochs=8,
                batches=4,
                bs=4,
                sl=4,
                eval_epochs=12345,
                new_best_model_cb=None,
                profiler=prof,
            )
        
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
"model_env": {
        "_": "create_rssm_env",
        "hidden_size": 64,
        "state_size": 30,
        "base_depth": 32,
        "uncertainty_predictor": null,
        "uncertainty_scale": null,
        "dynamic_factor": null,
        "dynamic_uncertainty_model": null
    },
    "actor": {
        "_": "TanhActor",
        "hidden_size": 64,
        "layer_num": 3
    },
    "critic": {
        "_": "EnsembleVCritic",
        "hidden_size": 64,
        "layer_num": 3,
        "ensemble_size": 2
    },
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*        41.61%     171.313ms        82.08%     337.931ms      67.586ms             5  
                                            aten::empty         0.97%       3.976ms         0.97%       3.976ms       0.964us          4126  
                                          aten::random_         0.01%      42.356us         0.01%      42.356us      10.589us             4  
                                             aten::item         0.20%     820.557us         1.49%       6.144ms       2.467us          2490  
                              aten::_local_scalar_dense         1.29%       5.323ms         1.29%       5.323ms       2.138us          2490  
enumerate(DataLoader)#_SingleProcessDataLoaderIter._...         1.86%       7.660ms         3.53%      14.520ms     726.009us            20  
                                            aten::slice         0.51%       2.115ms         0.60%       2.480ms       2.214us          1120  
                                       aten::as_strided         0.75%       3.104ms         0.75%       3.104ms       0.301us         10302  
                                            aten::stack         0.62%       2.539ms         2.14%       8.828ms      25.224us           350  
                                              aten::cat         2.38%       9.798ms         2.49%      10.270ms      10.902us           942  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 411.708ms

what is the issue?