Hi, I have a pretrained frozen model net
, then I’d like to train a another model pre_net
, so it will process input x
as net(pre_net(x))
. My training loss is constructed similar to contrastive learning: d(net(pre_net(x)), net(pre_net_ema(x_t)))
. pre_net_ema
is ema version of pre_net
, and x_t
is a transformation of x
.
My question is which one should I choose between .detach()
or torch.no_grad()
when calculate net(pre_net_ema(x_t))
, I 'm sure .detach()
is a correct choice, while torch.no_grad()
do not construct graph and will be faster than .detach()
theoretically. However, I’m not sure where or not using torch.no_grad()
is a correct choice for my setting.