Slow Forward Time

I am implementing a network for point cloud. However, the forward time for feature_extraction Module is much slower than the detection module. Below is the network implementation. Is there a way to speed up the implementation?

class ResUNet2(ME.MinkowskiNetwork):
  NORM_TYPE = None
  BLOCK_NORM_TYPE = 'BN'
  CHANNELS = [None, 32, 64, 128, 256]
  TR_CHANNELS = [None, 32, 64, 64, 128]

  # To use the model, must call initialize_coords before forward pass.
  # Once data is processed, call clear to reset the model before calling initialize_coords
  def __init__(self,
               in_channels=3,
               out_channels=32,
               bn_momentum=0.1,
               normalize_feature=None,
               conv1_kernel_size=None,
               D=3):
    ME.MinkowskiNetwork.__init__(self, D)
    NORM_TYPE = self.NORM_TYPE
    BLOCK_NORM_TYPE = self.BLOCK_NORM_TYPE
    CHANNELS = self.CHANNELS
    TR_CHANNELS = self.TR_CHANNELS
    self.normalize_feature = normalize_feature
    self.conv1 = ME.MinkowskiConvolution(
        in_channels=in_channels,
        out_channels=CHANNELS[1],
        kernel_size=conv1_kernel_size,
        stride=1,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D)

    self.block1 = get_block(
        BLOCK_NORM_TYPE, CHANNELS[1], CHANNELS[1], bn_momentum=bn_momentum, D=D)

    self.conv2 = ME.MinkowskiConvolution(
        in_channels=CHANNELS[1],
        out_channels=CHANNELS[2],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D)

    self.block2 = get_block(
        BLOCK_NORM_TYPE, CHANNELS[2], CHANNELS[2], bn_momentum=bn_momentum, D=D)

    self.conv3 = ME.MinkowskiConvolution(
        in_channels=CHANNELS[2],
        out_channels=CHANNELS[3],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D)

    self.block3 = get_block(
        BLOCK_NORM_TYPE, CHANNELS[3], CHANNELS[3], bn_momentum=bn_momentum, D=D)

    self.conv4 = ME.MinkowskiConvolution(
        in_channels=CHANNELS[3],
        out_channels=CHANNELS[4],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D)

    self.block4 = get_block(
        BLOCK_NORM_TYPE, CHANNELS[4], CHANNELS[4], bn_momentum=bn_momentum, D=D)

    self.conv4_tr = ME.MinkowskiConvolutionTranspose(
        in_channels=CHANNELS[4],
        out_channels=TR_CHANNELS[4],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D)

    self.block4_tr = get_block(
        BLOCK_NORM_TYPE, TR_CHANNELS[4], TR_CHANNELS[4], bn_momentum=bn_momentum, D=D)

    self.conv3_tr = ME.MinkowskiConvolutionTranspose(
        in_channels=CHANNELS[3] + TR_CHANNELS[4],
        out_channels=TR_CHANNELS[3],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D)

    self.block3_tr = get_block(
        BLOCK_NORM_TYPE, TR_CHANNELS[3], TR_CHANNELS[3], bn_momentum=bn_momentum, D=D)

    self.conv2_tr = ME.MinkowskiConvolutionTranspose(
        in_channels=CHANNELS[2] + TR_CHANNELS[3],
        out_channels=TR_CHANNELS[2],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D)

    self.block2_tr = get_block(
        BLOCK_NORM_TYPE, TR_CHANNELS[2], TR_CHANNELS[2], bn_momentum=bn_momentum, D=D)

    self.conv1_tr = ME.MinkowskiConvolution(
        in_channels=CHANNELS[1] + TR_CHANNELS[2],
        out_channels=TR_CHANNELS[1],
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=False,
        dimension=D)

    # self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, D=D)

    self.final = ME.MinkowskiConvolution(
        in_channels=TR_CHANNELS[1],
        out_channels=out_channels,
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=True,
        dimension=D)

  def forward(self, x):
    out_s1 = self.conv1(x)
    out_s1 = self.norm1(out_s1)
    out_s1 = self.block1(out_s1)
    out = MEF.relu(out_s1)

    out_s2 = self.conv2(out)
    out_s2 = self.norm2(out_s2)
    out_s2 = self.block2(out_s2)
    out = MEF.relu(out_s2)

    out_s4 = self.conv3(out)
    out_s4 = self.norm3(out_s4)
    out_s4 = self.block3(out_s4)
    out = MEF.relu(out_s4)

    out_s8 = self.conv4(out)
    out_s8 = self.norm4(out_s8)
    out_s8 = self.block4(out_s8)
    out = MEF.relu(out_s8)

    out = self.conv4_tr(out)
    out = self.norm4_tr(out)
    out = self.block4_tr(out)
    out_s4_tr = MEF.relu(out)

    out = ME.cat(out_s4_tr, out_s4)

    out = self.conv3_tr(out)
    out = self.norm3_tr(out)
    out = self.block3_tr(out)
    out_s2_tr = MEF.relu(out)

    out = ME.cat(out_s2_tr, out_s2)

    out = self.conv2_tr(out)
    out = self.norm2_tr(out)
    out = self.block2_tr(out)
    out_s1_tr = MEF.relu(out)

    out = ME.cat(out_s1_tr, out_s1)
    out = self.conv1_tr(out)
    out = MEF.relu(out)
    out = self.final(out)

    if self.normalize_feature:
      return ME.SparseTensor(
          out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
          coords_key=out.coords_key,
          coords_manager=out.coords_man)
    else:
      return out

