My program’s memory usage is roughly an order of magnitude greater when I specify requires_grad=True on the parameters of my model. I’ve looked through the docs to find a way to reduce my program’s memory consumption, but I can’t seem to figure it out. Here is my objective function:
def fun(x, cons, est, trans, model, data):
print(x)
for con in cons:
valid = np.all(con['fun'](x, *con['args']) > 0)
if valid == False: break
if valid == False:
ll = float('nan')
grad = np.empty(len(x))
grad[:] = np.nan
print('Constraint violation')
print()
return (-ll, -grad)
else:
torch.manual_seed(7)
ex0 = (None,)
ex1 = (slice(None), None)
x = torch.tensor(x.astype(np.float32), requires_grad=True)
ll = 0
for d in range(len(data)):
N, S, K = (data[d]['N'], data[d]['S'], data[d]['K'])
p = 0
mu = torch.tensor([])
re = []
theta = {}
for (key, val) in est.items():
if type(val) != str:
theta[key] = torch.tensor(val)
elif val == 'pool':
theta[key] = x[p][ex0]
p += 1
elif val == 'rand':
mu = torch.cat((mu, x[p][ex0]))
p += 1
re.append(key)
R = len(re)
di = tuple([torch.tensor(ind, dtype=torch.long)
for ind in np.diag_indices(R)])
tril = tuple([torch.tensor(ind, dtype=torch.long)
for ind in np.tril_indices(R, -1)])
triu = tuple([torch.tensor(ind, dtype=torch.long)
for ind in np.triu_indices(R, 1)])
sigma = torch.empty(R, R)
sigma[di] = x[p:p+R]
sigma[tril] = x[p+R:]
sigma[triu] = sigma.t()[triu]
sample = MultivariateNormal(mu, sigma).rsample((N * S,))
theta.update({re[r]: trans[re[r]](sample[:,r]) for r in range(R)})
theta['upsilon'] = torch.zeros(N * S)[ex1]
if K > 1:
theta['upsilon'] = torch.cat((theta['upsilon'],
theta['upsilon2'][ex1]), 1)
if K > 2:
theta['upsilon'] = torch.cat((theta['upsilon'],
theta['upsilon3'][ex1]), 1)
theta['upsilon'] = theta['upsilon'] / theta['kappa']
if data[d]['treat'] == 'pvd': theta['omega'] = torch.zeros(1)
ll = ll + logL(theta, model, data[d])
ll = ll / 1e+2
ll.backward()
grad = np.array(x.grad.cpu())
print(ll.item() * 1e+2)
print()
return (-ll.item(), -grad)
Here is the logL function referenced in the objective function:
def logL(theta, model, data):
N, S, T = (data['N'], data['S'], data['T'])
p_a = torch.ones(N * S)
model.start(theta, data['mat'])
for t in range(1, T + 1):
model.update(data, t)
p_a = p_a * (npeat(data['obs'][:,:,t-1], (S, 1)) * model.p_a).sum(1)
ll = p_a.reshape(N, S).mean(1).log().sum()
return ll
And here is my model
class Model:
def __init__(self, opt):
fun = {'p_k': {'rand': Probability.rand,
'hmax': Probability.hmax,
'smax': Probability.smax},
'p_ak': {'rand': Probability.rand,
'hmax': Probability.hmax,
'smax': Probability.smax},
'U' : {'none': Utility.Learning.none,
'diff': Utility.Learning.diff},
'v': {'lin': Utility.lin},
'w': {'lin': Probability.Weighting.lin,
'linlog': Probability.Weighting.linlog,
'ratio': Probability.Weighting.ratio},
'b_1': {'rand': Probability.rand,
'hmax': Probability.hmax,
'smax': Probability.smax},
'b_k': {'rand': Probability.rand,
'hmax': Probability.hmax,
'smax': Probability.smax},
'B': {'none': Activation.Base.none,
'exp': Activation.Base.exp,
'pow': Activation.Base.pow},
'S': {'none': Activation.Associative.none,
'chain': Activation.Associative.chain},
'P': {'none': Activation.Matching.none,
'part': Activation.Matching.part},
'R_a': {'lin': Utility.lin},
'R_c': {'none': Activation.Base.Learning.none,
'belief': Activation.Base.Learning.belief,
'reinf': Activation.Base.Learning.reinf,
'attrac': Activation.Base.Learning.attrac},
'S_q': {'none': Activation.Associative.Learning.none,
'diff': Activation.Associative.Learning.diff}}
self.fun = {key: fun[key][val] for (key, val) in opt.items()}
self.opt = opt
def start(self, theta, mat):
ex1 = (slice(None), None)
ex12 = (slice(None), None, None)
omega = theta['omega'][ex1].expand(-1, mat.size()[-1])
W = torch.cat((omega, 1 - omega), 1)
xi = W * theta['xi'][ex1]
beta = W * theta['beta'][ex1]
arg = {'p_k': {'rand': (1,),
'hmax': (1,),
'smax': (theta['kappa'][ex1], 1)},
'p_ak': {'rand': (1,),
'hmax': (1,),
'smax': (theta['lam'][ex12], 1)},
'U': {'none': (),
'diff': (theta['alpha'][ex1],)},
'v': {'lin': (mat,)},
'w': {'lin': (),
'linlog': (theta['gamma'][ex12], theta['delta'][ex12]),
'ratio': (theta['gamma'][ex12], theta['delta'][ex12], 2)},
'b_1': {'rand': (2,),
'hmax': (2,),
'smax': (1, 2)},
'b_k': {'rand': (2,),
'hmax': (2,),
'smax': (theta['iota'][ex12], 2)},
'B': {'cons': {},
'exp': {'gamma': theta['phi'][ex1]},
'pow': {'d': theta['phi'][ex12], 'dim': 2}},
'S': {'none': (),
'chain': (2,)},
'P': {'none': (),
'part': (xi[ex1], 2)},
'R_a': {'lin': ()},
'R_c': {'none': {},
'belief': {},
'reinf': {},
'attrac': {'delta': theta['rho'][ex1]}},
'S_q': {'none': (),
'diff': (theta['psi'][ex12], beta[ex1])}}
self.arg = {key: arg[key][val] for (key, val) in self.opt.items()}
self.theta = theta
self.p_ak = torch.zeros(1)
self.U = theta['upsilon']
self.v = self.fun['v'](*self.arg['v'])
self.B = torch.zeros(1).log()
self.S = torch.zeros(1)[ex1]
self.R_ct = torch.zeros(0)
self.S_q = torch.zeros(1)
def update(self, dat, t):
old = (Ellipsis, t - 1)
new = (Ellipsis, t)
ex1 = (slice(None), None)
ex2 = (slice(None), slice(None), None)
N, S, A, X, M, C, K, I = (dat['N'], dat['S'], dat['A'], dat['X'],
dat['M'], dat['C'], dat['K'], dat['I'])
hypo, hist, act, cue, sim, pay = (dat['hypo'], dat['hist'], dat['act'],
dat['cue'], dat['sim'], dat['pay'])
hypo_old = npeat(hypo[old], (S,)+(1,)*(hypo[old].dim()-1))
hist_old = npeat(hist[old], (S,)+(1,)*(hist[old].dim()-1))
act_old = npeat(act[old], (S,)+(1,)*(act[old].dim()-1))
cue_old = npeat(cue[old], (S,)+(1,)*(cue[old].dim()-1))
cue_new = npeat(cue[new], (S,)+(1,)*(cue[new].dim()-1))
sim_new = npeat(sim[new], (S,)+(1,)*(sim[new].dim()-1))
pay_old = npeat(pay[old], (S,)+(1,)*(pay[old].dim()-1))
fun, arg = (self.fun, self.arg)
R_a = fun['R_a'](pay_old, *arg['R_a'])
R_k = (self.p_ak * R_a[ex2]).sum(1)
self.U = fun['U'](self.U, R_k, int(t > 1), *arg['U'])
attn = 1 if t > 1 else self.theta['pi']
R_c = fun['R_c'](attn, hypo=hypo_old, hist=hist_old, **arg['R_c'])
self.R_ct = torch.cat((self.R_ct, R_c[ex2]), 2)
R_ct = self.R_ct[:,:,:t]
time = torch.arange(float(t), 0, -1)
self.S_q = fun['S_q'](self.S_q, self.S[ex2],
act_old[ex2], cue_old, *arg['S_q'])
self.B = fun['B'](B=self.B, t=time, R=R_c, R_j=R_ct, **arg['B'])
self.S = fun['S'](self.S_q, cue_new, *arg['S'])
P = fun['P'](sim_new, *arg['P'])
A_c = (self.B + self.S + P).reshape(N * S, I * M, C // M)
b_1 = fun['b_1'](A_c, *arg['b_1']).reshape(N * S, I * M, A, X).sum(3)
w = npeat(fun['w'](b_1, *arg['w']), (1, A // M, 1))
V_k = (w * self.v).sum(2)[ex2]
for k in range(1, K):
own = (slice(None), slice(0, A), k - 1)
opp = (slice(None), slice(A, 2 * A), k - 1)
V = torch.cat((V_k[opp], V_k[own]), 1).reshape(N * S, 2, A)
b_k = npeat(fun['b_k'](V, *arg['b_k']), (1, A, 1))
V_k = torch.cat((V_k, (b_k * self.v).sum(2)[ex2]), 2)
self.p_k = fun['p_k'](self.U, *arg['p_k'])
self.p_ak = fun['p_ak'](V_k[:,:A], *arg['p_ak'])
self.p_a = (self.p_k[ex1] * self.p_ak).sum(2)
I don’t understand why autograd is consuming so much memory. Please let me know if you see anything I can do to reduce my memory consumption. I’m willing to provide any additional information related to the program if you think it will help diagnose the problem.