Enable and disable pytorch profiler for different modules

I am trying to analyze operators’ performance using torch.profiler. I only want to record the operators used in the model (forward and backward) for several iteration, but I don’t want the operators used in generating data, putting data to device, etc…

    # DEFINE A profiler here as prof and disable it initially
    for i, (images, target) in enumerate(train_loader):
      
        train_iters=train_iters+1

        # set data to device
        images = images.to(device)
        target = target.to(device)

        # START TO TRACE
        # compute output
        output = model(images)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # PUASE THE TRACING USING PROFILER
   # After several iterations
   # 

What I have tried:

  1. The profiler record something I don’t want (operators used in data generating, to device)
    with profiler.profile(use_cuda=True,record_shapes=True) as prof:
        for i, (images, target) in enumerate(train_loader):
        
            train_iters=train_iters+1

            # ============PROFILE FROM HERE FOR SEVER ITER==========
            # set data to device
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # ============PROFILE END==========
    print(prof.key_averages(group_by_input_shape=True))
  1. using profiler.__enter__ and profiler.__exit__, but it can’t be used for recording multiple iterations.

So how can I break the profiler for some lines and resume when I want? Hope for your reply. Thanks.