How to reduce the for loop with torch.einsum function?

I’ve tried to get a derivative of the output with respect to input with the convolution layer.

  1. The function of ‘gradient_higer_wrt_lowwer_layer’ of ‘CNNLayer’ is coded with ‘for’ loop but this is too slow.

  2. The function of ‘gradient_higer_wrt_lowwer_layer’ of ‘OptimizedCNNLayer’ is coded with torch.einsum() function to reduce the computing cost.

  3. The issue always occurred with torch.einsum function. Even though the code run, the result of CNNLayer and OptimizedCNNLayer are different.

  4. Also, I got same issue with conv2dTransposed Layer.

Can anyone solve this problem?

Here is my code in jupyter Notebook (conv2d layer).

#%%
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNLayer(nn.Module):
    def __init__(self, in_channels, out_channels, padding=0):
        super(CNNLayer, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=4, stride=2, padding=padding
        )
        # self.batch = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(True)

    def forward(self, input):
        output1 = self.conv(input)
        # output = self.batch(output)
        output = self.relu(output1)

        grad_act = self.gradient_leakyrelu(output1)
        grad_layer = self.gradient_higher_wrt_lowwer_layer(input, grad_act, 2)
        return output, grad_layer

    def gradient_leakyrelu(self, input):
        # pdb.set_trace()
        return torch.where(input > 0.0, 1.0, 0.01)


    def gradient_higher_wrt_lowwer_layer(self, input, act_grad, stride):
        weight = self.conv.weight
        # bias = self.conv.bias

        b, c, h, w = input.size()
        M, _, k, k = weight.size() # M: out_channels, k: kernel size
        h_out = act_grad.size(2) # (h - k) // stride + 1
        w_out = act_grad.size(3) # (w - k) // stride + 1

        # gradient of output and input
        grad_input = torch.zeros_like(input)
        # grad_bias = torch.zeros_like(bias)

        for oc in range(M):
            for i in range(0, h_out * stride, stride):
                for j in range(0, w_out * stride, stride):
                    for ic in range(c):
                        # pdb.set_trace()
                        grad_input[:, ic, i:i+k, j:j+k] = grad_input[:, ic, i:i+k, j:j+k] + \
                                                          weight[oc, ic].unsqueeze(0) * act_grad[:, oc, i//stride, j//stride].view(b,1,1)

                    # grad_bias[oc] = grad_bias[oc] + act_grad[:, oc, i//stride, j//stride]
        return grad_input

class OptimizedCNNLayer(nn.Module):
    def __init__(self, in_channels, out_channels, padding=0):
        super(OptimizedCNNLayer, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=4, stride=2, padding=padding
        )
        self.relu = nn.LeakyReLU(True)

    def forward(self, input):
        output1 = self.conv(input)
        output = self.relu(output1)

        grad_act = self.gradient_leakyrelu(output1)
        grad_layer = self.optimized_gradient_higher_wrt_lower_layer(input, grad_act, self.conv.stride[0])
        return output, grad_layer

    def gradient_leakyrelu(self, input):
        """Derivative of Leaky ReLU function"""
        return torch.where(input > 0.0, 1.0, 0.01)

    def optimized_gradient_higher_wrt_lower_layer(self, input, act_grad, stride):
        """Optimized computation of gradient of CNN layer"""
        weight = self.conv.weight  # Shape: (out_channels, in_channels, kernel_size, kernel_size)

        b, c, h, w = input.shape
        M, _, k, k = weight.shape  # Correct order: (out_channels, in_channels, kernel_size, kernel_size)
        h_out = act_grad.size(2)  # Derived from activation gradient
        w_out = act_grad.size(3)

        # **Expand act_grad to match spatial positions**
        act_grad_expanded = act_grad.repeat_interleave(stride, dim=2).repeat_interleave(stride, dim=3)

        # **Flatten weights for efficient einsum**
        weight_flat = weight.view(M, c, k * k)

        # **Use unfold() to extract patches**
        input_unfold = F.unfold(input, kernel_size=k, stride=stride)  # Shape: (B, C*k*k, H_out*W_out)
        input_unfold = input_unfold.view(b, c, k * k, h_out, w_out)  # Reshape for broadcasting

        # **Compute gradient using optimized einsum**
        grad_unfolded = torch.einsum("bmij,mcq->bcqij", act_grad_expanded, weight_flat)

        # **Fix the reshape for fold() - Corrected Shape**
        grad_unfolded = grad_unfolded.contiguous().view(b, c * k * k, h_out * w_out)  # Correct shape before folding

        # **Fold back to original shape**
        grad_input = F.fold(grad_unfolded, output_size=(h, w), kernel_size=k, stride=stride)

        return grad_input

#%%
# Example Usage
B, C, H, W, K = 2, 3, 8, 8, 4  # Batch=2, Channels=3, Spatial=8x8, Kernel=4x4
M = 2  # Number of output channels
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize tensors
a = torch.rand(B, C, H, W, dtype=torch.float32)
w = torch.rand(M, C, K, K, dtype=torch.float32)

# Create model
model0 = CNNLayer(C, M)
model1 = OptimizedCNNLayer(C, M)

# Compute gradient db/da manually
output0, grad_layer0 = model0(a)
output1, grad_layer1 = model1(a)
print(grad_layer0 == grad_layer1)

Here is my code in jupyter Notebook (conv2dTransposed layer).

#%%
class CNNTrasnposedLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2):
        super(CNNTrasnposedLayer, self).__init__()
        self.conv = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1
        )
        # self.batch = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(True)

    def forward(self, input):
        output1 = self.conv(input)
        # output = self.batch(output)
        output = self.relu(output1)

        grad_act = self.gradient_leakyrelu(output1)
        grad_layer = self.gradient_higher_wrt_lowwer_layer(input, grad_act, 2, 1)
        return output, grad_layer

    def gradient_leakyrelu(self, input):
        # pdb.set_trace()
        return torch.where(input > 0.0, 1.0, 0.01)

    def gradient_higher_wrt_lowwer_layer(self, input, act_grad, stride=2, padding=1):

        weight = self.conv.weight
        # bias = self.conv.bias
        # print(input.size())
        # print(act_grad.size())
        # print(weight.size())
        b, c, h, w = input.size()
        _, M, k, k = weight.size() # M: in_channels, _: out_channels, k: kernel size
        h_out = (h - 1) * stride - 2 * padding + k
        w_out = (w - 1) * stride - 2 * padding + k

        grad_input = torch.zeros_like(input)

        for ic in range(c):
            for oc in range(M):
                for i in range(h):
                    for j in range(w):
                        out_x = i * stride - padding
                        out_y = j * stride - padding
                        for kx in range(k):
                            for ky in range(k):
                                # print(grad_input[:, ic, i, j].size())
                                # print(act_grad[:, oc, out_x + kx, out_y + ky].size())
                                # print(weight[oc, ic, kx, ky].size())
                                if 0 <= out_x + kx < h_out and 0 <= out_y + ky < w_out:
                                    grad_input[:, ic, i, j] = (grad_input[:, ic, i, j] + \
                                        act_grad[:, oc, out_x + kx, out_y + ky] * weight[ic, oc, kx, ky]
                                    )

        return grad_input

class OptimizedCNNTrasnposedLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2):
        super(OptimizedCNNTrasnposedLayer, self).__init__()
        self.conv = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1
        )
        # self.batch = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(True)

    def forward(self, input):
        output1 = self.conv(input)
        # output = self.batch(output)
        output = self.relu(output1)

        grad_act = self.gradient_leakyrelu(output1)
        grad_layer = self.gradient_higher_wrt_lowwer_layer(input, grad_act, 2, 1)
        return output, grad_layer

    def gradient_leakyrelu(self, input):
        # pdb.set_trace()
        return torch.where(input > 0.0, 1.0, 0.01)

    def gradient_higher_wrt_lowwer_layer(self, input, act_grad, stride=2, padding=1):

        weight = self.conv.weight
        # bias = self.conv.bias
        # print(input.size())
        # print(act_grad.size())
        # print(weight.size())
        b, c, h, w = input.size()
        _, M, k, k = weight.size() # M: in_channels, _: out_channels, k: kernel size

        grad_input = torch.zeros_like(input)

        for ic in range(c):
            for oc in range(M):
                # Compute spatial influence of weight kernels
                act_grad_unfold = act_grad[:, oc].unfold(1, k, 1).unfold(2, k, 1)  # Extract patches
                weight_unfold = weight[ic, oc].view(k, k)  # Reshape kernel

                # Multiply and sum efficiently using einsum
                grad_input[:, ic] = grad_input[:, ic] + torch.einsum("bijk,kl->bijl", act_grad_unfold, weight_unfold)

        return grad_input


#%%
# Example Usage
B, C, H, W, K = 2, 3, 8, 8, 4  # Batch=2, Channels=3, Spatial=8x8, Kernel=4x4
M = 2  # Number of output channels
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize tensors
a = torch.rand(B, C, H, W, dtype=torch.float32)
w = torch.rand(M, C, K, K, dtype=torch.float32)

# Create model
model0 = CNNTrasnposedLayer(C, M)
model1 = OptimizedCNNTrasnposedLayer(C, M)

# Compute gradient db/da manually
output0, grad_layer0 = model0(a)
output1, grad_layer1 = model1(a)
print(grad_layer0.size())
print(grad_layer1.size())
print(grad_layer0 == grad_layer1)

Hi Maroo!

Just to check:

Here you are performing an exact equality test on floating-point numbers. Do
grad_layer0 and grad_layer1 differ by more than some reasonable round-off
error or are they actually “equal” within round-off error?

I haven’t looked at your code in any detail, but assuming that your two computations
are mathematically equivalent, there is still no reason to expect them to be numerically
equivalent and they will likely differ by round-off error. (Note that in general with
round-off error it is not typically the case that one result is better than another – they’re
both equally good floating-point approximations to the “true” result, just different.)

Best.

K. Frank