Clone and detach in v0.4.0

I’m currently migrating my old code from v0.3.1 to v0.4.0. During migration, I feel confused by the document about clone and detach. After searching related topics in the forum, I find that most discussions are too old. Specifically, I want an answer to the three following questions:

  1. the difference between tensor.clone() and tensor.detach() in v0.4.0?
  2. the difference between tensor and tensor.data in v0.4.0?
  3. the difference between tensor and tensor.data when calling clone() or detach() in v0.4.0?
    And an additional question which will help me understand how my old code work and what is those old topic talking:
  4. the difference between tensor.clone(), variable.clone() and variable.detach() in v0.3.1?
25 Likes

We’ll provide a migration guide when 0.4.0 is officially released. Here are the answers to your questions:

  1. tensor.detach() creates a tensor that shares storage with tensor that does not require grad. tensor.clone()creates a copy of tensor that imitates the original tensor's requires_grad field.
    You should use detach() when attempting to remove a tensor from a computation graph, and clone as a way to copy the tensor while still keeping the copy as a part of the computation graph it came from.

  2. tensor.data returns a new tensor that shares storage with tensor. However, it always has requires_grad=False (even if the original tensor had requires_grad=True

  3. You should try not to call tensor.data in 0.4.0. What are your use cases for tensor.data?

  4. tensor.clone() makes a copy of tensor. variable.clone() and variable.detach() in 0.3.1 act the same as tensor.clone() and tensor.detach() in 0.4.0.

99 Likes

Thank you so much for such detailed explanation.
In my code, I use clone to do two things:

  1. save model parameter whenever validation set performance reaches a new peak and restore them after several training rounds whose validation set performances stop improving(implement early stop). here’s my code:
#because my parameter size is small compared with GPU memory I don’t save it to disk
best_params = dict()
if map_score > best_map:
    for name, tensor in model.state_dict(keep_vars=True).items():
        #a large part of my parameter is embedding which is freeze during training so I don’t save it
        if tensor.requires_grad:
            #use tensor.data to empty its grad_fn then clone a new one to avoid being changed during training
            best_params[name] = tensor.data.clone()
            #if i understand correctly tensor.clone() will be ok too 
            #because no computation is performed with best parameters and no backward() is called. 
            #In such conditions having an none-empty grad_fn won’t cause any problem
model.load_state_dict(best_params, strict=False)
  1. save model output when test model on the validation set. I divide validation task into several batches, save each batch’s ouput and concatenate them on GPU before transferring back to CPU here’s my code:
save_res = list()
for test_data in task:
    score_res = model(‘test’, test_data)
    #use tensor.data to empty its grad_fn then clone a new one to avoid being free
    #use tensor.data is essential because requires_grad need to be False when calling .numpy()
    save_res.append(score_res.data.clone())
    #i didn't use score_res.detach() because it will be deleted later
    #i think just score_res.detach() and no del score_res result in keeping graph in GPU memory
    del score_res
save_res = torch.cat(save_res).cpu().numpy()
do_some_eval(save_res)
1 Like

hi,
i have read many forums trying to understand .detach() , and although i get a intutution of what it is , i still don’t understand it completely . I dont understand what removing a tensor from computational graph implies , for ex. consider piece of code below

x = torch.tensor(([1]))
y = x**2
z = 2*y
w= z**3
z = z.detach()

x.grad.zero_()
w.backward()
print x.grad

The output came out to be 48 which is the same with or without z.detach() . What does `z.detach() do here ? Why wasn’t it removed from the computational graph?

thanks

2 Likes

detach() does not effect the origin graph. Only if you want to get non-leaf variable, eg, z, and perform operations on it without effecting the origin gradients( Here are gradients w.r.t x,y and z), you need add detach().

Eg.

x = torch.tensor(([1]),requires_grad=True)
y = x**2
z = 2*y
w= z**3

# This is the subpath
# Do not use detach()
p = z
q = torch.tensor(([2]), requires_grad=True)
pq = p*q
pq.backward(retain_graph=True)

w.backward()
print x.grad

The gradient will be accumulated in z, and it gives result of ‘56’.
However, when you try:

x = torch.tensor(([1]),requires_grad=True)
y = x**2
z = 2*y
w= z**3

# detach it, so the gradient w.r.t `p` does not effect `z`!
p = z.detach()
q = torch.tensor(([2]), requires_grad=True)
pq = p*q
pq.backward(retain_graph=True)

w.backward()
print x.grad

The gradient is still 48. So detach() makes the gradient flow in the subpath no harm to the main path.

23 Likes

Thanks for the quick reply…that cleared things up

Just curious, how to clone Variable without its history in 0.4 (meaning clone and make it a leaf variable)?

Is it x_clone = torch.tensor(x.detach(), requires_grad=True)? Or is there a better way?

2 Likes

Hi,

x_clone = x.detach().clone() will do the job.

7 Likes

@albanD Is sourceTensor.clone().detach() better than sourceTensor.detach().clone()? Why?

1 Like

I tried to copy from the tensor to a new leaf variable, like b = torch.tensor(a.datach(), requires_grad=True), however, the REPL adviced me:

/home/acgtyrant/.local/bin/ipython:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

But I think sourceTensor.datach().clone().requires_grad_(True) will be better than sourceTensor.clone().requires_grad_(True) while any function like clone or cuda will create some subgraph before call detach. So detach first, then call some functions which do not have any side effect, then call requires_grad_(True) to create a pure leaf variable finally, like sourceTensor.detach().index_select(0).clone().cuda(1).requires_grad_(True).

Hi,

The requires_grad_ call should be last of you do that in one line to make sure you get the leaf tensor.
Otherwise, the difference between all of these is going to be almost impossible to see !

By the way calling .cuda() on a cpu tensor will copy it, so no need to clone before. You can have a check and clone if it’s already on the correct device, otherwise just call .cuda(1) without the clone.

4 Likes

But why is “data” still used in the source?

Because we did not had time to modify all the existing code.
The use of .data in the pytorch codebase are safe and so not critical to be changed but for consistency.

2 Likes

Why should we not call .data?

It is going to go away soon ! It’s not even in the doc anymore !
You can replace it with .detach() or with torch.no_grad() depending what you want to do.

2 Likes

Does this mean that it also removes the flow of gradients down it’s path? i.e. removes an edge from the graph and thus a path for backprop to flow through?

It is a temporary restriction to prevent people from relying on impls for large arrays on a stable compiler just yet.

In the rare chance that we need to give up on const generics and rip it out of the compiler, it avoids getting into a situation where users are already relying on e.g. IntoIterator for [T; 262143] and we have no way to support them.

When const generics are finally stabilized the restriction would be removed and the impl would apply to arrays of arbitrarily large size.

My apologies ziyad, but I am having a hard time following what you are responding to. Are you responding to my question about gradient flows?

Yes, compare printed graphs:

# http://www.bnikolic.co.uk/blog/pytorch-detach.html

import torch
from torchviz import make_dot

x=torch.ones(10, requires_grad=True)
weights = {'x':x}

y=((x**2)**3)
y=((x**2).detach()**3)
y=((x**2)**3).detach()

z=(x+3)+4
r=(y+z).sum()

make_dot(r,params=weights)

I am trying to understand what “shares storage means”. Does that mean if I detach a tensor from it’s original but then opt.step() to modify the original, then Both would change?

i.e.

a = torch.tensor([1,2,3.], requires_grad=True)
b = a.detach()
opt = optim.Adam(a,lr=0.01)
a.backward()
opt.step() # changes both because they share storage?

So both change because they share storage? That seems an odd semantics to have or I am missing the point of .detach().


here is the migration guide Richard mentioned:


Related useful links:

1 Like