Equivalent to torch.unsqueeze() in tensorflow

I have a line of code as follow:

mu, log_scale = torch.unsqueeze(output[:, 0, :], dim=1), torch.unsqueeze(output[:, 1, :], dim=1)

The output is 3 dimensions tensor.

Please help me convert it to equivalent tensorflow code. I’m newbie of both torch and tensorflow.
Thank for your helps.

In TF, the tf.expand_dims() is equivalent to the torch.unsqueeze().

4 Likes