H I want to tweak aladdinpersson implementation of style transfer.
I have two tensors that are not changing - features of original and style image.
But in the original implementation, I compute them on every step.
This looks like this
for step in trange(6_000):
generated_features = model(generated)
# Two following tensors are constant
original_img_features = model(original_img)
style_features = model(style_img)
style_loss = original_loss = 0
for gen_feature, orig_feature, style_feature in zip(
generated_features, original_img_features, style_features
):
batch_size, channel, height, width = gen_feature.shape
original_loss += torch.mean((gen_feature - orig_feature) ** 2)
G = gen_feature.view(channel, height * width).mm(
gen_feature.view(channel, height * width).t()
)
A = style_feature.view(channel, height * width).mm(
style_feature.view(channel, height * width).t()
)
style_loss += torch.mean((G - A) ** 2)
total_loss = alpha * original_loss + beta * style_loss
optimizer.zero_grad()
total_loss.backward(retain_graph=True)
optimizer.step()
When I move the creation of these tensors outside the loop I get an Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
error. How can I create these tensors and retain them during the whole training? I am pretty sure that I don’t need saving the whole graph, just these two tensors
original_img_features = model(original_img)
style_features = model(style_img)
for step in trange(6_000):
generated_features = model(generated)
style_loss = original_loss = 0
for gen_feature, orig_feature, style_feature in zip(
generated_features, original_img_features, style_features
):
batch_size, channel, height, width = gen_feature.shape
original_loss += torch.mean((gen_feature - orig_feature) ** 2)
G = gen_feature.view(channel, height * width).mm(
gen_feature.view(channel, height * width).t()
)
A = style_feature.view(channel, height * width).mm(
style_feature.view(channel, height * width).t()
)
style_loss += torch.mean((G - A) ** 2)
total_loss = alpha * original_loss + beta * style_loss
optimizer.zero_grad()
total_loss.backward(retain_graph=True)
optimizer.step()