Calculate Flops in Pytorch and Tensorflow are not equal?

Given the same model, I found that the calculated flops in pytorch and tensorflow are different. I used the keras_flops (keras-flops · PyPI) in tensorflow, and ptflops (ptflops · PyPI) in pytorch to calculate flops. Is that TensorFlow has some tricks to speed up the computation so that few flops are measured? How come pytorch and tensorflow can have different flops with the same model?

The FLops in pytorch is

Model_1(
  0.013 M, 100.000% Params, 45.486 GMac, 100.000% MACs, 
  (begin): Sequential(
    0.002 M, 11.804% Params, 0.851 GMac, 1.870% MACs, 
    (0): Conv2d(0.001 M, 11.367% Params, 0.819 GMac, 1.801% MACs, 1, 56, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): PReLU(0.0 M, 0.437% Params, 0.032 GMac, 0.069% MACs, num_parameters=56)
  )
  (middle): Sequential(
    0.007 M, 52.775% Params, 3.803 GMac, 8.360% MACs, 
    (0): Conv2d(0.001 M, 5.340% Params, 0.385 GMac, 0.846% MACs, 56, 12, kernel_size=(1, 1), stride=(1, 1))
    (1): PReLU(0.0 M, 0.094% Params, 0.007 GMac, 0.015% MACs, num_parameters=12)
    (2): Conv2d(0.001 M, 10.212% Params, 0.736 GMac, 1.618% MACs, 12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): PReLU(0.0 M, 0.094% Params, 0.007 GMac, 0.015% MACs, num_parameters=12)
    (4): Conv2d(0.001 M, 10.212% Params, 0.736 GMac, 1.618% MACs, 12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): PReLU(0.0 M, 0.094% Params, 0.007 GMac, 0.015% MACs, num_parameters=12)
    (6): Conv2d(0.001 M, 10.212% Params, 0.736 GMac, 1.618% MACs, 12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): PReLU(0.0 M, 0.094% Params, 0.007 GMac, 0.015% MACs, num_parameters=12)
    (8): Conv2d(0.001 M, 10.212% Params, 0.736 GMac, 1.618% MACs, 12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): PReLU(0.0 M, 0.094% Params, 0.007 GMac, 0.015% MACs, num_parameters=12)
    (10): Conv2d(0.001 M, 5.684% Params, 0.409 GMac, 0.900% MACs, 12, 56, kernel_size=(1, 1), stride=(1, 1))
    (11): PReLU(0.0 M, 0.437% Params, 0.032 GMac, 0.069% MACs, num_parameters=56)
  )
  (final): ConvTranspose2d(0.005 M, 35.420% Params, 40.833 GMac, 89.770% MACs, 56, 1, kernel_size=(9, 9), stride=(4, 4), padding=(4, 4), output_padding=(3, 3))
)
Computational complexity:       45.49 GMac

My model in tensorflow

d=56
s=12

inp = Input((750 ,750, 1))
x = Conv2D(d, (5,5), padding='same')(inp)
x = PReLU()(x)

x = Conv2D(s, (1,1), padding='valid')(x)
x = PReLU()(x)

x = Conv2D(s, (3,3), padding='same')(x)
x = PReLU()(x)
x = Conv2D(s, (3,3), padding='same')(x)
x = PReLU()(x)
x = Conv2D(s, (3,3), padding='same')(x)
x = PReLU()(x)
x = Conv2D(s, (3,3), padding='same')(x)
x = PReLU()(x)

x = Conv2D(d, (1,1), padding='same')(x)
x = PReLU()(x)
out = Conv2DTranspose(1 ,(9,9), strides=(4, 4),padding='same',output_padding = 3)(x)

The Flops output in tensorflow is: Profile:

node name | # float_ops
Conv2D                   8.92b float_ops (100.00%, 61.95%)
Conv2DBackpropInput      5.10b float_ops (38.05%, 35.44%)
Neg                      180.00m float_ops (2.61%, 1.25%)
BiasAdd                  105.75m float_ops (1.36%, 0.73%)
Mul                      90.00m float_ops (0.63%, 0.63%)

======================End of Report==========================
The FLOPs is:14.3 GFlops

It looks like the conv2d transpose is responsible for the bulk of the flops in the PyTorch result. Can you verify that this number makes sense?

Thanks.
The flops for deconvolution is:
Cout * (1+Cin * k * k) * Hout * Wout
= 1 * (1+56 * 9 * 9) * 3000 * 3000
= 40.83 GFlops.
(Calculating a convolution neural network computation float(Others-Community))
I think pytorch calculate right, but tensorflow didn’t.
Just wonder why?

Can anyone help plz?

Bugs can happen, and since you’ve checked the calculation and think the TF calculation is wrong, I’m sure the TF devs would be happy about an issue (and a fix). :slight_smile:

Thank you for reply.
Do you think the pytorch flops calculation and my calculation are correct?
It seems to be weird that only tensorflow has problem on calculating flops.

Also, is there any pytorch tools can calculate the flops for Pixelshuffle operation?
I cannot find a equation to calculate it myself.
Thanks.

Hello @hcleung3325 I do not know if it is still a problem. However, tensorflow computes FLOPs, while tools that compute FLOPs for pytorch actually calculate MACs.

Disclaimer: I can not claim to be familiar with all pytorch tools. However, I am aware that torchprofile and fvcore are capable of computing MACs.

Pytorch

