Dear all,
In my implementation, the training memory is so large. In forward process, the total memory reported by torch.cuda.memory_allocated(self.device) is 71874058.139, which is about 71MB. However, during training, it can’t even fit into a 32GB Device.
Here is my network. The JointNet in the bottom is the network.
# -*- coding: future_fstrings -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import numpy as np
import MinkowskiEngine as ME
import MinkowskiEngine.MinkowskiFunctional as MEF
from model.common import get_norm
from model.residual_block import get_block
class ResUNet2(ME.MinkowskiNetwork):
NORM_TYPE = None
BLOCK_NORM_TYPE = 'BN'
CHANNELS = [None, 32, 64, 128, 256]
TR_CHANNELS = [None, 32, 64, 64, 128]
# To use the model, must call initialize_coords before forward pass.
# Once data is processed, call clear to reset the model before calling initialize_coords
def __init__(self,
in_channels=3,
out_channels=32,
bn_momentum=0.1,
normalize_feature=None,
conv1_kernel_size=None,
D=3):
ME.MinkowskiNetwork.__init__(self, D)
NORM_TYPE = self.NORM_TYPE
BLOCK_NORM_TYPE = self.BLOCK_NORM_TYPE
CHANNELS = self.CHANNELS
TR_CHANNELS = self.TR_CHANNELS
self.normalize_feature = normalize_feature
self.conv1 = ME.MinkowskiConvolution(
in_channels=in_channels,
out_channels=CHANNELS[1],
kernel_size=conv1_kernel_size,
stride=1,
dilation=1,
has_bias=False,
dimension=D)
self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D)
self.block1 = get_block(
BLOCK_NORM_TYPE, CHANNELS[1], CHANNELS[1], bn_momentum=bn_momentum, D=D)
self.conv2 = ME.MinkowskiConvolution(
in_channels=CHANNELS[1],
out_channels=CHANNELS[2],
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D)
self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D)
self.block2 = get_block(
BLOCK_NORM_TYPE, CHANNELS[2], CHANNELS[2], bn_momentum=bn_momentum, D=D)
self.conv3 = ME.MinkowskiConvolution(
in_channels=CHANNELS[2],
out_channels=CHANNELS[3],
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D)
self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D)
self.block3 = get_block(
BLOCK_NORM_TYPE, CHANNELS[3], CHANNELS[3], bn_momentum=bn_momentum, D=D)
self.conv4 = ME.MinkowskiConvolution(
in_channels=CHANNELS[3],
out_channels=CHANNELS[4],
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D)
self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D)
self.block4 = get_block(
BLOCK_NORM_TYPE, CHANNELS[4], CHANNELS[4], bn_momentum=bn_momentum, D=D)
self.conv4_tr = ME.MinkowskiConvolutionTranspose(
in_channels=CHANNELS[4],
out_channels=TR_CHANNELS[4],
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D)
self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D)
self.block4_tr = get_block(
BLOCK_NORM_TYPE, TR_CHANNELS[4], TR_CHANNELS[4], bn_momentum=bn_momentum, D=D)
self.conv3_tr = ME.MinkowskiConvolutionTranspose(
in_channels=CHANNELS[3] + TR_CHANNELS[4],
out_channels=TR_CHANNELS[3],
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D)
self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D)
self.block3_tr = get_block(
BLOCK_NORM_TYPE, TR_CHANNELS[3], TR_CHANNELS[3], bn_momentum=bn_momentum, D=D)
self.conv2_tr = ME.MinkowskiConvolutionTranspose(
in_channels=CHANNELS[2] + TR_CHANNELS[3],
out_channels=TR_CHANNELS[2],
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D)
self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D)
self.block2_tr = get_block(
BLOCK_NORM_TYPE, TR_CHANNELS[2], TR_CHANNELS[2], bn_momentum=bn_momentum, D=D)
self.conv1_tr = ME.MinkowskiConvolution(
in_channels=CHANNELS[1] + TR_CHANNELS[2],
out_channels=TR_CHANNELS[1],
kernel_size=1,
stride=1,
dilation=1,
has_bias=False,
dimension=D)
# self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, D=D)
self.final = ME.MinkowskiConvolution(
in_channels=TR_CHANNELS[1],
out_channels=out_channels,
kernel_size=1,
stride=1,
dilation=1,
has_bias=True,
dimension=D)
def forward(self, x):
out_s1 = self.conv1(x)
out_s1 = self.norm1(out_s1)
out_s1 = self.block1(out_s1)
out = MEF.relu(out_s1)
out_s2 = self.conv2(out)
out_s2 = self.norm2(out_s2)
out_s2 = self.block2(out_s2)
out = MEF.relu(out_s2)
out_s4 = self.conv3(out)
out_s4 = self.norm3(out_s4)
out_s4 = self.block3(out_s4)
out = MEF.relu(out_s4)
out_s8 = self.conv4(out)
out_s8 = self.norm4(out_s8)
out_s8 = self.block4(out_s8)
out = MEF.relu(out_s8)
out = self.conv4_tr(out)
out = self.norm4_tr(out)
out = self.block4_tr(out)
out_s4_tr = MEF.relu(out)
out = ME.cat(out_s4_tr, out_s4)
out = self.conv3_tr(out)
out = self.norm3_tr(out)
out = self.block3_tr(out)
out_s2_tr = MEF.relu(out)
out = ME.cat(out_s2_tr, out_s2)
out = self.conv2_tr(out)
out = self.norm2_tr(out)
out = self.block2_tr(out)
out_s1_tr = MEF.relu(out)
out = ME.cat(out_s1_tr, out_s1)
out = self.conv1_tr(out)
out = MEF.relu(out)
out = self.final(out)
if self.normalize_feature:
return ME.SparseTensor(
out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
coords_key=out.coords_key,
coords_manager=out.coords_man)
else:
return out
class ResUNetBN2(ResUNet2):
NORM_TYPE = 'BN'
class ResUNetBN2B(ResUNet2):
NORM_TYPE = 'BN'
CHANNELS = [None, 32, 64, 128, 256]
TR_CHANNELS = [None, 64, 64, 64, 64]
class ResUNetBN2C(ResUNet2):
NORM_TYPE = 'BN'
CHANNELS = [None, 32, 64, 128, 256]
TR_CHANNELS = [None, 64, 64, 64, 128]
class ResUNetBN2D(ResUNet2):
NORM_TYPE = 'BN'
CHANNELS = [None, 32, 64, 128, 256]
TR_CHANNELS = [None, 64, 64, 128, 128]
class ResUNetBN2E(ResUNet2):
NORM_TYPE = 'BN'
CHANNELS = [None, 128, 128, 128, 256]
TR_CHANNELS = [None, 64, 128, 128, 128]
class ResUNetIN2(ResUNet2):
NORM_TYPE = 'BN'
BLOCK_NORM_TYPE = 'IN'
class ResUNetIN2B(ResUNetBN2B):
NORM_TYPE = 'BN'
BLOCK_NORM_TYPE = 'IN'
class ResUNetIN2C(ResUNetBN2C):
NORM_TYPE = 'BN'
BLOCK_NORM_TYPE = 'IN'
class ResUNetIN2D(ResUNetBN2D):
NORM_TYPE = 'BN'
BLOCK_NORM_TYPE = 'IN'
class ResUNetIN2E(ResUNetBN2E):
NORM_TYPE = 'BN'
BLOCK_NORM_TYPE = 'IN'
class Detection(nn.Module):
# To use the model, must call initialize_coords before forward pass.
# Once data is processed, call clear to reset the model before calling initialize_coords
def __init__(self,radius=10,device='cpu'):
super(Detection, self).__init__()
self.device = device
#self.len_batch = len_batch
def forward(self,coords,features,len_batch):
#point_len = self.len_batch
device = self.device
score = (torch.Tensor()).to(device)
for i in range(len_batch):
batch_score = self._detection_score(coords[i],features[i])
#print(batch_score.device)
score = torch.cat((score,batch_score),0)
#print(score.shape)
#score1.append(self._detection_score(coords=batch_C1[i],feature=batch_F1[i],radius=radius))
return score
def _detection_score(self,coords=None,feature=None):
#find all points in a cube whose center is the point
#get alpha score in feature map k
feature = F.relu(feature)
max_local = torch.max(feature,dim=1)[0]
beta = feature/max_local.unsqueeze(1)
del max_local
#logging.info(f"Beta Done")
coords_A = (coords.view(coords.shape[0], 1, 3).repeat(1, coords.shape[0], 1)).short()
coords_B = (coords.view(1, coords.shape[0], 3).repeat(coords.shape[0], 1, 1)).short()
coords_confusion = (torch.stack((coords_A, coords_B), dim=2)).short()
del coords_A,coords_B
every_dist = (((coords_confusion[:, :, 0, :] - coords_confusion[:, :, 1, :]) ** 2).sum(dim=2) ** 0.5)
neighbors = (torch.topk(every_dist,1,largest=False,dim=1).indices)
del every_dist
neighbor9_feature = (feature[neighbors,:])[0]
del neighbors
exp_feature = torch.exp(feature)
exp_neighbor = torch.sum(torch.exp(neighbor9_feature),dim=0)
alpha = exp_feature/exp_neighbor
del exp_feature,exp_neighbor
#logging.info(f"Alpha Done")
gamma = torch.max(alpha*beta,dim=1).values
del alpha,beta
#logging.info(f"Gamma Done, gamma dimension{gamma.shape}")
score = gamma/torch.norm(gamma)
del gamma
#print(score.device)
torch.cuda.empty_cache()
return score
class JointNet(nn.Module):
def __init__(self,
device,
batch_size=4,
in_channels=3,
out_channels=32,
bn_momentum=0.1,
normalize_feature=None,
conv1_kernel_size=None,
backbone_model=ResUNetBN2C,
D=3):
super(JointNet, self).__init__()
self.batch_size = batch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.bn_momentum = bn_momentum
self.normalize_feature = normalize_feature
self.conv1_kernel_size = conv1_kernel_size
self.backbone_model = backbone_model
self.device = device
#self.batch_len = batch_len
#self.len_batch = len_batch
#model = load_model(backbone_model)
self.feature_extraction0 = backbone_model(
in_channels,
out_channels,
bn_momentum=bn_momentum,
normalize_feature=normalize_feature,
conv1_kernel_size=conv1_kernel_size,
D=3)
self.feature_extraction1 = backbone_model(
in_channels,
out_channels,
bn_momentum=bn_momentum,
normalize_feature=normalize_feature,
conv1_kernel_size=conv1_kernel_size,
D=3)#.to(device)
self.detection0 = Detection(device=device)
self.detection1 = Detection(device=device)
def forward(self,x0,x1,len_batch):
#x0 = x0.to(self.device)
#x1 = x1.to(self.device)
#logging.info(f"input device:{x0.F.device}")
#print(len_batch)
sparse0 = self.feature_extraction0(x0)
sparse1 = self.feature_extraction1(x1)
#logging.info(f"Feature Extraction Done")
#logging.info(f"coord at output device:{sparse1.coordinates_at(0).device}")
coord0 = (sparse0.C.short()).to(self.device)
feature0 = sparse0.F
coord1 = (sparse1.C.short()).to(self.device)
feature1 = sparse1.F
del sparse0,sparse1
torch.cuda.empty_cache()
batch_C1, batch_F1 = [],[]
batch_C0, batch_F0 = [],[]
start_idx = np.zeros((2,),dtype=int)
for i in range(len(len_batch)):
end_idx = start_idx + np.array(len_batch[i],dtype=int)
#print(start_idx,end_idx)
#logging.info(f"Before append device:{sparse0.C.device}")
C0 = coord0[start_idx[0]:end_idx[0],1:4]
C1 = coord1[start_idx[1]:end_idx[1],1:4]
F0 = feature0[start_idx[0]:end_idx[0],:]
F1 = feature1[start_idx[1]:end_idx[1],:]
#print(C0.shape,C1.shape,F0.shape,F1.shape)
batch_C1.append(C1)
batch_F1.append(F1)
batch_C0.append(C0)
batch_F0.append(F0)
del C0,C1,F0,F1
torch.cuda.empty_cache()
start_idx = end_idx
#logging.info(f"Coord_seperation Done")
#logging.info(f"After append device:{batch_C0[i].device}")
score0 = self.detection0(batch_C0,batch_F0,len(len_batch))
score1 = self.detection1(batch_C1,batch_F1,len(len_batch))
return{
'feature0': feature0,
'feature1': feature1,
'score0': score0,
'score1': score1
}
And this is my trainer.
def _train_epoch(self, epoch):
config = self.config
gc.collect()
self.model.train()
# Epoch starts from 1
total_loss = 0
total_num = 0.0
data_loader = self.data_loader
data_loader_iter = self.data_loader.__iter__()
iter_size = self.iter_size
data_meter, data_timer, total_timer = AverageMeter(), Timer(), Timer()
pos_dist_meter, neg_dist_meter, mem_meter = AverageMeter(), AverageMeter(), AverageMeter()
start_iter = (epoch - 1) * (len(data_loader) // iter_size)
for curr_iter in range(len(data_loader) // iter_size):
self.optimizer.zero_grad()
batch_loss = 0
data_time = 0
total_timer.tic()
for iter_idx in range(iter_size):
data_timer.tic()
input_dict = data_loader_iter.next()
data_time += data_timer.toc(average=False)
# pairs consist of (xyz1 index, xyz0 index)
len_batch = input_dict['len_batch']
sinput0 = ME.SparseTensor(
input_dict['sinput0_F'], coords=input_dict['sinput0_C']).to(self.device)
sinput1 = ME.SparseTensor(
input_dict['sinput1_F'], coords=input_dict['sinput1_C']).to(self.device)
out = self.model(sinput0,sinput1,len_batch)
pos_pairs = input_dict['correspondences']
loss, neg_dist, pos_dist = self.joint_loss(
out['feature0'],
out['feature1'],
out['score0'],
out['score1'],
pos_pairs,
len_batch = input_dict['len_batch'],
batch_size = config.batch_size,
num_pos=self.config.num_pos_per_batch * self.config.batch_size,
num_hn_samples=self.config.num_hn_samples_per_batch * self.config.batch_size,
)
logging.info(f" batch {iter_size} Done")
loss /= iter_size
loss.backward()
batch_loss += loss.item()
pos_dist_meter.update(pos_dist)
neg_dist_meter.update(neg_dist)
self.optimizer.step()
gc.collect()
torch.cuda.empty_cache()
total_loss += batch_loss
total_num += 1.0
total_timer.toc()
data_meter.update(data_time)
if curr_iter % self.config.stat_freq == 0:
self.writer.add_scalar('train/loss', batch_loss, start_iter + curr_iter)
logging.info(
"Train Epoch: {} [{}/{}], Current Loss: {:.3e}, Pos dist: {:.3e}, Neg dist: {:.3e}"
.format(epoch, curr_iter,
len(self.data_loader) //
iter_size, batch_loss, pos_dist_meter.avg, neg_dist_meter.avg) +
"\tData time: {:.4f}, Train time: {:.4f}, Iter time: {:.4f}".format(
data_meter.avg, total_timer.avg - data_meter.avg, total_timer.avg))
pos_dist_meter.reset()
neg_dist_meter.reset()
data_meter.reset()
total_timer.reset()
In my Network, Detection Module contains no convolution layers. It will just simply compute a score map with will further be used in loss function. As you can see, I tried to delete intermediate variable right after using them. And do empty_cache(). However, that doesn’t help any bit in forward process. The memory consumption remains the same. Just for reference, for JointNet without Detection Module (only have ResUNetBN2C), the memory consumption in forward time will be about 50MB. And the training will fit into a 8GB GPU.
So, is there anyway to minimize the total memory consumption in my module? I am looking forward to your reply. Thanks in advance!