Numerical Discrepancy between conv2d() and unfold()+tensordot() Seems Large

Hi All,
I encountered a large numerical discrepancy between functional.conv2d() and functional.unfold() + torch.tensordot(), which causes me trouble. My PyTorch version is 1.5.0. Now I write up a minimalistic example to reproduce:

import math
import numpy as np
import torch
import torch.optim as optim
import torch.backends.cudnn
import torch.nn as nn
import torch.nn.functional as F
import argparse

    
def main():
    ## parse arguments
    parser = argparse.ArgumentParser(description='conv2d & tensordot discrepancy test.')
    parser.add_argument('--device', default="cuda", type=str, choices=['cuda', 'cpu'], help='cuda or cpu')
    args = parser.parse_args()

    ## improve reproducibility
    np.random.seed(0)
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = False # setting this to True doubles the error
    torch.backends.cudnn.benchmark = False
    
    x = torch.randn(128,64,32,32).to(args.device) # input: B,C,W,H
    w = torch.randn(64,64,3,3).to(args.device)  # kernel: Cout,Cin,K,K 
    
    a = F.unfold(x, kernel_size=3, dilation=1, padding=1, stride=1).reshape(-1,64*3**2,32,32)   
    b = torch.tensordot(w.reshape(64,-1),a.transpose(0,1),dims=1).transpose(0,1)

    c = F.conv2d(x, weight=w, bias=None, stride=1, padding=1, dilation=1, groups=1)
    
    diff = c - b
    
    ratio = diff / torch.max(b.abs(),c.abs())
    
    print("xmax={},xmin={}".format(x.abs().max(),x.abs().min()))
    print("wmax={},wmin={}".format(w.abs().max(),w.abs().min()))
    print("bmax={},bmin={}".format(b.abs().max(),b.abs().min()))
    print("cmax={},cmin={}".format(c.abs().max(),c.abs().min()))
    print("diffmax={},diffmin={}".format(diff.abs().max(),diff.abs().min()))
    print("ratiomax={},ratiomin={}".format(ratio.abs().max(),ratio.abs().min()))
    
    
if __name__ == '__main__':
    main()

The output for device=“cuda”:

xmax=5.237013339996338,xmin=2.282601769820758e-07
wmax=4.395885944366455,wmin=4.590053504216485e-05
bmax=138.94764709472656,bmin=4.302736215322511e-06
cmax=138.9477081298828,cmin=1.5795230865478516e-06
diffmax=0.0010390281677246094,diffmin=0.0
ratiomax=1.754783272743225,ratiomin=0.0

The output for device=“cpu”:

xmax=5.237013339996338,xmin=2.282601769820758e-07
wmax=4.395885944366455,wmin=4.590053504216485e-05
bmax=138.94773864746094,bmin=1.9371509552001953e-06
cmax=138.94757080078125,cmin=3.674944991871598e-06
diffmax=0.0001678466796875,diffmin=0.0
ratiomax=0.6250920295715332,ratiomin=0.0

The magnitude of the difference seems large to me, what is causing this?

In my application the unfold()+tensordot() approach is giving much better results, yet it is much slower than conv2d() .

Thanks !

The way I typically look at this is by relating the maximal difference to the magnitude of the results (0.000168/139) in order to avoid the problem that relative difference has with cancellation (or you could look at only positive inputs). This gives about 1.2e-6, which would seem standard fare for float32 computation.

A thing to “know” – well, let’s say anecdotally be aware of – is that efficient methods to compute convolutions - prominently the Winograd method - may be more prone to see numerical precision issues than the naïve counterparts (and im2col is just rearranging the inputs to the defining formula). Intuitively, this isn’t surprising because it cleverly exploits algebraic relations between partial results which only works approximately in floating point arithmetic.

Part of the price you pay for im2col in is the materialization of very large matrices, which is costly in terms of memory but also time (because of limited memory bandwidth).
If you’re very determined, you could try to get a good kernel for direct convolution from TVM or use the algorithm selection of the MIOpen / CuDNN libraries to your advantage.

One way I always use to verify that everything is all right in general is to use float64 as a reference. (Provided both use the same code path.)

Best regards

Thomas

Thanks for the clarification! :+1: