Unable to Learn XOR Representation using 2 layers of Multi-Layered Perceptron (MLP)

@jpeg729, you’re missing optimizer.zero_grad() at the start of the training loop. The results will change quite a bit if you add it.

python xor.py --loss MSELoss --learning_rate 0.001 --activation Tanh --optimizer Adam
100 / 100 = 100.00% successes

python xor.py --loss MSELoss --learning_rate 0.001 --activation Sigmoid --optimizer Adam
99 / 100 = 99.00% successes

python xor.py --loss MSELoss --learning_rate 0.001 --activation ReLU --optimizer Adam
77 / 100 = 77.00% successes

Oh. Oops. And I’m the one who pointed out that it was missing in the original code. :thinking:

I tried again with ELU and SGD, and got 2% success. Tuning the learning rate does nothing much, but adding momentum=.9 gives 98% success with SGD.

I suppose all this goes to show the vast number of ways we can alter our models and get widely differing results.

That’s probably because not zeroing the gradient has the same effect as adding momentum: it retains some of the gradient from the previous training step.

Exactly what I thought. Though the fact that not zeroing the gradients actually worked, suggests that with ELU activation and MSELoss the problem was nicely convex, so that repeatedly taking steps in roughly the same direction pretty nearly always led to the solution.

It’s pretty much overkilling but sweeping Activation X Loss X Optimizer:


from itertools import product
from collections import Counter 

import time

import random
random.seed(100)

import numpy as np

import torch
from torch import nn
from torch.autograd import Variable
from torch import FloatTensor
from torch import optim
use_cuda = torch.cuda.is_available()

# Activation functions.
from torch.nn import ReLU, ReLU6, ELU, SELU, LeakyReLU
from torch.nn import Hardtanh, Sigmoid, Tanh, LogSigmoid
from torch.nn import Softplus, Softshrink, Tanhshrink, Softmin
from torch.nn import Softmax, LogSoftmax # Softmax2d


# Loss functions.
from torch.nn import L1Loss, MSELoss # NLLLoss, CrossEntropyLoss
from torch.nn import PoissonNLLLoss, KLDivLoss, BCELoss
from torch.nn import BCEWithLogitsLoss, HingeEmbeddingLoss # MarginRankingLoss
from torch.nn import SmoothL1Loss, SoftMarginLoss # MultiLabelMarginLoss, CosineEmbeddingLoss, 
from torch.nn import MultiLabelSoftMarginLoss # MultiMarginLoss, TripletMarginLoss

# Optimizers.
from torch.optim import Adadelta, Adagrad, Adam, Adamax # SparseAdam
from torch.optim import ASGD, RMSprop, Rprop # LBFGS

Activations = [ReLU, ReLU6, ELU, SELU, LeakyReLU, 
                Hardtanh, Sigmoid, Tanh, LogSigmoid,
                Softplus, Softshrink, Tanhshrink, Softmin, 
                Softmax, LogSoftmax]

