Hey, I have a question RE the backwards function.
Lets say I have two computation graphs which are unlinked and on seperate hosts, g1 and g2. The gradients of g2 are computed using a loss function. However the output of model[0] is the input of model[1]. I have the gradient of A after running L.backward() . I want to feed this loss backward to g1 by giving DofA to g1 and running backward() from that gradient. Is this possible?
In the example below the computation graph is linked because it is done on the same host. However, lets imagine the models are on seprate hosts and loss needs to be communicated backward to model[0] from model[1]. Is there a way to compute gradients on model[0] given the gradient of its output?
from torch import optim
from torch.autograd import Variable
# A Toy Dataset
x = torch.tensor([[0,0,0,0],[1,0,0,0],[0,1,0,0],[0,0,1,0],[1,1,0,0],[1,0,1,0],[0,1,1,0],[1,1,1,0],[0,0,0,1],[1,0,0,1],[0,1,0,1],[0,0,1,1],[1,1,0,1],[1,0,1,1],[0,1,1,1],[1,1,1,1.]])
target = torch.tensor([[0],[0],[0],[0],[0],[0],[0],[0],[1],[1],[1],[1],[1],[1],[1],[1.]])
# Variables for performance metrics
epochs = 20
lr = 0.2
counter = 0
# Define 2 chained models
models = [
nn.Sequential(
nn.Linear(4, 3),
nn.Tanh()
),
nn.Sequential(
nn.Linear(3, 1),
nn.Sigmoid()
)
]
# Create optimisers for each segment and link to their segment
optimizers = [
optim.SGD(params=model.parameters(),lr=lr)
for model in models
]
def train():
# Training Logic
for iter in range(epochs):
# 1) erase previous gradients (if they exist)
for opt in optimizers:
opt.zero_grad()
# 2) make a prediction
a = models[0](x)
# Janky Pseudocode
a.send(models[1].location)
# End Janky Pseudocode
pred = models[1](a)
# 3) calculate how much we missed
loss = ((pred - target)**2).sum()
# 4) figure out which weights caused us to miss
loss.backward()
#Pseudocode for functionality I want
DofA = a.grad
DofA.send(model[0].location)
a = model[0](x)
a.grad = DofA
a.backward()
#Pseudocode over
# 5) change the weights
for opt in optimizers:
opt.step()
# 6) print our progress
print(loss.data)
train()