Hello,
i write a toy code to check SGD weight_decay.
but it seems to have no effect to the gradient update.
am i misunderstand the meaning of weight_decay?
thank you very much.
PyTorch 1.0
import torch
import numpy as np
np.random.seed(123)
np.set_printoptions(8, suppress=True)
x_numpy = np.random.random((3, 4)).astype(np.double)
w_numpy = np.random.random((4, 5)).astype(np.double)
x_torch = torch.tensor(x_numpy, requires_grad=True)
w_torch = torch.tensor(w_numpy, requires_grad=True)
lr = 0.1
sgd = torch.optim.SGD([w_torch], lr=lr, weight_decay=0.9)
y_torch = torch.matmul(x_torch, w_torch)
loss = y_torch.sum()
print("w_torch before SGD")
print(w_torch.data.numpy())
sgd.zero_grad()
loss.backward()
sgd.step()
w_grad = w_torch.grad.data.numpy()
print("w_torch after SGD")
print(w_torch.data.numpy())
print("check_weight_decay")
print(w_numpy - lr * w_grad)
"""
code output :
w_torch before SGD
[[ 0.43857224 0.0596779 0.39804426 0.73799541 0.18249173]
[ 0.17545176 0.53155137 0.53182759 0.63440096 0.84943179]
[ 0.72445532 0.61102351 0.72244338 0.32295891 0.36178866]
[ 0.22826323 0.29371405 0.63097612 0.09210494 0.43370117]]
w_torch after SGD
[[ 0.20941374 -0.13538012 0.17253327 0.48188881 -0.02361953]
[ 0.04952477 0.37357542 0.37382677 0.46716854 0.6628466 ]
[ 0.50417498 0.40095203 0.50234411 0.13881324 0.17414831]
[ 0.01120012 0.07076036 0.37766885 -0.11270393 0.19814865]]
check_weight_decay
[[ 0.20941374 -0.13538012 0.17253327 0.48188881 -0.02361953]
[ 0.04952477 0.37357542 0.37382677 0.46716854 0.6628466 ]
[ 0.50417498 0.40095203 0.50234411 0.13881324 0.17414831]
[ 0.01120012 0.07076036 0.37766885 -0.11270393 0.19814865]]
"""