Criterions = [L1Loss, MSELoss,
              PoissonNLLLoss, KLDivLoss, BCELoss,
              BCEWithLogitsLoss, HingeEmbeddingLoss,
              SmoothL1Loss, SoftMarginLoss,


Criterions = [L1Loss, MSELoss,
              PoissonNLLLoss, KLDivLoss, BCELoss,
              BCEWithLogitsLoss, HingeEmbeddingLoss,
              SmoothL1Loss, SoftMarginLoss,
              MultiLabelSoftMarginLoss]

Optimizers = [Adadelta, Adagrad, Adam, Adamax,
             ASGD, RMSprop, Rprop]

X = xor_input = np.array([[0,0], [0,1], [1,0], [1,1]])
Y = xor_output = np.array([[0,1,1,0]]).T

# Converting the X to PyTorch-able data structure.
X_pt = Variable(FloatTensor(X))
X_pt = X_pt.cuda() if use_cuda else X_pt
# Converting the Y to PyTorch-able data structure.
Y_pt = Variable(FloatTensor(Y), requires_grad=False)
Y_pt = Y_pt.cuda() if use_cuda else Y_pt

# Use FloatTensor.shape to get the shape of the matrix/tensor.
num_data, input_dim = X_pt.shape
num_data, output_dim = Y_pt.shape

learning_rate = 0.03
hidden_dim = 5
num_epochs = 10000
num_experiments = 100



for Activation, Criterion, Optimizer in product(Activations, Criterions, Optimizers):
    all_results=[]
    start = time.time()
    for _ in range(num_experiments):
        model = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                              Activation(), 
                              nn.Linear(hidden_dim, output_dim),
                              nn.Sigmoid())
        model = model.cuda() if use_cuda else model
        criterion = Criterion()
        optimizer = Optimizer(model.parameters(), lr=learning_rate)
        
        for _e in range(num_epochs):
            optimizer.zero_grad()
            predictions = model(X_pt)
            loss_this_epoch = criterion(predictions, Y_pt)
            loss_this_epoch.backward()
            optimizer.step()
            ##print(_e, [float(_pred) for _pred in predictions], list(map(int, Y_pt)), loss_this_epoch.data[0])

        x_pred = [int(model(_x) > 0.5) for _x in X_pt]
        y_truth = list([int(_y[0]) for _y in Y_pt])
        all_results.append([x_pred == y_truth, x_pred, loss_this_epoch.data[0]])

    tf, outputsss, losses__ = zip(*all_results)
    print(Activation, Criterion, Optimizer, Counter(tf), time.time() - start)

I’ve some numbers from the sweep: https://github.com/alvations/ixora/blob/master/xor.output

It seems like nothing gets close to the Python implementation on https://www.kaggle.com/alvations/xor-with-mlp where I get 100% all the time.

I tried out your code and got your results of non-convergence. I then upped the hidden dimension to 20, and I got convergence 100% of the time. Can you confirm that you see that also before we try to figure out why you can converge with a bigger one-layer hidden dimension? (I think it has to do with saddle points but first I want to know you get convergence with a bigger net.)

Also, you don’t need to try out so many losses and Activations. For this problem, BCELoss() and RELU will work just fine.

Hi,

In theory, we should be able to obtain a solution with a much smaller network (ie, 2 hidden units + bias). Please see Section 6.1 of Goodfellow et al (2016).

The smooth L1 loss and the selu activation function seem to help in the learning process. Below please find a solution that uses as starting base the autograd example.

# -*- coding: utf-8 -*-
import torch
import numpy as np
from torch.autograd import Variable
from torch import FloatTensor
import torch.nn.functional as F

dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 2, 2, 2, 1

# Create random Tensors to hold input and outputs, and wrap them in Variables.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to these Variables during the backward pass.
x = Variable(FloatTensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]])))
y = Variable(FloatTensor(np.array([[0., 1., 1., 0.]])))

# Create random Tensors for weights, and wrap them in Variables.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Variables during the backward pass.
W = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
w = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)

c = Variable(torch.zeros(D_in).type(dtype), requires_grad=True)
b = Variable(torch.zeros(D_out).type(dtype), requires_grad=True)

learning_rate = 1e-3
for t in range(200000):
    # Forward pass: compute predicted y using operations on Variables; these
    # are exactly the same operations we used to compute the forward pass using
    # Tensors, but we do not need to keep references to intermediate values since
    # we are not implementing the backward pass by hand.

    y_pred = F.selu(x.mm(W).add(c)).mm(w).add(b)

    # Compute and print loss using operations on Variables.
    # Now loss is a Variable of shape (1,) and loss.data is a Tensor of shape
    # (1,); loss.data[0] is a scalar value holding the loss.
    # loss = (y_pred - y).pow(2).sum()
    loss = F.smooth_l1_loss(y_pred, y)
    if t % 10000 == 0:
        print(t, loss.data[0])
        print(t, y_pred.data)
        # print(t, c.data)
        # print(t, w.data)

    # Use autograd to compute the backward pass. This call will compute the
    # gradient of loss with respect to all Variables with requires_grad=True.
    # After this call w1.grad and w2.grad will be Variables holding the gradient
    # of the loss with respect to w1 and w2 respectively.
    loss.backward()

    # Update weights using gradient descent; w1.data and w2.data are Tensors,
    # w1.grad and w2.grad are Variables and w1.grad.data and w2.grad.data are
    # Tensors.
    W.data -= learning_rate * W.grad.data
    w.data -= learning_rate * w.grad.data
    c.data -= learning_rate * c.grad.data
    b.data -= learning_rate * b.grad.data

    # Manually zero the gradients after updating weights
    W.grad.data.zero_()
    w.grad.data.zero_()
    c.grad.data.zero_()
    b.grad.data.zero_()

print("W: ")
print(W)

print("w: ")
print(w)


1 Like

Hi,

After running the code from above many times, I noticed that in some cases the process got stuck in local minima (error around 0.125). To avoid this, I added time-dependent Gaussian noise to the gradients. With noise, the results are much better.

import torch
import numpy as np
from torch.autograd import Variable
from torch import FloatTensor
import torch.nn.functional as F

dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 2, 2, 2, 1

# Create random Tensors to hold input and outputs, and wrap them in Variables.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to these Variables during the backward pass.
x = Variable(FloatTensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]])))
y = Variable(FloatTensor(np.array([[0., 1., 1., 0.]])))

# Create random Tensors for weights, and wrap them in Variables.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Variables during the backward pass.
W = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
w = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)

c = Variable(torch.zeros(D_in).type(dtype), requires_grad=True)
b = Variable(torch.zeros(D_out).type(dtype), requires_grad=True)

# Create tensors to simulate a normal distribution
W_zeros = torch.zeros(D_in, H).type(dtype)
w_zeros = torch.zeros(H, D_out).type(dtype)
c_zeros = torch.zeros(D_in).type(dtype)
b_zeros = torch.zeros(D_out).type(dtype)

W_sigma = torch.zeros(D_in, H).type(dtype)
w_sigma = torch.zeros(H, D_out).type(dtype)
c_sigma = torch.zeros(D_in).type(dtype)
b_sigma = torch.zeros(D_out).type(dtype)


learning_rate = 1e-3
for t in range(400000):
    # Forward pass: compute predicted y using operations on Variables; these
    # are exactly the same operations we used to compute the forward pass using
    # Tensors, but we do not need to keep references to intermediate values since
    # we are not implementing the backward pass by hand.

    y_pred = F.selu(x.mm(W).add(c)).mm(w).add(b)

    # Compute and print loss using operations on Variables.
    # Now loss is a Variable of shape (1,) and loss.data is a Tensor of shape
    # (1,); loss.data[0] is a scalar value holding the loss.
    # loss = (y_pred - y).pow(2).sum()
    loss = F.smooth_l1_loss(y_pred, y)
    if t % 10000 == 0:
        print(t, loss.data[0])
        print(t, y_pred.data)
        # print(t, c.data)
        # print(t, w.data)

    # Use autograd to compute the backward pass. This call will compute the
    # gradient of loss with respect to all Variables with requires_grad=True.
    # After this call w1.grad and w2.grad will be Variables holding the gradient
    # of the loss with respect to w1 and w2 respectively.
    loss.backward()

    # Update sigma
    s_2 = 2.0 / np.power(1+t, 0.55)
    W_sigma.fill_(np.sqrt(s_2))
    w_sigma.fill_(np.sqrt(s_2))
    c_sigma.fill_(np.sqrt(s_2))
    b_sigma.fill_(np.sqrt(s_2))

    # Update the gradients
    mW = torch.distributions.Normal(W_zeros, W_sigma)
    W.grad.data += mW.sample()

    mw = torch.distributions.Normal(w_zeros, w_sigma)
    w.grad.data += mw.sample()

    mc = torch.distributions.Normal(c_zeros, c_sigma)
    c.grad.data += mc.sample()

    mb = torch.distributions.Normal(b_zeros, b_sigma)
    b.grad.data += mb.sample()


    # Update weights using gradient descent; w1.data and w2.data are Tensors,
    # w1.grad and w2.grad are Variables and w1.grad.data and w2.grad.data are
    # Tensors.
    W.data -= learning_rate * W.grad.data
    w.data -= learning_rate * w.grad.data
    c.data -= learning_rate * c.grad.data
    b.data -= learning_rate * b.grad.data

    # Manually zero the gradients after updating weights
    W.grad.data.zero_()
    w.grad.data.zero_()
    c.grad.data.zero_()
    b.grad.data.zero_()

print("W: ")
print(W)

print("w: ")
print(w)

Interesting, I never thought of adding noise. Nor did I check to see what the theoretical minimum network size was. I did all my experiments with 5 hidden units, and like yours, most of my models would occasionally get stuck in local minima. Had I realised that 5 units was more than necessary, I would have tried adding dropout.

I notice that your model doesn’t use any activation after the second linear layer where I used sigmoid, which may have been slowing down the flow of gradients in some cases.

Hi,

Regarding the architecture, I was just trying to replicate the results from the book.

But after looking at the gradients, I noticed I needed a mechanism to get out of a “vicious circle” and noise seemed like a good starting given the recent results in the literature (reinforcement learning, training very deep networks, etc).

Hope this helps!

1 Like

Going back to the original question, (why does the PyTorch version not succeed as well as the numpy version), maybe we have been going about this the wrong way.

The PyTorch version is obviously not doing the same thing as the numpy version and apart from the non-linearity between the linear layers, we haven’t touched on what those differences are, nor why they could be important.

Here are some of the differences between the numpy version and the pytorch version in the first post.

The weight initialisation

In the numpy version

# random float values uniformly taken from [0, 1)
W1 = np.random.random((input_dim, hidden_dim))
W2 = np.random.random((hidden_dim, output_dim))

In the PyTorch version (from the source code for nn.Linear)

# random values taken uniformly from (-1/sqrt(input_size), 1/sqrt(input_size))
stdv = 1. / math.sqrt(input_size)
self.weight.data.uniform_(-stdv, stdv)

The learning rate

In the numpy version, learning_rate = 1
In the PyTorch version, learning_rate = 0.03

The loss function

In the numpy version

def cost(predicted, truth):
    return truth - predicted # N.B. this is NOT equivalent to L1 loss

In the PyTorch version

criterion = nn.L1Loss() # = abs(predictions - target)

The biggest difference here is in the way they are used.

In the PyTorch version criterion is added to the end of the computation graph and then differentiated with the rest of the computation graph. In the numpy version cost is used directly, which I think is equivalent to doing predictions.backward(cost(predictions)) in PyTorch.

I think the PyTorch loss function that would create the same result as the numpy update would be…

criterion = .5 * torch.nn.functional.mse_loss(predictions, target, size_average=False)

Justification: d_criterion / d_predictions == predictions - target. The difference in sign is corrected later when applying the update. (The numpy version adds by doing self.W1 += lr*grad, whereas PyTorch SGD subtracts by doing param.data.add_(-group['lr'], param.grad.data).)

I can’t see any other significant differences between the numpy version and the PyTorch version.

Hi,
Currently traveling and it is hard to read the code in the phone, so apologies in advance if I create more confusion…

In your numpy code, if your cost function returns a vector, arent you applying a batch size of 1 vs 4 of Pytorch? In other words, you calculate the gradients for each example and then add them together?

I believe this is not what is happening in Pytorch, where you summarize the error of the 4 examples and then backprop the error.