class Detection(nn.Module):
  # To use the model, must call initialize_coords before forward pass.
  # Once data is processed, call clear to reset the model before calling initialize_coords
  def __init__(self,batch_size,radius=10,device='cpu'):
    super(Detection, self).__init__()

    self.batch_size = batch_size
    self.radius = radius
    self.device = device
    #self.len_batch = len_batch


  def forward(self,coords,features):
    #point_len = self.len_batch
    batch_size = self.batch_size
    device = self.device
    radius = self.radius

    score = []
    for i in range(batch_size):
      score.append(self._detection_score(coords=coords[i],
                                         feature=features[i],
                                         radius=radius,
                                         device=device)
                  )
      #score1.append(self._detection_score(coords=batch_C1[i],feature=batch_F1[i],radius=radius))

    return score


  def _detection_score(self,coords=None,feature=None,radius=None,device=None):
    #find all points in a cube whose center is the point
    #get alpha score in feature map k
    #print("Device :",device)

    tmp = np.zeros((feature.shape[0],feature.shape[1]))
    alpha = torch.tensor(tmp,device=device)
    beta = torch.tensor(tmp,device=device)
    #print("alpha device:",alpha.device)
    #print("beta device:",beta.device)
    for i in range(feature.shape[0]):
      for j in range(feature.shape[1]):
        beta[i,j] = feature[i,j]/torch.max(feature[i,:])
        #logging.info(f"beta done at {i},{j}")
        center = coords[i,:]
        feature_k = feature[:,j]
        alpha[i,j] = self._alpha_score(center,coords,feature_k,radius,device)
        #logging.info(f"alpha done at {i},{j}")

    
    gamma = torch.max((alpha*beta),axis=1)  
    score = gamma/torch.sum(gamma)

    return score

  def _alpha_score(self,center,coords=None,feature_k=None,radius=None,device=None):
    mask = np.zeros(coords.shape[0],)
    for i in range(coords.shape[0]):
      mask[i] = torch.dist(coords[i,:].float(), center.float(), 2) <= radius**2

    idx = ((coords[:,0]==center[0]) * (coords[:,1]==center[1]) * (coords[:,2]==center[2])).nonzero()
    #print(idx)
    mask = np.where(mask==1)
    feature_i = feature_k[idx]
    feature_in_neighbor = feature_k[mask].to(device)
    #print(feature_in_neighbor.device)
    score = torch.exp(feature_i)/torch.sum(torch.exp(feature_in_neighbor))
    #print(score.shape)

    return score

class JointNet(nn.Module):
  def __init__(self,
                device,
                batch_size=4,
                radius=10,
                in_channels=3,
                out_channels=32,
                bn_momentum=0.1,
                normalize_feature=None,
                conv1_kernel_size=None,
                backbone_model=ResUNetBN2C,
                D=3):
    super(JointNet, self).__init__()

    self.batch_size = batch_size
    self.radius = radius
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.bn_momentum = bn_momentum
    self.normalize_feature = normalize_feature
    self.conv1_kernel_size = conv1_kernel_size
    self.backbone_model = backbone_model
    self.device = device
    #self.len_batch = len_batch

    #model = load_model(backbone_model)
    self.feature_extraction0 = backbone_model(
                              in_channels,
                              out_channels,
                              bn_momentum=bn_momentum,
                              normalize_feature=normalize_feature,
                              conv1_kernel_size=conv1_kernel_size,
                              D=3)#.to(device)
    self.feature_extraction1 = backbone_model(
                              in_channels,
                              out_channels,
                              bn_momentum=bn_momentum,
                              normalize_feature=normalize_feature,
                              conv1_kernel_size=conv1_kernel_size,
                              D=3)#.to(device)

    self.detection0 = Detection(batch_size,radius,device=device)
    self.detection1 = Detection(batch_size,radius,device=device)

  def forward(self,x0,x1):
    #x0 = x0.to(self.device)
    #x1 = x1.to(self.device)
    #logging.info(f"input device:{x0.F.device}")
    sparse0 = self.feature_extraction0(x0)
    sparse1 = self.feature_extraction1(x1)
    logging.info(f"Feature Extraction Done")
    #logging.info(f"coord at output device:{sparse1.coordinates_at(0).device}")

    batch_C1, batch_F1 = [],[]
    batch_C0, batch_F0 = [],[]
    for i in range(self.batch_size):
      #logging.info(f"Before append device:{sparse0.C.device}")
      batch_C1.append(sparse1.coordinates_at(i))
      batch_F1.append(sparse1.features_at(i))
      batch_C0.append(sparse0.coordinates_at(i))
      batch_F0.append(sparse0.features_at(i))

    logging.info(f"Coord_seperation Done")
    #logging.info(f"After append device:{batch_C0[i].device}")
    score0 = self.detection0(batch_C0,batch_F0)
    score1 = self.detection1(batch_C1,batch_F1)

    return{
     'feature0': sparse0,
     'feature1': sparse1,
     'score0': score0,
     'score1': score1
    }

Hi,
I have two naive suggestions.

  1. It seems like your feature_extraction0 and feature_extraction0 are the same module. Why not construct one and forward the combined input of x0 and x1?
  2. Nvidia APEX is a nice tool to speed up your training by leveraging 16-bit operations on the float tensor (instead of the default 32-bit)

I think it’s more clear if you can describe what the ResUNetBN2C and x0, x1 are, there may be methods to speed up depends on your case.

Hi,
I am aware of Apex package. However, I think the problem is not like this.
firstly Feature Extraction Module a network that computes sparse tensor. x0 and x1 are the input for MinkowskiNetwork with coordinates, features. See details here: https://stanfordvl.github.io/MinkowskiEngine/
After Feature Extraction Module, the output has to be extracted into coords and features based based on batch. Then compute detection and description of points then later generate loss.

The problem here is for feature extraction time in forward() only takes 1 second or less than one second for one batch. However, for detection, the forward time is almost half an hour in the same batch, which is definitely unnormal.

As for input concatenation, x0 and x1 is of different size. So it’s not possible to do concatenation at input time. I am looking forward to your reply. Thank you very much!

The input are 3D point clouds which contain around 10k points


For each batch, there will be 4 loops of computation (line number 1,2,3,6) that I think is the main reason for slow forward.
You can try not to do so much loop operation by using tensor operations. For example, in line 4 and 5, maybe you can use

mask = ((coords.float() - center[None].repeat(coords.shape[0], 1).float()) ** 2).sum(dim=1) ** 0.5 <= radius ** 2

that compute all you need by tensor instead of a loop.
Same concept can be applied to the other loop operations.

Thank you very much for your reply. I will make modification on it. Also, will that matter if both cpu and cuda tensor appears in forward time? Because coords are in cpu and feature are in cuda.

I’m not sure if this will hurt your computation speed. If you worried, you can move them all to CUDA :upside_down_face:

Also, I have another question. If I need to get the distance between all points in coords. Is there any function for that? nn.pairwisedist doesn’t apply here. Coords are of Nx3 with each row as a coordinate 1 point in 3D. If not, is there an efficient way to do that in Pytorch?

Sorry, I don’t know if the function you need exists. But I can show you what I will do:

