What's torch._C.infer_size() means?

location: torch\distributions\kl.py

combined_batch_shape = torch._C._infer_size(other._unbroadcasted_scale_tril.shape[:-2], self._unbroadcasted_scale_tril.shape[:-2])

torch.distributions.kl — PyTorch 2.1 documentation

This internal method seems to be used to compute the size needed to expand the tensors as described in the comment:

    # Expands term2 according to
    # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)

May I ask what is the difference between this internal method and the internal method of expanding the tensor size of the MultivariateNormal distribution?

like torch.distributions.multivariate_normal.MultivariateNormal.expand()

Thank you very much! :laughing:

_infer_size expects torch.Size objects as its input and will also output a torch.Size object which can then be used for further shape processing of tensors.

torch.distributions.multivariate_normal.MultivariateNormal.expand is defined as:

Returns a new distribution instance (or populates an existing instance
provided by a derived class) with batch dimensions expanded to
batch_shape. This method calls :class:~torch.Tensor.expand on
the distribution’s parameters. As such, this does not allocate new
memory for the expanded distribution instance. Additionally,
this does not repeat any args checking or parameter broadcasting in
__init__.py, when an instance is first created.

Thank you for your answer, I seem to understand the difference between the two, I am trying to use my own defined method to achieve the effect achieved by the _infer_size() method.

Thank you very much for your help! :hugs: