Hi,
I’m new to pytorch, and I’m doing something a little heterodox. I build a very dynamic computational graph with a lot of inplace operations due to the complex nature of my subject. I expect it to be slow as I build the whole ‘network’ from scratch, but the backward is just so slow (about 1 minutes for forward and 20+ minutes for backward) that I must have done something wrong. So I’m asking for some advice on implementing pytorch for general purpose computation. Below is an example of my code:
First, I defined a function to update the state of one node:
def neuronForward(self, InputS, InputR, InputTau, StateTau, endTime, V0, Dv0, G, Tau_m, H, Tau_n, I, Tau_r, Tau_s, G_exc, G_inh ):
Delta = (Tau_m+G*Tau_n)**2 - 4*Tau_m*Tau_n*(H+G)
deta = Delta.data[0]
nptr = (-1/Tau_r).data[0]
npts = (-1/Tau_s).data[0]
P = G/Tau_m +1/Tau_n
Q = (H+G)/(Tau_m*Tau_n)
L = I/(H+G)
Cr = 1/(Tau_m*Tau_n) - 1/(Tau_m*Tau_r)
Cs = 1/(Tau_m*Tau_n) - 1/(Tau_m*Tau_s)
PQR = 1-P*Tau_r+Q*Tau_r**2
PQS = 1-P*Tau_s+Q*Tau_s**2
InheritMask = (InputTau == StateTau).float().detach()
EffectiveMask = (InputTau > 0).float().detach()
GexcMask = (InputR > 0).float().detach()
GinhMask = (InputR < 0).float().detach()
S0 = torch.sum(InheritMask*InputS)
Kr = EffectiveMask*(GexcMask*G_exc*Tau_r*InputR/(Tau_r-Tau_s) - GinhMask*G_inh*Tau_r*InputR/(Tau_r-Tau_s))
Ks = EffectiveMask*(InputS - GexcMask*G_exc*Tau_r*InputR/(Tau_r-Tau_s) + GinhMask*G_inh*Tau_r*InputR/(Tau_r-Tau_s))
if(deta >0):
r1 = -0.5/Tau_n -0.5*G/Tau_m +Delta**0.5/(2*Tau_m*Tau_n)
r2 = -0.5/Tau_n -0.5*G/Tau_m -Delta**0.5/(2*Tau_m*Tau_n)
if(r1.data[0]!= nptr and r2.data[0]!= nptr):
Vr = Cr*Kr*Tau_r**2*torch.exp(-InputTau/Tau_r)/PQR
C1r = Kr*(1/(Tau_m*(r1-r2)) + Cr*(Tau_r+r2*Tau_r**2)/((r1-r2)*PQR))
C2r = -Cr*Kr*Tau_r**2/PQR-C1r
Vr = Vr + C1r*torch.exp(r1*InputTau) + C2r*torch.exp(r2*InputTau)
Dvr = -Cr*Kr*Tau_r*torch.exp(-InputTau/Tau_r)/PQR + C1r*r1*torch.exp(r1*InputTau) + C2r*r2*torch.exp(r2*InputTau)
else:
Vr = Cr*Kr*Tau_r*InputTau*torch.exp(-InputTau/Tau_r)/(P*Tau_r-2)
C1r = Kr*(Tau_r*(P-Cr*Tau_m)-2)/((r1-r2)*(P*Tau_r-2)*Tau_m)
C2r = -C1r
Vr = Vr + C1r*torch.exp(r1*InputTau) + C2r*torch.exp(r2*InputTau)
Dvr = Cr*Kr*Tau_r*torch.exp(-InputTau/Tau_r)/(P*Tau_r - 2) - Cr*Kr*InputTau*torch.exp(-InputTau/Tau_r)/(P*Tau_r - 2) \
+ C1r*r1*torch.exp(r1*InputTau) + C2r*r2*torch.exp(r2*InputTau)
if(r1.data[0]!= npts and r2.data[0]!= npts):
Vs = Cs*Ks*Tau_s ** 2*torch.exp(-InputTau/Tau_s)/PQS
C1s = Ks*(1/(Tau_m*(r1 - r2)) + Cs*(Tau_s + r2*Tau_s ** 2)/((r1 - r2)*PQS))
C2s = -Cs*Ks*Tau_s ** 2/PQS - C1s
Vs = Vs + C1s*torch.exp(r1*InputTau) + C2s*torch.exp(r2*InputTau)
Dvs = -Cs*Ks*Tau_s*torch.exp(-InputTau/Tau_s)/PQS + C1s*r1*torch.exp(r1*InputTau) + C2s*r2*torch.exp(r2*InputTau)
else:
Vs = Cs*Ks*Tau_s*InputTau*torch.exp(-InputTau/Tau_s)/(P*Tau_s-2)
C1s = Ks*(Tau_s*(P-Cs*Tau_m)-2)/((r1-r2)*(P*Tau_s-2)*Tau_m)
C2s = -C1s
Vs = Vs + C1s*torch.exp(r1*InputTau) + C2s*torch.exp(r2*InputTau)
Dvs = Cs*Ks*Tau_s*torch.exp(-InputTau/Tau_s)/(P*Tau_s - 2) - Cs*Ks*InputTau*torch.exp(-InputTau/Tau_s)/(P*Tau_s - 2) \
+ C1s*r1*torch.exp(r1*InputTau) + C2s*r2*torch.exp(r2*InputTau)
C3 = (Dv0-S0/Tau_m+r2*L-r2*V0)/(r1-r2)
C4 = V0-C3-L
VI = C3*torch.exp(r1*StateTau)+C4*torch.exp(r2*StateTau)+L
DvI = C3*r1*torch.exp(r1*StateTau) + C4*r2*torch.exp(r2*StateTau)
elif(deta < 0):
alpha = -0.5/Tau_n -0.5*G/Tau_m
beta = (-Delta)**0.5/(2*Tau_m*Tau_n)
Vr = Cr*Kr*Tau_r ** 2*torch.exp(-InputTau/Tau_r)/PQR
C1r = -Cr*Kr*Tau_r**2/PQR
C2r = Cr*Kr*Tau_r*(1+alpha*Tau_r)/(beta*PQR)+Kr/(beta*Tau_m)
Vr = Vr + torch.exp(alpha*InputTau)*(C1r*torch.cos(beta*InputTau)+C2r*torch.sin(beta*InputTau))
Vs = Cs*Ks*Tau_s ** 2*torch.exp(-InputTau/Tau_s)/PQS
C1s = -Cs*Ks*Tau_s**2/PQS
C2s = Cs*Ks*Tau_s*(1+alpha*Tau_s)/(beta*PQS)+Ks/(beta*Tau_m)
Vs = Vs + torch.exp(alpha*InputTau)*(C1s*torch.cos(beta*InputTau)+C2s*torch.sin(beta*InputTau))
C3 = V0 - L
C4 = (Dv0 - S0/Tau_m -alpha*C3)/beta
VI = torch.exp(alpha*StateTau)*(C3*torch.cos(beta*StateTau)+C4*torch.sin(beta*StateTau)) + L
Dvr = -Cr*Kr*Tau_r*torch.exp(-InputTau/Tau_r)/PQR + torch.exp(alpha*InputTau)*((alpha*C1r+beta*C2r)*torch.cos(beta*InputTau) +
(alpha*C2r-beta*C1r)*torch.sin(beta*InputTau))
Dvs = -Cs*Ks*Tau_s*torch.exp(-InputTau/Tau_s)/PQS + torch.exp(alpha*InputTau)*((alpha*C1s+beta*C2s)*torch.cos(beta*InputTau) +
(alpha*C2s-beta*C1s)*torch.sin(beta*InputTau))
DvI = torch.exp(alpha*StateTau)*((alpha*C3 + beta*C4)*torch.cos(beta*StateTau) +
(alpha*C4 - beta*C3)*torch.sin(beta*StateTau))
elif(deta==0):
r = -0.5/Tau_n -0.5*G/Tau_m
if(r!=nptr):
Vr = Cr*Kr*Tau_r**2*torch.exp(-InputTau/Tau_r)/PQR
C1r = -Cr*Kr*Tau_r**2/PQR
C2r = Cr*Kr*Tau_r*(1+r*Tau_r)/PQR +Kr/Tau_m
Vr = Vr + (C1r+C2r*InputTau)*torch.exp(r*InputTau)
Dvr = -Cr*Kr*Tau_r*torch.exp(-InputTau/Tau_r)/PQR + (r*C1r + (r*InputTau + 1)*C2r)*torch.exp(r*InputTau)
else:
Vr = Cr*Kr*InputTau**2*torch.exp(-InputTau/Tau_r)/2
C2r = Kr/Tau_m
Vr = Vr + C2r*InputTau*torch.exp(r*InputTau)
Dvr = Cr*Kr*InputTau*torch.exp(-InputTau/Tau_r) - Cr*Kr*InputTau ** 2*torch.exp(-InputTau/Tau_r)/(2*Tau_r)+\
(r*InputTau+1)*C2r*torch.exp(r*InputTau)
if(r!=npts):
Vs = Cs*Ks*Tau_s ** 2*torch.exp(-InputTau/Tau_s)/PQS
C1s = -Cs*Ks*Tau_s**2/PQS
C2s = Cs*Ks*Tau_s*(1+r*Tau_s)/PQS +Ks/Tau_m
Vs = Vs + (C1s+C2s*InputTau)*torch.exp(r*InputTau)
Dvs = -Cs*Ks*Tau_s*torch.exp(-InputTau/Tau_s)/PQS + (r*C1s + (r*InputTau + 1)*C2s)*torch.exp(r*InputTau)
else:
Vs = Cs*Ks*InputTau**2*torch.exp(-InputTau/Tau_s)
C2s = Ks/Tau_m
Vs = Vs + C2s*InputTau*torch.exp(r*InputTau)
Dvs = Cs*Ks*InputTau*torch.exp(-InputTau/Tau_s) - Cs*Ks*InputTau ** 2*torch.exp(-
V = torch.sum(Vr+Vs) + VI
Dv = torch.sum(Dvr + Dvs) + DvI
endTime = couple(V, endTime, -Dv.data[0])
theTime = 0*InputTau + endTime
InputTau = couple(theTime,InputTau,1)
StateTau = couple(endTime,StateTau,1)
InputS = Kr*(torch.exp(-InputTau/Tau_r)-torch.exp(-InputTau/Tau_s)) + EffectiveMask*InputS*torch.exp(-InputTau/Tau_s) +\
(1-EffectiveMask)*InputS
InputR = EffectiveMask*InputR*torch.exp(-InputTau/Tau_r) + (1-EffectiveMask)*InputR
return V, endTime, Dv, InputR, InputS
And then, I update the state one-by-one in the main loop:
for ii in range(StateUpdateNum):
neuronInd = int(SrIndex[ii])
StateTau = SrTime[ii] - Time[neuronInd]
InpTau = SrTime[ii] - PreSynpticTimeList[neuronInd]
InpTau = torch.min(InpTau, StateTau)
endTime = Variable(torch.Tensor([SrTime[ii]]), requires_grad=True)
theV, endTime, theDv, InputR, InputS = \
neuronForward(S[neuronInd], R[neuronInd], InpTau, StateTau, endTime, V[neuronInd], Dv[neuronInd], theG,
Tau_m[neuronInd],
H[neuronInd], theTau_n, theI, Tau_r[neuronInd], Tau_s[neuronInd], theG_exc, theG_inh)
Time[neuronInd] = endTime
S[neuronInd] = InputS
R[neuronInd] = InputR
Dv = Dv.clone()
Dv[neuronInd] = theDv
V = V.clone()
V[neuronInd] = theV.detach()
State[neuronInd] = SrState[ii]
if (SrState[ii] == 4):
if (NetSpkTime is None):
NetSpkTime = endTime.clone()
else:
NetSpkTime = torch.cat((NetSpkTime, endTime))
for postInd in ForwardLink[neuronInd + InpNeuronNum]:
if (len(PreSynpticTimeList[postInd]) == 0):
PreSynpticTimeList[postInd] = preTime + DelaysMatrix[postInd, neuronInd + InpNeuronNum]
else:
PreSynpticTimeList[postInd] = torch.cat(
(PreSynpticTimeList[postInd], preTime + DelaysMatrix[postInd, neuronInd + InpNeuronNum]))
if len(S[postInd]) == 0:
S[postInd] = Variable(torch.zeros(1), requires_grad=True).clone()
R[postInd] = WeightsMatrix[postInd, neuronInd + InpNeuronNum].clone()
else:
S[postInd] = torch.cat((S[postInd], Variable(torch.zeros(1), requires_grad=True)))
R[postInd] = torch.cat((R[postInd], WeightsMatrix[postInd, neuronInd + InpNeuronNum]))
I understand that my ‘network’ is too fragmentary, and is really bad for vectorization. But the ‘network’ is so dynamic that I can’t know each state and connectivity before state update. Even through, it shouldn’t be that slow. I can do something to speed it up? Thanks in advance. And by the way, I can’t find the module torch.autograd.profiler in my pytorch (linux python2.7 conda installation latest version), do I need a seperate installation?