coords_A = coords.view(coords.shape[0], 1, 3).repeat(1, coords.shape[0], 1)
coords_B = coords.view(1, coords.shape[0], 3).repeat(coords.shape[0], 1, 1)
coords_confusion = torch.stack((coords_A, coords_B), dim=2)
every_dist = ((coords_confusion[:, :, 0, :] - coords_confusion[:, :, 1, :]) ** 2).sum(dim=2) ** 0.5
# now every_dist[i,j] would be the distance of coords[i] and coords[j]

Thank you very much!

Hi,

I tried to remove all the for loop. However, this didn’t make it better.Also, when code running, my computer became extremely slow. I was wondering if this means that computation is using cpu instead of cuda. However, I did set model.to(device) in my trainer, and I have checked that it is indeed cuda device. Any idea about this? Below is the modified code.

# -*- coding: future_fstrings -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import numpy as np
import MinkowskiEngine as ME
import MinkowskiEngine.MinkowskiFunctional as MEF
from model.common import get_norm

from model.residual_block import get_block


class ResUNet2(ME.MinkowskiNetwork):
  NORM_TYPE = None
  BLOCK_NORM_TYPE = 'BN'
  CHANNELS = [None, 32, 64, 128, 256]
  TR_CHANNELS = [None, 32, 64, 64, 128]

  # To use the model, must call initialize_coords before forward pass.
  # Once data is processed, call clear to reset the model before calling initialize_coords
  def __init__(self,
               in_channels=3,
               out_channels=32,
               bn_momentum=0.1,
               normalize_feature=None,
               conv1_kernel_size=None,
               D=3):
    ME.MinkowskiNetwork.__init__(self, D)
    NORM_TYPE = self.NORM_TYPE
    BLOCK_NORM_TYPE = self.BLOCK_NORM_TYPE
    CHANNELS = self.CHANNELS
    TR_CHANNELS = self.TR_CHANNELS
    self.normalize_feature = normalize_feature
    self.conv1 = ME.MinkowskiConvolution(
        in_channels=in_channels,
        out_channels=CHANNELS[1],
        kernel_size=conv1_kernel_size,
        stride=1,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D)

    self.block1 = get_block(
        BLOCK_NORM_TYPE, CHANNELS[1], CHANNELS[1], bn_momentum=bn_momentum, D=D)

    self.conv2 = ME.MinkowskiConvolution(
        in_channels=CHANNELS[1],
        out_channels=CHANNELS[2],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D)

    self.block2 = get_block(
        BLOCK_NORM_TYPE, CHANNELS[2], CHANNELS[2], bn_momentum=bn_momentum, D=D)

    self.conv3 = ME.MinkowskiConvolution(
        in_channels=CHANNELS[2],
        out_channels=CHANNELS[3],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D)

    self.block3 = get_block(
        BLOCK_NORM_TYPE, CHANNELS[3], CHANNELS[3], bn_momentum=bn_momentum, D=D)

    self.conv4 = ME.MinkowskiConvolution(
        in_channels=CHANNELS[3],
        out_channels=CHANNELS[4],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D)

    self.block4 = get_block(
        BLOCK_NORM_TYPE, CHANNELS[4], CHANNELS[4], bn_momentum=bn_momentum, D=D)

    self.conv4_tr = ME.MinkowskiConvolutionTranspose(
        in_channels=CHANNELS[4],
        out_channels=TR_CHANNELS[4],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D)

    self.block4_tr = get_block(
        BLOCK_NORM_TYPE, TR_CHANNELS[4], TR_CHANNELS[4], bn_momentum=bn_momentum, D=D)

    self.conv3_tr = ME.MinkowskiConvolutionTranspose(
        in_channels=CHANNELS[3] + TR_CHANNELS[4],
        out_channels=TR_CHANNELS[3],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D)

    self.block3_tr = get_block(
        BLOCK_NORM_TYPE, TR_CHANNELS[3], TR_CHANNELS[3], bn_momentum=bn_momentum, D=D)

    self.conv2_tr = ME.MinkowskiConvolutionTranspose(
        in_channels=CHANNELS[2] + TR_CHANNELS[3],
        out_channels=TR_CHANNELS[2],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D)

    self.block2_tr = get_block(
        BLOCK_NORM_TYPE, TR_CHANNELS[2], TR_CHANNELS[2], bn_momentum=bn_momentum, D=D)

    self.conv1_tr = ME.MinkowskiConvolution(
        in_channels=CHANNELS[1] + TR_CHANNELS[2],
        out_channels=TR_CHANNELS[1],
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=False,
        dimension=D)

    # self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, D=D)

    self.final = ME.MinkowskiConvolution(
        in_channels=TR_CHANNELS[1],
        out_channels=out_channels,
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=True,
        dimension=D)

  def forward(self, x):
    out_s1 = self.conv1(x)
    out_s1 = self.norm1(out_s1)
    out_s1 = self.block1(out_s1)
    out = MEF.relu(out_s1)

    out_s2 = self.conv2(out)
    out_s2 = self.norm2(out_s2)
    out_s2 = self.block2(out_s2)
    out = MEF.relu(out_s2)

    out_s4 = self.conv3(out)
    out_s4 = self.norm3(out_s4)
    out_s4 = self.block3(out_s4)
    out = MEF.relu(out_s4)

    out_s8 = self.conv4(out)
    out_s8 = self.norm4(out_s8)
    out_s8 = self.block4(out_s8)
    out = MEF.relu(out_s8)

    out = self.conv4_tr(out)
    out = self.norm4_tr(out)
    out = self.block4_tr(out)
    out_s4_tr = MEF.relu(out)

    out = ME.cat(out_s4_tr, out_s4)

    out = self.conv3_tr(out)
    out = self.norm3_tr(out)
    out = self.block3_tr(out)
    out_s2_tr = MEF.relu(out)

    out = ME.cat(out_s2_tr, out_s2)

    out = self.conv2_tr(out)
    out = self.norm2_tr(out)
    out = self.block2_tr(out)
    out_s1_tr = MEF.relu(out)

    out = ME.cat(out_s1_tr, out_s1)
    out = self.conv1_tr(out)
    out = MEF.relu(out)
    out = self.final(out)

    if self.normalize_feature:
      return ME.SparseTensor(
          out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
          coords_key=out.coords_key,
          coords_manager=out.coords_man)
    else:
      return out


