Getting device errors at line all_rays_cos += cos(input_rays, target_ray)
in the below function.
I checked both tensors input_rays
and target_ray
, both are on cuda:0
. I do not know which tensor is using CPU. Any help is very much useful? Thanks!
def compute_feat_weights(src_rays, all_rays, device=None):
assert src_rays.shape[0] == all_rays.shape[0]
assert src_rays.shape[-1] == all_rays.shape[-1]
SB = src_rays.shape[0]
NS = src_rays.shape[1]
H, W = src_rays.shape[2:4]
target_ray_batch_size = all_rays.shape[1]
input_rays = src_rays.view(SB, NS, -1, 8)
cos = nn.CosineSimilarity(dim=3, eps=1e-6).to(device) # similarity func
all_rays_cos = torch.zeros(SB, NS, H*W)
for idx in range(0, target_ray_batch_size):
target_ray = all_rays[:,idx].unsqueeze(0).unsqueeze(1)
print(input_rays.device, target_ray.device)
**all_rays_cos += cos(input_rays, target_ray)**
feat_weights = torch.sum(all_rays_cos, dim=2)
norm_feat_werigts = (feat_weights - torch.mean(feat_weights)) / torch.std(feat_weights)
return norm_feat_werigts.view(SB, NS, 1)`Preformatted text`