Hard sample mining with function decoration

Hi,

I’m trying to speedup my training process in the following way: I do the first forward of my model in no_grad mode. Then I calculate the loss, and do the backward. Once the backward arrives at the output of the model forward, I resample, lets say, 30% of the inputs based on their output gradient norm, and redo the forward pass, now building the computational graph.

I have a reason to believe that in my actual training problem, only a fraction of the samples induce a significant gradient, hence this kind of scheme might be useful.

I implement this kind of logic by decorating the model.forward with this HSMWrap, which uses my HSM(torch.autograd.Function) to carry out the logic (see the code).

Running with on CPU, this results in iteration runtime reduction from 0.1998 → 0.1383. This is close to the expected (as ~33% is still coming from the first forward pass, subsequent forward pass and backward pass should take ~30% of the whole iteration, hence in total ~0.63).

But turning to GPU, I see no noticeable runtime reduction.

import torch
import torch.nn as nn
import time
from torch.autograd import profiler
from functools import wraps
from typing import Callable

def HSMWrap(fn):
    @wraps(fn)
    def wrapper(x):
        return HSM.apply(x, fn)
    return wrapper

class HSM(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, original_fun):
        ctx.x = x
        ctx.original_fun = original_fun

        with torch.no_grad(): # first we do the forward pass in inference mode
            y = original_fun(x)

        y.requires_grad_(True)
        return y

    @staticmethod
    def backward(ctx, dy):
        B = ctx.x.shape[:-1].numel() # batch size
        G = dy.norm(dim=-1)
        b = int(B*0.3) # draw 30% of the samples proportionally to their output gradient norm
        idx = torch.multinomial(G, b) # draw the hard indices
        with torch.set_grad_enabled(True):
            y = ctx.original_fun(ctx.x[idx]) # second forward pass, now building the computational graph
            torch.autograd.backward((y), (dy[idx]))
        return None, None

# Define a simple MLP model
class SimpleMLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_hidden=3):
        super(SimpleMLP, self).__init__()
        # Add hidden layers
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_size, hidden_size))
        self.layers.append(nn.ReLU())
        for _ in range(1, num_hidden):
            self.layers.append(nn.Linear(hidden_size, hidden_size))
            self.layers.append(nn.ReLU())
        self.layers.append(nn.Linear(hidden_size, output_size))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Create an instance of your MLP model and move it to CUDA device
input_size = 3
output_size = 3
hidden_size = 64  # Example hidden layer size
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleMLP(input_size, hidden_size, output_size, num_hidden=3).to(device)

# Wrap the forward method with HSMDecorator
model.forward = HSMWrap(model.forward)

# Now, prepare synthetic data and perform training iterations
num_iterations = 1000
batch_size = 2**18
# Define a loss function (e.g., MSE) and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Generate random input data and labels on CUDA device
inputs = torch.randn(batch_size, input_size).to(device)
inputs.requires_grad_(True)
labels = torch.randn(batch_size, output_size).to(device)

# Start profiling
with profiler.profile(record_shapes=True, use_cuda=True) as prof:
    for i in range(num_iterations):
        # prof.step()  # Need to call this at each step to notify profiler of steps' boundary.
        iter_start = time.time()
        # Generate random input data and labels on CUDA device
        
        # Zero the gradients
        optimizer.zero_grad()
        # Forward pass
        with torch.profiler.record_function('forward'):
            outputs = model(inputs)
        # Compute loss
        loss = criterion(outputs, labels)
        # Backward pass
        with torch.profiler.record_function('backward'):
            loss.backward()
        # Update weights
        optimizer.step()
        torch.cuda.synchronize()
        print(f"\rIteration: [{i + 1}/{num_iterations}], Time: {(time.time()-iter_start):.4}", end="")

# # Print and analyze profiling results
print("")
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

Without wrapping (so commenting out model.forward = HSMWrap(model.forward)) this results in print:

