I wanted to calculate the flops in a linear layer for a given input and output size.
For example, a Linear layer defined as torch.nn.Linear(3, 2) and given an input tensor x = torch.randn(4, 384, 3, dtype=torch.float16, device = device), will first do a matmul for the input tensor, which does 4 * 384 * 3 * 2 * 2 opearations.
Then bias addtion takes place and it will do another 4 * 384 * 2 operations.
Is the calculation right? Is it what the linear layer acually does or are there some other optimizations?