from unittest import mock print(torch.equal(mock._Call(), torch.zeros((2, 2))))
and this output is like this.
call.__torch_function__(<built-in method equal of type object at 0x10cc80770>, (<class 'unittest.mock._Call'>,), (call(), tensor([[0., 0.], [0., 0.]])))
torch.equal correctly working in this case? I found this by my typo in unit testing like this.
import unittest from unittest.mock import MagicMock, patch import torch def calc_mean(input_tensor): means = torch.mean(input_tensor) return means class Test(unittest.TestCase): @patch("torch.mean") def test_sample(self, mock_torch_mean: MagicMock): mock_torch_mean.return_value = 1.0 input_tensor = torch.zeros((2, 2)) result = calc_mean(input_tensor=input_tensor) # My typo is here and mock_torch_mean.call_args.args is correct. # However, self.assertTrue passes in this case. self.assertTrue(torch.equal(mock_torch_mean.call_args, input_tensor)) if __name__ == "__main__": unittest.main()