Backpropagation is so slow with ResNet

Hi there,

I’m trying to embed my data using ResNet34 as an embedding model and Triplet loss as my loss function. The backpropagation step in my training takes an insane amount of time and I need help making it faster. You can see my loss function and the training loop below:

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
        distance_positive = (anchor - positive).pow(2).sum(1)
        distance_negative = (anchor - negative).pow(2).sum(1)
        losses = torch.relu(distance_positive - distance_negative + self.margin)

        return losses.mean()
model = torchvision.models.resnet34()
model.conv1 = nn.Conv2d(1, 64, kernel_size=1)
model.apply(init_weights)
model = torch.jit.script(model).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = torch.jit.script(TripletLoss())
model.train()
for epoch in tqdm(range(epochs), desc="Epochs"):
    running_loss = []
    for step, (anchor_img, positive_img, negative_img, anchor_label) in enumerate(tqdm(train_loader, desc="Training", leave=False)):
        anchor_img = anchor_img.to(device)
        positive_img = positive_img.to(device)
        negative_img = negative_img.to(device)
        print(negative_img.shape,positive_img.shape)
        optimizer.zero_grad()
        anchor_out = model(anchor_img.float())
        positive_out = model(positive_img.float())
        negative_out = model(negative_img.float())
        
        loss = criterion(anchor_out, positive_out, negative_out)
        loss.backward()
        optimizer.step()```

How slow is the backwards pass?
On my system using a 3090 and a recent nightly pip wheel I get:

forward pass: 0.085
loss calculation:  0.001
backward pass 0.177
optimizer step 0.007
forward pass: 0.092
loss calculation:  0.001
backward pass 0.175
optimizer step 0.006
forward pass: 0.087
loss calculation:  0.001
backward pass 0.182
optimizer step 0.007
forward pass: 0.089
loss calculation:  0.001
backward pass 0.175
optimizer step 0.007
forward pass: 0.085
loss calculation:  0.001
backward pass 0.190
optimizer step 0.008
forward pass: 0.087
loss calculation:  0.001
backward pass 0.174
optimizer step 0.007
forward pass: 0.089
loss calculation:  0.001
backward pass 0.189
optimizer step 0.009
forward pass: 0.088
loss calculation:  0.001
backward pass 0.174
optimizer step 0.010
forward pass: 0.092
loss calculation:  0.001
backward pass 0.194
optimizer step 0.007
forward pass: 0.088
loss calculation:  0.000
backward pass 0.182
optimizer step 0.009

using this code:

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
        distance_positive = (anchor - positive).pow(2).sum(1)
        distance_negative = (anchor - negative).pow(2).sum(1)
        losses = torch.relu(distance_positive - distance_negative + self.margin)

        return losses.mean()

device = "cuda"
model = models.resnet34()
model.conv1 = nn.Conv2d(1, 64, kernel_size=1)
model = torch.jit.script(model).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = torch.jit.script(TripletLoss())

model.train()

anchor_img = torch.randn(16, 1, 224, 224).to(device)
positive_img = torch.randn_like(anchor_img)
negative_img = torch.randn_like(anchor_img)

# warmup
for _ in range(10):
    optimizer.zero_grad()
    anchor_out = model(anchor_img)
    positive_out = model(positive_img)
    negative_out = model(negative_img)
    
    loss = criterion(anchor_out, positive_out, negative_out)
    loss.backward()
    optimizer.step()
    
        
nb_iters = 10
for _ in range(nb_iters):
    optimizer.zero_grad()
    
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    anchor_out = model(anchor_img)
    positive_out = model(positive_img)
    negative_out = model(negative_img)
    
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    loss = criterion(anchor_out, positive_out, negative_out)
    
    torch.cuda.synchronize()
    t2 = time.perf_counter()
    loss.backward()
    
    torch.cuda.synchronize()
    t3 = time.perf_counter()
    optimizer.step()
    
    torch.cuda.synchronize()
    t4 = time.perf_counter()

    print("forward pass: {:.3f}".format(t1 - t0))
    print("loss calculation:  {:.3f}".format(t2 - t1))
    print("backward pass {:.3f}".format(t3 - t2))
    print("optimizer step {:.3f}".format(t4 - t3))
1 Like

I wasn’t using cuda cause I didn’t have it installed on my local machine, I tried the code you provided on Google colab and it was fast, Thank you very much for your help.