Hi All,
I am very new to PyTorch and I’m seeing something weird when my code runs that I can’t figure out.
In short this I am applying a gaussian to many images and then a regression with brain data.
The code batches the gaussian/image process. The center location and width of the gaussian changes, each combination is considered one ‘model’ and we find the combination that provides the best prediction for the brain data. I use nvtop to monitor the GPU usage. I’m trying to find the best batch number for both images and brain data so that I can model a whole brain without it taking days and days.
So I call my function, I watch the gpu and it starts the first model, I can see the memory shoot up when it starts the image batching, which makes sense. But when the second model runs, it shoots up again, as if PyTorch is allocating a new set of memory and I’m not sure why it would. Even weirder is it only does this on the second model, it doesn’t keep going up for each model. If it was doing it for every model then I would assume it’s creating an unnecessary new tensor because of something in my code, but to do it just between the first and second loop is really puzzling to me. I’ve tried stopping the code at various places and looking at the tensors that have been created, but I can’t find the culprit as there is nothing new after the second run.
The code is below, it was mostly written by someone else and I just changed some aspects to ‘optimize’ it. I’m sure there are other things that can be done, but at the moment figuring out this memory thing is the most important because it severely limits the batch sizes I can use. And I also just want to gain a better understanding of memory issues so I can get better at using PyTorch in general. I appreciate any help or insights!
def learn_params_ridge_regressionM(data, voxels, _fmaps_fn, models, lambdas, aperture=1.0, _nonlinearity=None, zscore=False, sample_batch_size=100, voxel_batch_size=100, holdout_size=100, shuffle=True, add_bias=False):
"""
Learn the parameters of the fwRF model
Parameters
----------
data : ndarray, shape (#samples, #channels, x, y)
Input image block.
voxels: ndarray, shape (#samples, #voxels)
Input voxel activities.
_fmaps_fn: Torch module
Torch module that returns a list of torch tensors.
models: ndarray, shape (#candidateRF, 3)
The (x, y, sigma) of all candidate RFs for gridsearch.
lambdas: ndarray, shape (#candidateRegression)
The rigde parameter candidates.
aperture (default: 1.0): scalar
The span of the stimulus in the unit used for the RF models.
_nonlinearity (default: None)
A nonlinearity expressed with torch's functions.
zscore (default: False)
Whether to zscore the feature maps or not.
sample_batch_size (default: 100)
The sample batch size (used where appropriate)
voxel_batch_size (default: 100)
The voxel batch size (used where appropriate)
holdout_size (default: 100)
The holdout size for model and hyperparameter selection
shuffle (default: True)
Whether to shuffle the training set or not.
add_bias (default: False)
Whether to add a bias term to the ridge regression or not.
Returns
-------
losses : ndarray, shape (#voxels)
The final loss for each voxel.
lambdas : ndarray, shape (#voxels)
The regression regularization index for each voxel.
models : ndarray, shape (#voxels, 3)
The RF model (x, y, sigma) associated with each voxel.
params : list of ndarray, shape (#voxels, #features)
Can contain a bias parameter of shape (#voxels) if add_bias is True.
mst_mean : ndarray, shape (#voxels, #feature)
None if zscore is False. Otherwise returns zscoring average per feature.
mst_std : ndarray, shape (#voxels, #feature)
None if zscore is False. Otherwise returns zscoring std.dev. per feature.
"""
def _cofactor_fn(_x, lambdas):
'''input matrix [#samples, #features], a list of lambda values'''
_f = torch.stack([(torch.mm(torch.t(_x), _x) + torch.eye(_x.size()[1], device=device) * l).inverse() for l in lambdas], axis=0) # [#lambdas, #feature, #feature]
return torch.tensordot(_f, _x, dims=[[2],[1]]) # [#lambdas, #feature, #sample]
def _loss_fn(_cofactor, _vtrn, _xout, _vout):
'''input '''
_beta = torch.tensordot(_cofactor, _vtrn, dims=[[2], [0]]) # [#lambdas, #feature, #voxel]
_pred = torch.tensordot(_xout, _beta, dims=[[1],[1]]) # [#samples, #lambdas, #voxels]
_loss = torch.sum(torch.pow(_vout[:,None,:] - _pred, 2), dim=0) # [#lambdas, #voxels]
return _beta, _loss
#############################################################################
dtype = np.float32
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trn_size = len(voxels) - holdout_size
assert trn_size>0, 'Training size needs to be greater than zero'
print ('trn_size = %d (%.1f%%)' % (trn_size, float(trn_size)*100/len(voxels)))
sys.stdout.flush()
nt = len(data)
nm = len(models)
nv = voxels.shape[1]
data = torch.from_numpy(data)
data = data.pin_memory()
if shuffle:
order = np.arange(len(voxels), dtype=int)
np.random.shuffle(order)
data = data[order]
voxels = voxels[order]
trn_voxels = torch.from_numpy(voxels[:trn_size]).to(device)
out_voxels = torch.from_numpy(voxels[trn_size:]).to(device)
### Calculate total feature count
nf = 0
_fmaps = _fmaps_fn(data[:3].float().to(device))
fmaps_rez = []
for k,_fm in enumerate(_fmaps):
nf += _fm.size()[1]
assert _fm.size()[2]==_fm.size()[3], 'All feature maps need to be square'
fmaps_rez += [_fm[k].size()[2],]
#print (_fm.size())
#print ('---------------------------------------')
#sys.stdout.flush()
#############################################################################
### Create full model value buffers
best_models = np.full(shape=(nv,), fill_value=-1, dtype=np.int)
best_lambdas = torch.ones(nv, device=device, dtype=torch.long).neg_()
best_losses = torch.full((nv,), float("Inf"), device=device)
best_w_params = torch.zeros(nv,nf,device=device)
nfd=nf
if add_bias:
nfd = nf+1
best_w_params = torch.cat([best_w_params, torch.ones(nv,1,device=device)], axis=1)
mst_mean = None
mst_std = None
if zscore:
mst_mean = torch.zeros(nv, nf, device=device)
mst_std = torch.zeros(nv, nf, device=device)
start_time = time.time()
vox_loop_time = 0
with torch.no_grad():
for m,(x,y,sigma) in enumerate(models):
print ('\rmodel %4d of %-4d' % (m, nm), flush=True)
mst = torch.zeros(nt, nf, device=device)
_pfs = [_to_torch(pnu.make_gaussian_mass(x, y, sigma, n_pix, size=aperture, dtype=dtype)[2], device=device) for n_pix in fmaps_rez]
for rt,rl in iterate_range(0, nt, sample_batch_size):
mst[rt] = torch.cat([torch.tensordot(_fm, _pf, dims=[[2,3], [0,1]]) for _fm,_pf in zip(_fmaps_fn(data[rt].float().to(device)), _pfs)], dim=1) # [#samples, #features]
if _nonlinearity is not None:
mst = _nonlinearity(mst)
if zscore:
mstm = torch.mean(mst, axis=0, keepdims=True) #[:trn_size]
msts = torch.std(mst, axis=0, keepdims=True) + 1e-6
mst -= mstm
mst /= msts
if add_bias:
mst = torch.cat([mst, torch.ones(len(mst), 1, device=device)], axis=1)
_xtrn = mst[:trn_size]
_xout = mst[trn_size:]
_cof = _cofactor_fn(_xtrn, lambdas)
vox_start = time.time()
for rv,lv in iterate_range(0, nv, voxel_batch_size):
_vtrn = trn_voxels[:,rv]
_vout = out_voxels[:,rv]
_betas, _loss = _loss_fn(_cof, _vtrn, _xout, _vout) # [#lambda, #feature, #voxel, ], [#lambda, #voxel]
_values, _select = torch.min(_loss, dim=0)
imp = _values<best_losses[rv]
if torch.sum(imp)>0:
arv = torch.arange(rv[0],rv[-1]+1)[imp]
li = _select[imp]
best_lambdas[arv] = li
best_losses[arv] = _values[imp]
best_models[arv.numpy()] = m
if zscore:
mst_mean[arv] = mstm # broadcast over updated voxels
mst_std[arv] = msts
best_w_params[arv,:]= _betas[:,:,imp].gather(dim=0,index=li.repeat(nfd,1).view(1,nfd,-1)).squeeze().T
vox_loop_time += (time.time() - vox_start)
#############################################################################
total_time = time.time() - start_time
inv_time = total_time - vox_loop_time
best_w_params=best_w_params.cpu().numpy()
return_params = [best_w_params[:,:nf],]
if add_bias:
return_params += [best_w_params[:,-1],]
else:
return_params += [None,]
print ('\n---------------------------------------')
print ('total time = %fs' % total_time)
print ('total throughput = %fs/voxel' % (total_time / nv))
print ('voxel throughput = %fs/voxel' % (vox_loop_time / nv))
print ('setup throughput = %fs/model' % (inv_time / nm))
sys.stdout.flush()
return best_losses.cpu().numpy(), best_lambdas.cpu().numpy(), [models[best_models],]+return_params+[mst_mean, mst_std]
The code for _fmaps_fn function that is fed into it is below. The code was originally written to extract features from AlexNet for the images. I’ve actually already made the features for my images, so instead of inputting the raw stimuli I’m giving it the features directly but need this function to play nice with the code as it is currently written. I thought about changing it to not be necessary but at some point my own model is going to require more complicated feature maps and I’ll probably need to give the raw stimuli and have a function to create the feature maps for each batch, so taking it out will be counterproductive.
class Torch_Split(nn.Module):
def __init__(self,s,d):
super(Torch_Split, self).__init__()
self.s = nn.Parameter(torch.as_tensor(s),requires_grad=False)
self.d = nn.Parameter(torch.as_tensor(d),requires_grad=False)
def forward(self, _x):
return list(torch.split(_x, self.s, self.d))
Thanks in advance for any help you can provide!!