OrderedDict([(‘W.0.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.1.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.2.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.3.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.4.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.5.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.6.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.7.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.8.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.9.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.10.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.11.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.12.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.13.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘W.14.weight’, tensor([[ 6.7416e-04, 7.2296e-04, 1.0417e-03, …, -1.4441e-04,
-2.0521e-04, 1.0587e-04],
[-2.1848e-03, -1.2637e-03, 1.3776e-03, …, 2.1392e-04,
5.4779e-05, 3.0876e-04],
[ 1.2177e-04, 3.9099e-04, 1.6234e-03, …, -8.8361e-04,
-9.1579e-04, 3.7307e-04],
…,
[-1.8724e-03, 1.4805e-03, -7.5514e-04, …, -1.1449e-03,
1.6069e-03, -1.8179e-03],
[ 1.7913e-03, -1.7511e-03, 1.4971e-03, …, -1.3609e-04,
3.4944e-04, -4.7839e-04],
[-1.7623e-03, 2.2035e-03, -2.4384e-03, …, 8.3601e-04,
-7.4248e-04, 4.3736e-04]], device=‘cuda:0’, dtype=torch.float64)), (‘param.0’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.1’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.2’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.3’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.4’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.5’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.6’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.7’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.8’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.9’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.10’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.11’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.12’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.13’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64)), (‘param.14’, tensor(0.0021, device=‘cuda:0’, dtype=torch.float64))])
That’s an interesting observation. Could you describe your model architecture in more detail as it seems as if you are reusing the same parameter(s)?
Thank you for your attention,I think maybe I made a mistake in using this model,I think I made a mistake in parameter initialization。
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)
class MyLISTA_Cpss(nn.Module):
def init(self, n, m, W_e, layers, lamda, share):
super(MyLISTA_Cpss, self).__init__()
self.A = W_e
self.L = lamda
self.layers = layers
self.m = m
self.n = n
self.share = share
A = self.A.cpu().numpy()
self.scale = 1.001 * np.linalg.norm (A, ord=2)**2
L = self.L
w = torch.from_numpy(A.T / self.scale)
w = w.to(device)
s = torch.from_numpy(np.eye(A.shape[1]) - np.matmul(A.T, A) / self.scale)
s = s.to(device)
theta = (self.L / self.scale)
theta = torch.tensor(theta)
theta = theta.to(device)
self.W = nn.ModuleList([nn.Linear(n, m, bias=False) for i in range(self.layers)])
self.param = nn.ParameterList([nn.Parameter(theta, requires_grad=True) for i in range(self.layers)])
def weight_init1(e):
if isinstance(e, nn.Linear):
e.weight = nn.Parameter(w, requires_grad=True)
self.W.apply(weight_init1)
def snapsss_shrink(self, x, theta):
dim1, dim2, dim3 = x.shape
kuaipai = torch.norm(x, dim=1)
kuaipai = kuaipai.unsqueeze(1)
kuaipai = kuaipai.repeat(1, dim2, 1)
guodu = torch.maximum(kuaipai - theta, torch.tensor(0.0).to(device))
res = torch.div(guodu, (guodu + theta))
result = x * res
return result
def forward(self, y):
m = self.m
n = self.n
batch = y.shape[0]
snaps = y.shape[2]
x_pre = []
xh = torch.zeros((batch, m, snaps), dtype=torch.float64).to(device)
print()
x_pre.append(xh)
y = y.permute(0, 2, 1)
for t in range (self.layers):
W = self.W[t]
param = self.param[t]
batch_A = self.A.unsqueeze(0).repeat(batch, 1, 1)
y_y = y - torch.bmm(batch_A, xh).permute(0, 2, 1)
By = W(y_y)
Sy = xh.permute(0, 2, 1) + By
xh = self.snapsss_shrink (Sy, param)
x_pre.append (xh)
xh = xh.permute(0, 2, 1)
return x_pre[-1]
Based on your code snippet you are initializing all parameters with the same value:
self.W = nn.ModuleList([nn.Linear(n, m, bias=False) for i in range(self.layers)])
self.param = nn.ParameterList([nn.Parameter(theta, requires_grad=True) for i in range(self.layers)])
def weight_init1(e):
if isinstance(e, nn.Linear):
e.weight = nn.Parameter(w, requires_grad=True)
self.W.apply(weight_init1)
so your results might be expected. Try to randomly initialize the parameters and rerun your code to see if different final parameters are learned.
Thank you very much.My goal is to simulate the iteration of soft threshold algorithm, so initialization with the same parameters is more consistent with the real iteration。I don’t know whether the layer by layer greedy training can improve the results