Inconsistent running time issue

We observe a very strange case when we compute min value on different data using GPU. Given a random tensor, it only takes 42.1 us(case A). However, when we replace it with the output of a Conv2d layer, it takes a very long time,3946.6 us(case B). We have no idea why there is such a big difference in running time. For convenience, we show the code and running results below. The results are obtained on pytorch 1.0.1.post2(py3.6_cuda10.0.130_cudnn7.4.2_2) with TITAN Xp(Driver Version: 415.27) on Ubuntu 16.04(4.4.0-131-generic).

case A

Code:

import torch


@profile
def f(x):
    N, _ = x.shape
    therehold = torch.rand((N, 1), device=x.device)
    mask = x.ge(therehold)
    mask_sum = mask.sum(dim=1, keepdim=False)
    res = mask_sum.min()
    return res


device = torch.device("cuda:0")

for i in range(100):
    x = torch.rand((128, 56 * 56), device=device)
    res = f(x)

Run using line profiler:

Timer unit: 1e-06 s

Total time: 0.012572 s
File: a.py
Function: f at line 4

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     4                                           @profile
     5                                           def f(x):
     6       100        218.0      2.2      1.7      N, _ = x.shape
     7       100       1273.0     12.7     10.1      therehold = torch.rand((N, 1), device=x.device)
     8       100       2279.0     22.8     18.1      mask = x.ge(therehold)
     9       100       4495.0     45.0     35.8      mask_sum = mask.sum(dim=1, keepdim=False)
    10       100       4212.0     42.1     33.5      res = mask_sum.min()
    11       100         95.0      0.9      0.8      return res

case B

Code:

import torch
import torch.nn as nn


@profile
def f(x):
    N, _ = x.shape
    therehold = torch.rand((N, 1), device=x.device)
    mask = x.ge(therehold)
    mask_sum = mask.sum(dim=1, keepdim=False)
    res = mask_sum.min()
    return res


device = torch.device("cuda:0")
conv = nn.Conv2d(128, 1, 3, 3).to(device)

for i in range(100):
    x = torch.rand((128, 128, 56, 56), device=device)
    x = conv(x)
    x = x.view(x.shape[0], -1)
    res = f(x)
   

Run using line profiler:

Timer unit: 1e-06 s

Total time: 0.405717 s
File: b.py
Function: f at line 5

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     5                                           @profile
     6                                           def f(x):
     7       100        245.0      2.5      0.1      N, _ = x.shape
     8       100       2303.0     23.0      0.6      therehold = torch.rand((N, 1), device=x.device)
     9       100       3341.0     33.4      0.8      mask = x.ge(therehold)
    10       100       5022.0     50.2      1.2      mask_sum = mask.sum(dim=1, keepdim=False)
    11       100     394663.0   3946.6     97.3      res = mask_sum.min()
    12       100        143.0      1.4      0.0      return res