I am implementing a network for point cloud. However, the forward time for feature_extraction Module is much slower than the detection module. Below is the network implementation. Is there a way to speed up the implementation?
class ResUNet2(ME.MinkowskiNetwork):
NORM_TYPE = None
BLOCK_NORM_TYPE = 'BN'
CHANNELS = [None, 32, 64, 128, 256]
TR_CHANNELS = [None, 32, 64, 64, 128]
# To use the model, must call initialize_coords before forward pass.
# Once data is processed, call clear to reset the model before calling initialize_coords
def __init__(self,
in_channels=3,
out_channels=32,
bn_momentum=0.1,
normalize_feature=None,
conv1_kernel_size=None,
D=3):
ME.MinkowskiNetwork.__init__(self, D)
NORM_TYPE = self.NORM_TYPE
BLOCK_NORM_TYPE = self.BLOCK_NORM_TYPE
CHANNELS = self.CHANNELS
TR_CHANNELS = self.TR_CHANNELS
self.normalize_feature = normalize_feature
self.conv1 = ME.MinkowskiConvolution(
in_channels=in_channels,
out_channels=CHANNELS[1],
kernel_size=conv1_kernel_size,
stride=1,
dilation=1,
has_bias=False,
dimension=D)
self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D)
self.block1 = get_block(
BLOCK_NORM_TYPE, CHANNELS[1], CHANNELS[1], bn_momentum=bn_momentum, D=D)
self.conv2 = ME.MinkowskiConvolution(
in_channels=CHANNELS[1],
out_channels=CHANNELS[2],
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D)
self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D)
self.block2 = get_block(
BLOCK_NORM_TYPE, CHANNELS[2], CHANNELS[2], bn_momentum=bn_momentum, D=D)
self.conv3 = ME.MinkowskiConvolution(
in_channels=CHANNELS[2],
out_channels=CHANNELS[3],
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D)
self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D)
self.block3 = get_block(
BLOCK_NORM_TYPE, CHANNELS[3], CHANNELS[3], bn_momentum=bn_momentum, D=D)
self.conv4 = ME.MinkowskiConvolution(
in_channels=CHANNELS[3],
out_channels=CHANNELS[4],
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D)
self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D)
self.block4 = get_block(
BLOCK_NORM_TYPE, CHANNELS[4], CHANNELS[4], bn_momentum=bn_momentum, D=D)
self.conv4_tr = ME.MinkowskiConvolutionTranspose(
in_channels=CHANNELS[4],
out_channels=TR_CHANNELS[4],
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D)
self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D)
self.block4_tr = get_block(
BLOCK_NORM_TYPE, TR_CHANNELS[4], TR_CHANNELS[4], bn_momentum=bn_momentum, D=D)
self.conv3_tr = ME.MinkowskiConvolutionTranspose(
in_channels=CHANNELS[3] + TR_CHANNELS[4],
out_channels=TR_CHANNELS[3],
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D)
self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D)
self.block3_tr = get_block(
BLOCK_NORM_TYPE, TR_CHANNELS[3], TR_CHANNELS[3], bn_momentum=bn_momentum, D=D)
self.conv2_tr = ME.MinkowskiConvolutionTranspose(
in_channels=CHANNELS[2] + TR_CHANNELS[3],
out_channels=TR_CHANNELS[2],
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D)
self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D)
self.block2_tr = get_block(
BLOCK_NORM_TYPE, TR_CHANNELS[2], TR_CHANNELS[2], bn_momentum=bn_momentum, D=D)
self.conv1_tr = ME.MinkowskiConvolution(
in_channels=CHANNELS[1] + TR_CHANNELS[2],
out_channels=TR_CHANNELS[1],
kernel_size=1,
stride=1,
dilation=1,
has_bias=False,
dimension=D)
# self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, D=D)
self.final = ME.MinkowskiConvolution(
in_channels=TR_CHANNELS[1],
out_channels=out_channels,
kernel_size=1,
stride=1,
dilation=1,
has_bias=True,
dimension=D)
def forward(self, x):
out_s1 = self.conv1(x)
out_s1 = self.norm1(out_s1)
out_s1 = self.block1(out_s1)
out = MEF.relu(out_s1)
out_s2 = self.conv2(out)
out_s2 = self.norm2(out_s2)
out_s2 = self.block2(out_s2)
out = MEF.relu(out_s2)
out_s4 = self.conv3(out)
out_s4 = self.norm3(out_s4)
out_s4 = self.block3(out_s4)
out = MEF.relu(out_s4)
out_s8 = self.conv4(out)
out_s8 = self.norm4(out_s8)
out_s8 = self.block4(out_s8)
out = MEF.relu(out_s8)
out = self.conv4_tr(out)
out = self.norm4_tr(out)
out = self.block4_tr(out)
out_s4_tr = MEF.relu(out)
out = ME.cat(out_s4_tr, out_s4)
out = self.conv3_tr(out)
out = self.norm3_tr(out)
out = self.block3_tr(out)
out_s2_tr = MEF.relu(out)
out = ME.cat(out_s2_tr, out_s2)
out = self.conv2_tr(out)
out = self.norm2_tr(out)
out = self.block2_tr(out)
out_s1_tr = MEF.relu(out)
out = ME.cat(out_s1_tr, out_s1)
out = self.conv1_tr(out)
out = MEF.relu(out)
out = self.final(out)
if self.normalize_feature:
return ME.SparseTensor(
out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
coords_key=out.coords_key,
coords_manager=out.coords_man)
else:
return out
class Detection(nn.Module):
# To use the model, must call initialize_coords before forward pass.
# Once data is processed, call clear to reset the model before calling initialize_coords
def __init__(self,batch_size,radius=10,device='cpu'):
super(Detection, self).__init__()
self.batch_size = batch_size
self.radius = radius
self.device = device
#self.len_batch = len_batch
def forward(self,coords,features):
#point_len = self.len_batch
batch_size = self.batch_size
device = self.device
radius = self.radius
score = []
for i in range(batch_size):
score.append(self._detection_score(coords=coords[i],
feature=features[i],
radius=radius,
device=device)
)
#score1.append(self._detection_score(coords=batch_C1[i],feature=batch_F1[i],radius=radius))
return score
def _detection_score(self,coords=None,feature=None,radius=None,device=None):
#find all points in a cube whose center is the point
#get alpha score in feature map k
#print("Device :",device)
tmp = np.zeros((feature.shape[0],feature.shape[1]))
alpha = torch.tensor(tmp,device=device)
beta = torch.tensor(tmp,device=device)
#print("alpha device:",alpha.device)
#print("beta device:",beta.device)
for i in range(feature.shape[0]):
for j in range(feature.shape[1]):
beta[i,j] = feature[i,j]/torch.max(feature[i,:])
#logging.info(f"beta done at {i},{j}")
center = coords[i,:]
feature_k = feature[:,j]
alpha[i,j] = self._alpha_score(center,coords,feature_k,radius,device)
#logging.info(f"alpha done at {i},{j}")
gamma = torch.max((alpha*beta),axis=1)
score = gamma/torch.sum(gamma)
return score
def _alpha_score(self,center,coords=None,feature_k=None,radius=None,device=None):
mask = np.zeros(coords.shape[0],)
for i in range(coords.shape[0]):
mask[i] = torch.dist(coords[i,:].float(), center.float(), 2) <= radius**2
idx = ((coords[:,0]==center[0]) * (coords[:,1]==center[1]) * (coords[:,2]==center[2])).nonzero()
#print(idx)
mask = np.where(mask==1)
feature_i = feature_k[idx]
feature_in_neighbor = feature_k[mask].to(device)
#print(feature_in_neighbor.device)
score = torch.exp(feature_i)/torch.sum(torch.exp(feature_in_neighbor))
#print(score.shape)
return score
class JointNet(nn.Module):
def __init__(self,
device,
batch_size=4,
radius=10,
in_channels=3,
out_channels=32,
bn_momentum=0.1,
normalize_feature=None,
conv1_kernel_size=None,
backbone_model=ResUNetBN2C,
D=3):
super(JointNet, self).__init__()
self.batch_size = batch_size
self.radius = radius
self.in_channels = in_channels
self.out_channels = out_channels
self.bn_momentum = bn_momentum
self.normalize_feature = normalize_feature
self.conv1_kernel_size = conv1_kernel_size
self.backbone_model = backbone_model
self.device = device
#self.len_batch = len_batch
#model = load_model(backbone_model)
self.feature_extraction0 = backbone_model(
in_channels,
out_channels,
bn_momentum=bn_momentum,
normalize_feature=normalize_feature,
conv1_kernel_size=conv1_kernel_size,
D=3)#.to(device)
self.feature_extraction1 = backbone_model(
in_channels,
out_channels,
bn_momentum=bn_momentum,
normalize_feature=normalize_feature,
conv1_kernel_size=conv1_kernel_size,
D=3)#.to(device)
self.detection0 = Detection(batch_size,radius,device=device)
self.detection1 = Detection(batch_size,radius,device=device)
def forward(self,x0,x1):
#x0 = x0.to(self.device)
#x1 = x1.to(self.device)
#logging.info(f"input device:{x0.F.device}")
sparse0 = self.feature_extraction0(x0)
sparse1 = self.feature_extraction1(x1)
logging.info(f"Feature Extraction Done")
#logging.info(f"coord at output device:{sparse1.coordinates_at(0).device}")
batch_C1, batch_F1 = [],[]
batch_C0, batch_F0 = [],[]
for i in range(self.batch_size):
#logging.info(f"Before append device:{sparse0.C.device}")
batch_C1.append(sparse1.coordinates_at(i))
batch_F1.append(sparse1.features_at(i))
batch_C0.append(sparse0.coordinates_at(i))
batch_F0.append(sparse0.features_at(i))
logging.info(f"Coord_seperation Done")
#logging.info(f"After append device:{batch_C0[i].device}")
score0 = self.detection0(batch_C0,batch_F0)
score1 = self.detection1(batch_C1,batch_F1)
return{
'feature0': sparse0,
'feature1': sparse1,
'score0': score0,
'score1': score1
}