class ResUNetBN2(ResUNet2):
  NORM_TYPE = 'BN'


class ResUNetBN2B(ResUNet2):
  NORM_TYPE = 'BN'
  CHANNELS = [None, 32, 64, 128, 256]
  TR_CHANNELS = [None, 64, 64, 64, 64]


class ResUNetBN2C(ResUNet2):
  NORM_TYPE = 'BN'
  CHANNELS = [None, 32, 64, 128, 256]
  TR_CHANNELS = [None, 64, 64, 64, 128]


class ResUNetBN2D(ResUNet2):
  NORM_TYPE = 'BN'
  CHANNELS = [None, 32, 64, 128, 256]
  TR_CHANNELS = [None, 64, 64, 128, 128]


class ResUNetBN2E(ResUNet2):
  NORM_TYPE = 'BN'
  CHANNELS = [None, 128, 128, 128, 256]
  TR_CHANNELS = [None, 64, 128, 128, 128]


class ResUNetIN2(ResUNet2):
  NORM_TYPE = 'BN'
  BLOCK_NORM_TYPE = 'IN'


class ResUNetIN2B(ResUNetBN2B):
  NORM_TYPE = 'BN'
  BLOCK_NORM_TYPE = 'IN'


class ResUNetIN2C(ResUNetBN2C):
  NORM_TYPE = 'BN'
  BLOCK_NORM_TYPE = 'IN'


class ResUNetIN2D(ResUNetBN2D):
  NORM_TYPE = 'BN'
  BLOCK_NORM_TYPE = 'IN'

class ResUNetIN2E(ResUNetBN2E):
  NORM_TYPE = 'BN'
  BLOCK_NORM_TYPE = 'IN'

