Correct way to do backpropagation through time?

(Sean Morrison) #1

This is in the context of model-based reinforcement learning. Say I have some reward at time T, and i want to do truncated backprop through the network roll out, what is the best way to do this? Are there any good examples out there? I haven’t managed to find much.

Any help would be appreciated!

Neural Style Transfer on videos
#2
# non-truncated
for t in range(T):
   out = model(out)
out.backward()

# truncated to the last K timesteps
for t in range(T):
    out = model(out)
    if T - t == K:
        out.detach()
out.backward()
2 Likes
(jpeg729) #3

Shouldn’t the truncated example be

# truncated to the last K timesteps
for t in range(T):
    out = model(out)
    if T - t == K:
        out.backward()
        out.detach()
out.backward()
2 Likes
#4

yea you’re right. i was just backwarding through the last part

(Santosh Manicka) #5

I guess that an alternative is to do a “head” truncation besides a “tail” trunction

This is tail truncation –

This is head truncation –

modelparameter.requires_grad = False
for t in range(T):
    out = model(out)
    if T - t == K:
        modelparameter.requires_grad = True
out.backward()
#6

I have a similar issue in this post.

I followed the pseudocode for the non-truncated BPTT in this conversation, the network trains but I have the feeling that the gradient is not flowing through time. I posted my training code for the network.

Can someone give some tips?

(Duane Nielsen) #7

Check out hooks. If you want to inspect an gradient, you can register a backwards_hook, and drop the values into a print statement or tensorboard.

eg, in the below code I drop a hook to monitor the values passing through a softmax functiion. (later I compute the entropy and pump it into tensorboard).

        def monitorAttention(self, input, output):
            if writer.global_step % 10 == 0:
                monitors.monitorSoftmax(self, input, output, ' input ', writer, dim=1)
        self.softmax.register_forward_hook(monitorAttention)