I am trying to create an encoder - decoder network that converts a vector U (dim = 300) in one vector space to another V (dim = 300). The training data is the collection of vector pairs (u,v) such that they both represent the same object in different vector spaces. Typically, we would just ensure that the reconstructed vector (_v) and the original vector (v) have cosine similarity of 1.
Cos (v, _v) =1
However, I am trying to extend the criteria to ensure that the reconstructed vector _v has the same cosine similarity to other vectors Vj in the transformed vector space.
Essentially, I’m insisting not only that the reconstructed and the original vectors be parallel to each other but also that the reconstructed vector have the same alignment to other vectors in the same space.
I am doing this by calculating the original cosine product between v and vj and the reconstructed cosine product between _v and vj and then optimizing this to zero.
loss_func = cos(v, vj) - cos(_v, vj)
This is repeated for many vectors vj for each vector v in a batch.
Here is the torch implementation of the loss function:
cosine_loss = torch.nn.CosineEmbeddingLoss(reduction='none')
def loss_fn(train_cog, train_y, train_r):
condition = torch.ones(train_y.shape[0]).to(device)
sim_sum = torch.sum(torch.abs(train_cog + cosine_loss(train_y, train_r, condition) -1))/train_y.shape[0]
return sim_sum
where train_y is a batch of 32 reconstructed vectors [32, 300] and train_r is the vector to be paired with the each of the vector in the training batch. so train_r is also [32, 300]
train_cog is the original cosine similarity between the original vector and each of the vectors in train_r. [32,1]
However, this program keeps running into a cuda illegal memory access error:
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Is there something wrong with what I am doing? Can anyone suggest modifications to the code or a better way of doing this?