class Detection(nn.Module):
  # To use the model, must call initialize_coords before forward pass.
  # Once data is processed, call clear to reset the model before calling initialize_coords
  def __init__(self,batch_size,radius=10,device='cpu'):
    super(Detection, self).__init__()

    self.batch_size = batch_size
    self.radius = radius
    self.device = device
    self.pdist = nn.PairwiseDistance(p=2)
    #self.len_batch = len_batch


  def forward(self,coords,features):
    #point_len = self.len_batch
    batch_size = self.batch_size
    device = self.device
    radius = self.radius

    score = []
    for i in range(batch_size):
      score.append(self._detection_score(coords=coords[i],
                                         feature=features[i],
                                         radius=radius,
                                         device=device)
                  )
      #score1.append(self._detection_score(coords=batch_C1[i],feature=batch_F1[i],radius=radius))

    return score


  def _detection_score(self,coords=None,feature=None,radius=None,device=None):
    #find all points in a cube whose center is the point
    #get alpha score in feature map k
    feature = F.relu(feature)
    max_local = torch.max(feature,dim=1)[0]
    beta = feature/max_local.unsqueeze(1)
    logging.info{f"Beta Done"}

    coords_A = coords.view(coords.shape[0], 1, 3).repeat(1, coords.shape[0], 1)
    coords_B = coords.view(1, coords.shape[0], 3).repeat(coords.shape[0], 1, 1)
    coords_confusion = torch.stack((coords_A, coords_B), dim=2)
    every_dist = ((coords_confusion[:, :, 0, :] - coords_confusion[:, :, 1, :]) ** 2).sum(dim=2) ** 0.5


    neighbors = torch.topk(every_dist,9,largest=False,dim=1).indices
    neighbor9_feature = (feature[:,neighbors])[0]
    exp_feature = torch.exp(feature)
    exp_neighbor = torch.sum(torch.exp(neighbor9_feature),dim=1)
    alpha = exp_feature/exp_neighbor
    logging.info(f"Alpha Done")

    gamma = torch.max(alpha*beta,dim=1)
    logging.info(f"Gamma Done")
    score = gamma/torch.norm(gamma)

    return score

class JointNet(nn.Module):
  def __init__(self,
                device,
                batch_size=4,
                radius=10,
                in_channels=3,
                out_channels=32,
                bn_momentum=0.1,
                normalize_feature=None,
                conv1_kernel_size=None,
                backbone_model=ResUNetBN2C,
                D=3):
    super(JointNet, self).__init__()

    self.batch_size = batch_size
    self.radius = radius
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.bn_momentum = bn_momentum
    self.normalize_feature = normalize_feature
    self.conv1_kernel_size = conv1_kernel_size
    self.backbone_model = backbone_model
    self.device = device
    #self.len_batch = len_batch

    #model = load_model(backbone_model)
    self.feature_extraction0 = backbone_model(
                              in_channels,
                              out_channels,
                              bn_momentum=bn_momentum,
                              normalize_feature=normalize_feature,
                              conv1_kernel_size=conv1_kernel_size,
                              D=3)
    self.feature_extraction1 = backbone_model(
                              in_channels,
                              out_channels,
                              bn_momentum=bn_momentum,
                              normalize_feature=normalize_feature,
                              conv1_kernel_size=conv1_kernel_size,
                              D=3)#.to(device)

    self.detection0 = Detection(batch_size,radius,device=device)
    self.detection1 = Detection(batch_size,radius,device=device)

  def forward(self,x0,x1):
    #x0 = x0.to(self.device)
    #x1 = x1.to(self.device)
    #logging.info(f"input device:{x0.F.device}")
    sparse0 = self.feature_extraction0(x0)
    sparse1 = self.feature_extraction1(x1)
    logging.info(f"Feature Extraction Done")
    #logging.info(f"coord at output device:{sparse1.coordinates_at(0).device}")

    batch_C1, batch_F1 = [],[]
    batch_C0, batch_F0 = [],[]
    for i in range(self.batch_size):
      #logging.info(f"Before append device:{sparse0.C.device}")
      batch_C1.append(sparse1.coordinates_at(i))
      batch_F1.append(sparse1.features_at(i))
      batch_C0.append(sparse0.coordinates_at(i))
      batch_F0.append(sparse0.features_at(i))

    logging.info(f"Coord_seperation Done")
    #logging.info(f"After append device:{batch_C0[i].device}")
    score0 = self.detection0(batch_C0,batch_F0)
    score1 = self.detection1(batch_C1,batch_F1)

    return{
     'feature0': sparse0,
     'feature1': sparse1,
     'score0': score0,
     'score1': score1
    }






                            



i have forward time taking 6s any idea to reduce this time