def compute_distances_no_loops(x_train, x_test):
“”"
Computes the squared Euclidean distance between each element of the training
set and each element of the test set. Images should be flattened and treated
as vectors.
This implementation should not use any Python loops. For memory-efficiency,
it also should not create any large intermediate tensors; in particular you
should not create any intermediate tensors with O(num_train*num_test)
elements.
Inputs:
-
x_train: Torch tensor of shape (num_train, C, H, W)
-
x_test: Torch tensor of shape (num_test, C, H, W)
Returns:
-
dists: Torch tensor of shape (num_train, num_test) where dists[i, j] is the
squared Euclidean distance between the ith training point and the jth test
point.
“”"
Initialize dists to be a tensor of shape (num_train, num_test) with the
same datatype and device as x_train
num_train = x_train.shape[0]
num_test = x_test.shape[0]
dists = x_train.new_zeros(num_train, num_test)
##############################################################################
TODO: Implement this function without using any explicit loops and without
creating any intermediate tensors with O(num_train * num_test) elements.
You may not use torch.norm (or its instance method variant), nor any
functions from torch.nn or torch.nn.functional.
HINT: Try to formulate the Euclidean distance using two broadcast sums
and a matrix multiply.
##############################################################################
Replace “pass” statement with your code
Output: sqrt((x-y)^2)
(x-y)^2 = x^2 + y^2 - 2xy
test_sum = torch.sum((x_test).pow(2), dim=1)
train_sum = torch.sum((x_train).pow(2), dim=1)
inner_product = (x_test)*(x_train)
dists = torch.sqrt( test_sum + train_sum-2*inner_product) # broadcast
##############################################################################
END OF YOUR CODE
##############################################################################
return dists
So I am supposed to calculate dists vector without using any loops,basically by the help of broadcasting and matrix multiplication .I think there need to be done reshaping but i am not getting how.I have done this using two loops as well as one loop