Iteration: [1000/1000], Time: 0.01181

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                forward         0.75%      95.759ms         3.08%     390.711ms     390.711us      27.806ms         0.17%        6.477s       6.477ms          1000  
                                           aten::linear         0.43%      54.461ms         1.80%     228.364ms      57.091us      24.601ms         0.15%        5.549s       1.387ms          4000  
                                            aten::addmm         0.68%      86.382ms         0.86%     109.654ms      27.413us        5.481s        32.86%        5.481s       1.370ms          4000  
                                               backward        10.37%        1.315s        10.59%        1.344s       1.344ms        4.859s        29.13%        4.874s       4.874ms          1000  
    autograd::engine::evaluate_function: AddmmBackward0         0.64%      80.717ms         3.64%     461.371ms     115.343us      25.974ms         0.16%        3.119s     779.675us          4000  
                                         AddmmBackward0         0.87%     110.012ms         2.57%     325.843ms      81.461us      46.677ms         0.28%        2.472s     618.055us          4000  
                                               aten::mm         0.58%      73.359ms         0.72%      91.554ms      11.444us        2.348s        14.08%        2.348s     293.555us          8000  
     autograd::engine::evaluate_function: ReluBackward0         0.16%      20.259ms         0.48%      60.591ms      20.197us       8.433ms         0.05%        1.458s     485.893us          3000  
                                          ReluBackward0         0.14%      17.769ms         0.32%      40.332ms      13.444us       8.248ms         0.05%        1.449s     483.082us          3000  
                               aten::threshold_backward         0.12%      15.319ms         0.18%      22.557ms       7.519us        1.441s         8.64%        1.441s     480.332us          3000  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 12.687s
Self CUDA time total: 16.682s

With wrapping on:

Iteration: [1000/1000], Time: 0.01107

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::linear         0.53%      98.896ms         2.12%     392.019ms      49.002us      42.655ms         0.28%        6.494s     811.807us          8000  
                                            aten::addmm         0.81%     150.811ms         0.98%     181.945ms      22.743us        6.384s        41.51%        6.384s     798.027us          8000  
                                                forward         0.15%      28.466ms         2.08%     384.580ms     384.580us      13.552ms         0.09%        5.933s       5.933ms          1000  
                                                    HSM         0.52%      95.585ms         1.92%     356.114ms     356.114us      25.659ms         0.17%        5.919s       5.919ms          1000  
                                               backward        43.09%        7.984s        43.32%        8.026s       8.026ms        4.485s        29.16%        4.500s       4.500ms          1000  
       autograd::engine::evaluate_function: HSMBackward         0.06%      10.196ms        40.71%        7.543s       7.543ms       2.544ms         0.02%        4.417s       4.417ms          1000  
                                            HSMBackward         1.98%     366.338ms        40.66%        7.533s       7.533ms      69.772ms         0.45%        4.415s       4.415ms          1000  
                                             aten::relu         0.27%      50.232ms         0.58%     106.960ms      17.827us      15.755ms         0.10%        1.180s     196.715us          6000  
                                        aten::clamp_min         0.19%      35.375ms         0.31%      56.703ms       9.450us        1.165s         7.57%        1.165s     194.089us          6000  
    autograd::engine::evaluate_function: AddmmBackward0         0.44%      81.859ms         2.39%     443.330ms     110.832us      24.625ms         0.16%        1.078s     269.508us          4000  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 18.529s
Self CUDA time total: 15.381s

I’m asking for help in understanding why on GPU I don’t see the expected runtime reduction. Is some function perhaps so slow to run on GPU that its blocking other parts?

You could check if your current code is e.g. CPU-limited and if the CPU is thus unable to saturate the GPU with kernel launches. If so, reducing the compute workload on the GPU might now show significant speedups and you would need to tackle the bottleneck (i.e. the CPU) first. Profile the code with a visual profiler, e.g. Nsight Systems, and check if whitespaces are visible between kernel launches. If so, check if applying CUDA Graphs is feasible to reduce the CPU overhead.

Thanks @ptrblck for a quick reply! Am I supposed to be looking at the CUDA API row? There seems to bet quite a lot of whitespaces between (this is the profile when wrapping)

For comparison, here’s the visualization without wrapping:

Trying out CUDA graphs, I’m able to remove those whitespaces from the CUDA API row:

But the runtime reduction still isn’t as expected:

  • with wrapping on training takes: 2.3759379386901855 seconds
  • without wrapping, training takes: 2.795189142227173 seconds

So a reduction down to 85% yes, but not to the expected 60-70%. So I suspect there is something causing overhead still.

What do you think @ptrblck, is the cuda graph applicable here because doesn’t my selection of the hardest input elements based on their gradient norms constitute somekind of “dynamic control flow”? It doesn’t throw an error, but I’m just wondering. So I changed the code to following to apply the cuda graphs:

import torch
import torch.nn as nn
import time
from torch.autograd import profiler
from functools import wraps
from typing import Callable

import argparse

parser = argparse.ArgumentParser(description="MLP Training")
parser.add_argument("--hsm", action="store_true", help="Wrap model forward with HSMWrap")
args = parser.parse_args()

# Function decorator
def HSMWrap(fn):
    @wraps(fn)
    def wrapper(x):
        x2 = x.clone()
        x2.requires_grad_(True)
        return HSM.apply(x2, fn)
    return wrapper

class HSM(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, original_fun):
        # x.requires_grad_(True)
        ctx.x = x
        ctx.original_fun = original_fun
        with torch.no_grad(): # first we do the forward pass in inference mode
            y = original_fun(x)
        y.requires_grad_(True)
        return y # .detach()

    @staticmethod
    def backward(ctx, dy):
        B = ctx.x.shape[:-1].numel() # batch size
        G = torch.sqrt((dy).pow(2).sum(1)) # dy.norm(dim=-1)
        b = int(B*0.3) # draw 30% of the samples proportionally to their output gradient norm
        # idx = torch.multinomial(G, b) # draw the hard indices
        idx = torch.topk(G, b, sorted=False)[1]
        # print(idx)
        hard_input = ctx.x[idx]
        hard_input.requires_grad_(True)
        with torch.set_grad_enabled(True):
            y = ctx.original_fun(hard_input) # second forward pass, now building the computational graph
            torch.autograd.backward((y), (dy[idx]))
        return None, None

# Define a simple MLP model
class SimpleMLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_hidden=3):
        super(SimpleMLP, self).__init__()
        # Add hidden layers
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_size, hidden_size))
        self.layers.append(nn.ReLU())
        for _ in range(1, num_hidden):
            self.layers.append(nn.Linear(hidden_size, hidden_size))
            self.layers.append(nn.ReLU())
        self.layers.append(nn.Linear(hidden_size, output_size))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Create an instance of your MLP model and move it to CUDA device
input_size = 3
output_size = 3
hidden_size = 64  # Example hidden layer size
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SimpleMLP(input_size, hidden_size, output_size, num_hidden=8).to(device)

# Wrap the forward method with HSMDecorator
if args.hsm:
    model.forward = HSMWrap(model.forward)
# Now, prepare synthetic data and perform training iterations
num_iterations = 100
batch_size = 2**18

# Define a loss function (e.g., MSE) and optimizer
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

log_name = "hsm" if args.hsm else "default"

static_input = torch.randn(batch_size, input_size).to(device)
# static_input.requires_grad_(True)
static_target = torch.randn(batch_size, output_size).to(device)

s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        optimizer.zero_grad(set_to_none=True)
        y_pred = model(static_input)
        loss = loss_fn(y_pred, static_target)
        loss.backward()
        optimizer.step()
torch.cuda.current_stream().wait_stream(s)

# capture
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
    static_y_pred = model(static_input)
    static_loss = loss_fn(static_y_pred, static_target)
    static_loss.backward()
    optimizer.step()

t_start = time.time()

for i in range(num_iterations):
    g.replay()
    torch.cuda.synchronize()

# torch.cuda.synchronize()
print(f"training takes: {time.time() - t_start} seconds")