According to medium post, it is possible to compute MACs and FLOPs using the mentioned method. Therefore, I believe you calculated MACs that were close to pytorch. I would like to suggest using fvcore, the official/semi-official flops counter provided by facebook/meta, instead of torchprofile which was employed in medium post.

!pip -q install fvcore
import torch
import torchvision
from fvcore.nn import FlopCountAnalysis

convnexttiny_weights = torchvision.models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
model = torchvision.models.convnext_tiny(weights=convnexttiny_weights)

inputs = (torch.randn(1, 3, 224, 224), )

macs = FlopCountAnalysis(model, inputs)
macs = macs.total()
flops = macs * 2
print(f'MACs = {macs / 1e+9:,} G')
print(f'FLOPs = {flops / 1e+9:,} G')
MACs = 4.470437376 G
FLOPs = 8.940874752 G

Your model

d=56
s=12
model = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=d, kernel_size=(5, 5), padding=2),
            torch.nn.PReLU(),

            torch.nn.Conv2d(in_channels=d, out_channels=s, kernel_size=(1, 1), padding=0),
            torch.nn.PReLU(),

            torch.nn.Conv2d(in_channels=s, out_channels=s, kernel_size=(3 ,3), padding=1),
            torch.nn.PReLU(),

            torch.nn.Conv2d(in_channels=s, out_channels=s, kernel_size=(3 ,3), padding=1),
            torch.nn.PReLU(),

            torch.nn.Conv2d(in_channels=s, out_channels=s, kernel_size=(3 ,3), padding=1),
            torch.nn.PReLU(),

            torch.nn.Conv2d(in_channels=s, out_channels=s, kernel_size=(3 ,3), padding=1),
            torch.nn.PReLU(),
            
            torch.nn.Conv2d(in_channels=s, out_channels=d, kernel_size=(1 ,1), padding=0),
            torch.nn.PReLU(),

            torch.nn.ConvTranspose2d(in_channels=d, out_channels=1, kernel_size=(9 ,9), stride=(4, 4), padding=4, output_padding=3)
        )

inputs = (torch.randn(1, 1, 750, 750), )

macs = FlopCountAnalysis(model, inputs)

macs = macs.total()
flops = macs * 2

print(f'MACs = {macs / 1e+9:,} G')
print(f'FLOPs = {flops / 1e+9:,} G')
MACs = 7.011 G
FLOPs = 14.022 G

Just keep in mind if you set padding = 'same' or padding = 'valid' in pytorch it gives you:

MACs = 2.5515 G
FLOPs = 5.103 G

which I think it is wrong.
I calculate padding from fomula in Conv2d and ConvTranspose2d.

Tensorflow

According to the information provided in github issue, you can compute MACs and FLOPs in tensorflow using the following snippet code:

import tensorflow as tf
from tensorflow.python.profiler.model_analyzer import profile
from tensorflow.python.profiler.option_builder import ProfileOptionBuilder

def get_flops(model):
  forward_pass = tf.function(model.call, input_signature=[tf.TensorSpec(shape=(1,) + model.input_shape[1:])])
  graph_info = profile(forward_pass.get_concrete_function().graph, options=ProfileOptionBuilder.float_operation())
  flops = graph_info.total_float_ops
  return flops

model = tf.keras.applications.ConvNeXtTiny()

# model.compile(optimizer='adam', loss='bce', metrics=['accuracy'])

flops = get_flops(model)
macs = flops / 2
print(f"MACs: {macs / 1e+9:,} G")
print(f"FLOPs: {flops / 1e+9:,} G")
MACs: 4.3900329795 G
FLOPs: 8.780065959 G

Just keep in mind that the source computes MACs .

Your model

d=56
s=12

model = tf.keras.Sequential([
    tf.keras.Input((750 ,750, 1)),

    tf.keras.layers.Conv2D(d, (5,5), padding='same'),
    tf.keras.layers.PReLU(),


    tf.keras.layers.Conv2D(s, (1,1), padding='valid'),
    tf.keras.layers.PReLU(),

    tf.keras.layers.Conv2D(s, (3,3), padding='same'),
    tf.keras.layers.PReLU(),

    tf.keras.layers.Conv2D(s, (3,3), padding='same'),
    tf.keras.layers.PReLU(),

    tf.keras.layers.Conv2D(s, (3,3), padding='same'),
    tf.keras.layers.PReLU(),

    tf.keras.layers.Conv2D(s, (3,3), padding='same'),
    tf.keras.layers.PReLU(),

    tf.keras.layers.Conv2D(d, (1,1), padding='same'),
    tf.keras.layers.PReLU(),

    tf.keras.layers.Conv2DTranspose(1 ,(9,9), strides=(4, 4),padding='same',output_padding = 3)
])

def get_flops(model):
  forward_pass = tf.function(model.call, input_signature=[tf.TensorSpec(shape=(1,) + model.input_shape[1:])])
  graph_info = profile(forward_pass.get_concrete_function().graph, options=ProfileOptionBuilder.float_operation())
  flops = graph_info.total_float_ops
  return flops

flops = get_flops(model)
macs = flops / 2
print(f"MACs: {macs / 1e+9:,} G")
print(f"FLOPs: {flops / 1e+9:,} G")
MACs: 7.257375 G
FLOPs: 14.51475 G

In case your question is about that little difference between 2 frameworks, I think developers have to help us.

ConvNeXt

If you refer to Figure 2 of ConvNeXt paper, you will notice that they have reported MACs instead of FLOPs for ConvNeXtTiny. ConvNeXt code: tensorflow , pytorch

I executed the code on colab.
I kindly request that someone correct me if I am mistaken.