The numpy version keeps the error for each sample separate. The PyTorch loss with size_average=False sums the errors, nevertheless, I think the end result is the same since the gradient of a sum is the sum of the gradients of the parts.

Hi,

I will take a closer look at your code this weekend.

Was looking online for alternative Pytorch solutions, and found this gist (I havent run it as I am far away from a linux box :o( )

The only difference I can spot vs your implementation is the inner loop to update the weights for each sample.

Hope this helps!

Hi,

I think I managed to reproduce your numpy code with Pytorch, including the nice results :o)

Just be aware that I used the MSE error function.

Hope this helps!

import torch
import numpy as np
from torch.autograd import Variable
from torch import FloatTensor
import torch.nn.functional as F

dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
D_in, H, D_out = 2, 5, 1

# Create random Tensors to hold input and outputs, and wrap them in Variables.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to these Variables during the backward pass.
x = Variable(FloatTensor(np.array([[0., 0.], [0., 1.], [1., 0.], [1., 1.]])), requires_grad=False)
y = Variable(FloatTensor(np.array([[0., 1., 1., 0.]])), requires_grad=False)

# Create random Tensors for weights, and wrap them in Variables.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Variables during the backward pass.
W1 = Variable(torch.Tensor(D_in, H).uniform_(0., 1.).type(dtype), requires_grad=True)
W2 = Variable(torch.Tensor(H, D_out).uniform_(0., 1.).type(dtype), requires_grad=True)

print("W1: ", W1.data)
print("W2: ", W2.data)


learning_rate = 1.

for t in range(10000):
    # Forward pass: compute predicted y using operations on Variables; these
    # are exactly the same operations we used to compute the forward pass using
    # Tensors, but we do not need to keep references to intermediate values since
    # we are not implementing the backward pass by hand.

    layer1 = F.sigmoid(x.mm(W1))
    layer2 = F.sigmoid(layer1.mm(W2))

    # Compute and print loss using operations on Variables.
    # Now loss is a Variable of shape (1,) and loss.data is a Tensor of shape
    # (1,); loss.data[0] is a scalar value holding the loss.
    # loss = (y_pred - y).pow(2).sum()

    loss = (y.t() - layer2).pow(2).sum()
    # print("Loss: ")
    # print(t, layer2.data)

    # print(t, loss.data[0])
    if t % 1000 == 0:
        print("Loss: ")
        print(t, loss.data[0])
        print("Current predictions: ")
        print(t, layer2.data)

    # Use autograd to compute the backward pass. This call will compute the
    # gradient of loss with respect to all Variables with requires_grad=True.
    # After this call w1.grad and w2.grad will be Variables holding the gradient
    # of the loss with respect to w1 and w2 respectively.
    loss.backward()

    # Update weights using gradient descent; w1.data and w2.data are Tensors,
    # w1.grad and w2.grad are Variables and w1.grad.data and w2.grad.data are
    # Tensors.
    W1.data -= learning_rate * W1.grad.data
    W2.data -= learning_rate * W2.grad.data

    # Manually zero the gradients after updating weights
    W1.grad.data.zero_()
    W2.grad.data.zero_()


print("W: ")
print(W1)

print("w: ")
print(W2)



2 Likes

And here is a version using nn.Linear, optim.SGD and F.mse_loss.

As I said above MSE loss is mathematically equivalent to a constant times the cost function in the numpy version.

10000 epochs is overkill. Most of the time this thing gets to 100% accuracy in under 200 epochs.

I tried timing epochs with the numpy version and this pytorch version. The numpy version is ~3x faster per epoch. I am sure that a larger model would run faster in pytorch.

import torch
import numpy as np
from torch.autograd import Variable
from torch import FloatTensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
D_in, H, D_out = 2, 5, 1

x = Variable(FloatTensor(np.array([[0., 0.], [0., 1.], [1., 0.], [1., 1.]])), requires_grad=False)
y = Variable(FloatTensor(np.array([[0., 1., 1., 0.]])), requires_grad=False)

# Create two linear modules and initialize their weights
L1 = nn.Linear(D_in, H, bias=False)
L2 = nn.Linear(H, D_out, bias=False)
L1.weight.data.uniform_(0., 1.).type(dtype)
L2.weight.data.uniform_(0., 1.).type(dtype)

print("W1: ", L1.weight.data)
print("W2: ", L2.weight.data)

optimizer = optim.SGD([L1.weight, L2.weight], lr=1.)

success = False
for epoch in range(1000):
    layer1 = F.sigmoid(L1(x))
    layer2 = F.sigmoid(L2(layer1))

    loss = F.mse_loss(layer2, y, size_average=False)
    
    worst_error = (y.t() - layer2).abs().max()
    if not success and worst_error.data[0] < .5:
        print("100% accuracy achieved in", epoch+1, "epochs")
        success = True
    if worst_error.data[0] < .45:
        break

    if epoch % 100 == 0:
        print("Epoch %d: Loss %f  Predictions %s" % (epoch+1, loss.data[0], ' '.join(["%.3f" % p for p in (layer2.data.cpu().numpy())])))
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

print("Epoch %d: Loss %f  Predictions %s" % (epoch+1, loss.data[0], ' '.join(["%.3f" % p for p in (layer2.data.cpu().numpy())])))

print("W1: ", L1.weight.data)
print("W2: ", L2.weight.data)
2 Likes

Very good! I think we managed to address the issue!

Regarding run-time performance, I would say it is somehow expected, after all Pytorch needs to estimate the gradients while the Numpy version has hand-coded formulas. I am kind of used to see performance costs of 4x in Stan-Math and Autograd, for example.

This post was very interesting given the “simplicity” of the problem, yet online you find similar issues with CNTK and Tensorflow, for example.

Wouldnt be good if a clean version of the solution from above ends as a example in pytorch website?

@pedronahum that’s interesting! I would have expect nn.Sequential to do the same with the code you’ve posted. Currently, you’re manually assigning the weights to the optimizer.

Hi @alvations,

And it does. Please have a look at the code below (minor changes to the code from @jpeg729)

import torch
import numpy as np
from torch.autograd import Variable
from torch import FloatTensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
D_in, H, D_out = 2, 5, 1

x = Variable(FloatTensor(np.array([[0., 0.], [0., 1.], [1., 0.], [1., 1.]])), requires_grad=False)
y = Variable(FloatTensor(np.array([[0., 1., 1., 0.]])), requires_grad=False)

# Create two linear modules and initialize their weights
L1 = nn.Linear(D_in, H, bias=False)
L2 = nn.Linear(H, D_out, bias=False)
L1.weight.data.uniform_(0., 1.).type(dtype)
L2.weight.data.uniform_(0., 1.).type(dtype)

model = nn.Sequential(L1,
                      nn.Sigmoid(),
                      L2,
                      nn.Sigmoid())


print("W1: ", L1.weight.data)
print("W2: ", L2.weight.data)

optimizer = optim.SGD(model.parameters(), lr=1.)

success = False
for epoch in range(10000):

    layer2 = model(x)

    loss = F.mse_loss(layer2, y, size_average=False)

    worst_error = (y.t() - layer2).abs().max()
    if not success and worst_error.data[0] < .5:
        print("100% accuracy achieved in", epoch + 1, "epochs")
        success = True
    if worst_error.data[0] < .05:
        break

    if epoch % 100 == 0:
        print("Epoch %d: Loss %f  Predictions %s" % (
        epoch + 1, loss.data[0], ' '.join(["%.3f" % p for p in (layer2.data.cpu().numpy())])))

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

print("Epoch %d: Loss %f  Predictions %s" % (
epoch + 1, loss.data[0], ' '.join(["%.3f" % p for p in (layer2.data.cpu().numpy())])))

print("W1: ", L1.weight.data)
print("W2: ", L2.weight.data)
1 Like