Hi all,
I’m encountering a problem where my RAM is during inference of multiple models (the GPU memory is released though).
I’ve trained 6 models with binary classification and now i’m trying to do inference of all the 6 models one after the other and i’m for some reason my RAM keep increasing like i have a memory leak problem somewhere in my code but i just don’t know where.
Each of the 6 inference models is I3D, and i’m passing the output of the last layer into a model that will output 6 outputs (I’m doing ensemble on the 6 inference models)
I made .eval() on all my inference models and made sure their .require_grad = False is set too.
I tried setting torch.backends.cudnn.benchmark = False
and torch.backends.cudnn.benchmark = True
, no matter what the consequence was that the RAM was still exploding.
This is the code i’m using for inference:
def load_inference_model(kinetics_weights, checkpoint_weights, use_half, use_dataparallel, convert_to_cuda_before_loading_weights=False):
model = I3D(num_classes=400, modality='rgb', use_spatial=True)
if kinetics_weights is not None:
model.load_state_dict(torch.load(kinetics_weights))
model.conv3d_0c_1x1 = Unit3Dpy(
in_channels=1024,
out_channels=2,
kernel_size=(1, 1, 1),
activation=None,
use_bias=True,
use_bn=False)
if convert_to_cuda_before_loading_weights:
state_dict = torch.load(checkpoint_weights)["state_dict"]
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
state_dict = new_state_dict
else:
state_dict = torch.load(checkpoint_weights)["state_dict"]
model.load_state_dict(state_dict)
model = convert_to_cuda(model, use_half, use_dataparallel)
for p in model.module.parameters():
p.requires_grad = False
model.eval()
return model
def do_single_inference(inputs, model, mean , std, use_half):
normalized_input = (inputs - mean)/std
if use_half:
inputs_var = Variable(normalized_input.cuda().half(), volatile=True)
else:
inputs_var = Variable(normalized_input.cuda(), volatile=True)
score, softmax = model(inputs_var)
return score.data
def do_inference(inputs, models, means, stds, use_half):
scores = torch.zeros([4, 12]) #4 = batch_size i'm using
#inference 8 models
scores[:, 0:2] = do_single_inference(inputs, models[0], means[0], stds[0], use_half)
scores[:, 2:4] = do_single_inference(inputs, models[1], means[1], stds[1], use_half)
scores[:, 4:6] = do_single_inference(inputs, models[2], means[2], stds[2], use_half)
scores[:, 6:8] = do_single_inference(inputs, models[3], means[3], stds[3], use_half)
scores[:, 8:10] = do_single_inference(inputs, models[4], means[4], stds[4], use_half)
scores[:, 10:12] = do_single_inference(inputs, models[5], means[5], stds[5], use_half)
return scores
def train_micro_batches(epoch, model, ensemble_models, models_mean, models_std, steps_per_epoch, num_micro_batches, data_loader, use_half):
stateful_metrics = ["Loss", "Acc"]
progress_bar = ProgressBar(steps_per_epoch, stateful_metrics=stateful_metrics)
model.train()
running_loss = 0.0
running_corrects = 0
total = 0
loss_avg = 0.0
acc_avg = 0.0
data_loader_iter = iter(data_loader)
for i in range(steps_per_epoch):
batch_loss_value = 0
optimizer.zero_grad()
for j in range(num_micro_batches):
inputs, targets = next(data_loader_iter)
scores = do_inference(inputs, ensemble_models, models_mean, models_std, use_half)
if use_half:
scores, targets = Variable(scores.cuda().half()), Variable(targets.cuda())
else:
scores, targets = Variable(scores.cuda()), Variable(targets.cuda())
score = model(scores)
loss = criterion(score, targets)
loss.backward()
if use_half:
batch_loss_value += loss.data.cpu()[0]
else:
batch_loss_value += loss.data.cpu().numpy()[0]
_, predicted = torch.max(score.data, 1)
total += targets.size(0)
running_corrects += predicted.eq(targets.data).cpu().sum()
optimizer.step()
running_loss += batch_loss_value/num_micro_batches
gc.collect()
acc_avg = running_corrects / total
loss_avg = running_loss / (i + 1)
vals = [("Loss", "{:0.4f}".format(loss_avg)), ("Acc", "{:0.4f}".format(acc_avg))]
progress_bar.update(i+1, vals)