Simple neural network not converging

Hi, I am just beginning to learn deep learning in pytorch. I am running the following code I got from pytorch tutorial by Justin Johnson.

#With autograd
import torch
from torch.autograd import Variable

dtype = torch.cuda.FloatTensor

N, D_in, H, D_out = 64, 1000, 100, 10

x = Variable(torch.randn(N, D_in).type(dtype), requires_grad = False)
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad = False)

w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad = True)
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad = True)

learning_rate = 1e-6
for t in range(500):
    y_pred = x.mm(w1).clamp(min = 0).mm(w2)
    
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.data[0])
    
    #w1.grad.data.zero_()
    #w2.grad.data.zero_()
    
    loss.backward()
    #print w1.grad.data
    #print w2.grad.data
    w1.data -= learning_rate * w1.grad.data
    w2.data -= learning_rate * w2.grad.data

But it seems it is exploding and after 123 epochs, the loss is becoming nan. This is the output-

(0, 31723518.0)
(1, 28070452.0)
(2, 8525556.0)
(3, 14738816.0)
(4, 9347755.0)
(5, 12841774.0)
(6, 18114290.0)
(7, 6447365.5)
(8, 11224685.0)
(9, 9882719.0)
(10, 2951912.0)
(11, 2978006.25)
(12, 6616687.5)
(13, 7743705.0)
(14, 5883046.5)
(15, 3643038.25)
(16, 2570257.25)
(17, 2455251.0)
(18, 2659530.75)
(19, 2724341.5)
(20, 2513530.25)
(21, 2057666.625)
(22, 1586186.375)
(23, 1254101.625)
(24, 1110446.375)
(25, 1110734.0)
(26, 1145980.0)
(27, 1071132.875)
(28, 910926.4375)
(29, 782463.5)
(30, 719357.125)
(31, 717793.9375)
(32, 761821.125)
(33, 756986.375)
(34, 682688.4375)
(35, 646783.5625)
(36, 679672.0)
(37, 676811.0)
(38, 600790.3125)
(39, 631020.375)
(40, 692508.6875)
(41, 696700.5625)
(42, 615305.625)
(43, 504780.4375)
(44, 505154.0)
(45, 507697.0625)
(46, 498239.1875)
(47, 478827.5)
(48, 531659.3125)
(49, 472687.5)
(50, 433654.9375)
(51, 504356.59375)
(52, 475822.34375)
(53, 465258.40625)
(54, 490428.53125)
(55, 542419.6875)
(56, 480332.28125)
(57, 456323.03125)
(58, 548866.5)
(59, 460200.1875)
(60, 582967.375)
(61, 467767.125)
(62, 399487.1875)
(63, 525414.75)
(64, 563015.5)
(65, 630127.125)
(66, 339907.625)
(67, 485001.0625)
(68, 541414.6875)
(69, 637931.8125)
(70, 424327.5)
(71, 444804.25)
(72, 542814.6875)
(73, 624015.6875)
(74, 405953.71875)
(75, 523452.90625)
(76, 604742.4375)
(77, 624313.0625)
(78, 665899.8125)
(79, 796917.625)
(80, 1059727.875)
(81, 1661096.375)
(82, 3876985.5)
(83, 5157832.0)
(84, 2041864.25)
(85, 5117962.0)
(86, 5582782.0)
(87, 9489012.0)
(88, 28304358.0)
(89, 92396984.0)
(90, 135757312.0)
(91, 30141958.0)
(92, 36246224.0)
(93, 63904096.0)
(94, 27171200.0)
(95, 22396498.0)
(96, 18266130.0)
(97, 25967810.0)
(98, 23575290.0)
(99, 8453866.0)
(100, 13056855.0)
(101, 7837615.5)
(102, 10242168.0)
(103, 8700571.0)
(104, 178546768.0)
(105, 311015104.0)
(106, 264007536.0)
(107, 31766490.0)
(108, 79658920.0)
(109, 19210790.0)
(110, 20177744.0)
(111, 24349004.0)
(112, 158815472.0)
(113, 51590388.0)
(114, 42294844.0)
(115, 20198332.0)
(116, 26488356.0)
(117, 14971826.0)
(118, 296145664.0)
(119, 11408661504.0)
(120, 472693047296.0)
(121, 1.5815737104924672e+16)
(122, 2.7206068612442637e+30)
(123, inf)
(124, nan)
(125, nan)
(126, nan)
(127, nan)
(128, nan)
(129, nan)
(130, nan)
(131, nan)
(132, nan)
(133, nan)
(134, nan)
(135, nan)
(136, nan)
(137, nan)
(138, nan)
(139, nan)
(140, nan)
(141, nan)
(142, nan)
(143, nan)
(144, nan)
(145, nan)
(146, nan)
(147, nan)
(148, nan)
(149, nan)
(150, nan)
(151, nan)
(152, nan)
(153, nan)
(154, nan)
(155, nan)
(156, nan)
(157, nan)
(158, nan)
(159, nan)
(160, nan)
(161, nan)
(162, nan)
(163, nan)
(164, nan)
(165, nan)
(166, nan)
(167, nan)
(168, nan)
(169, nan)
(170, nan)
(171, nan)
(172, nan)
(173, nan)
(174, nan)
(175, nan)
(176, nan)
(177, nan)
(178, nan)
(179, nan)
(180, nan)
(181, nan)
(182, nan)
(183, nan)
(184, nan)
(185, nan)
(186, nan)
(187, nan)
(188, nan)
(189, nan)
(190, nan)
(191, nan)
(192, nan)
(193, nan)
(194, nan)
(195, nan)
(196, nan)
(197, nan)
(198, nan)
(199, nan)
(200, nan)
(201, nan)
(202, nan)
(203, nan)
(204, nan)
(205, nan)
(206, nan)
(207, nan)
(208, nan)
(209, nan)
(210, nan)
(211, nan)
(212, nan)
(213, nan)
(214, nan)
(215, nan)
(216, nan)
(217, nan)
(218, nan)
(219, nan)
(220, nan)
(221, nan)
(222, nan)
(223, nan)
(224, nan)
(225, nan)
(226, nan)
(227, nan)
(228, nan)
(229, nan)
(230, nan)
(231, nan)
(232, nan)
(233, nan)
(234, nan)
(235, nan)
(236, nan)
(237, nan)
(238, nan)
(239, nan)
(240, nan)
(241, nan)
(242, nan)
(243, nan)
(244, nan)
(245, nan)
(246, nan)
(247, nan)
(248, nan)
(249, nan)
(250, nan)
(251, nan)
(252, nan)
(253, nan)
(254, nan)
(255, nan)
(256, nan)
(257, nan)
(258, nan)
(259, nan)
(260, nan)
(261, nan)
(262, nan)
(263, nan)
(264, nan)
(265, nan)
(266, nan)
(267, nan)
(268, nan)
(269, nan)
(270, nan)
(271, nan)
(272, nan)
(273, nan)
(274, nan)
(275, nan)
(276, nan)
(277, nan)
(278, nan)
(279, nan)
(280, nan)
(281, nan)
(282, nan)
(283, nan)
(284, nan)
(285, nan)
(286, nan)
(287, nan)
(288, nan)
(289, nan)
(290, nan)
(291, nan)
(292, nan)
(293, nan)
(294, nan)
(295, nan)
(296, nan)
(297, nan)
(298, nan)
(299, nan)
(300, nan)
(301, nan)
(302, nan)
(303, nan)
(304, nan)
(305, nan)
(306, nan)
(307, nan)
(308, nan)
(309, nan)
(310, nan)
(311, nan)
(312, nan)
(313, nan)
(314, nan)
(315, nan)
(316, nan)
(317, nan)
(318, nan)
(319, nan)
(320, nan)
(321, nan)
(322, nan)
(323, nan)
(324, nan)
(325, nan)
(326, nan)
(327, nan)
(328, nan)
(329, nan)
(330, nan)
(331, nan)
(332, nan)
(333, nan)
(334, nan)
(335, nan)
(336, nan)
(337, nan)
(338, nan)
(339, nan)
(340, nan)
(341, nan)
(342, nan)
(343, nan)
(344, nan)
(345, nan)
(346, nan)
(347, nan)
(348, nan)
(349, nan)
(350, nan)
(351, nan)
(352, nan)
(353, nan)
(354, nan)
(355, nan)
(356, nan)
(357, nan)
(358, nan)
(359, nan)
(360, nan)
(361, nan)
(362, nan)
(363, nan)
(364, nan)
(365, nan)
(366, nan)
(367, nan)
(368, nan)
(369, nan)
(370, nan)
(371, nan)
(372, nan)
(373, nan)
(374, nan)
(375, nan)
(376, nan)
(377, nan)
(378, nan)
(379, nan)
(380, nan)
(381, nan)
(382, nan)
(383, nan)
(384, nan)
(385, nan)
(386, nan)
(387, nan)
(388, nan)
(389, nan)
(390, nan)
(391, nan)
(392, nan)
(393, nan)
(394, nan)
(395, nan)
(396, nan)
(397, nan)
(398, nan)
(399, nan)
(400, nan)
(401, nan)
(402, nan)
(403, nan)
(404, nan)
(405, nan)
(406, nan)
(407, nan)
(408, nan)
(409, nan)
(410, nan)
(411, nan)
(412, nan)
(413, nan)
(414, nan)
(415, nan)
(416, nan)
(417, nan)
(418, nan)
(419, nan)
(420, nan)
(421, nan)
(422, nan)
(423, nan)
(424, nan)
(425, nan)
(426, nan)
(427, nan)
(428, nan)
(429, nan)
(430, nan)
(431, nan)
(432, nan)
(433, nan)
(434, nan)
(435, nan)
(436, nan)
(437, nan)
(438, nan)
(439, nan)
(440, nan)
(441, nan)
(442, nan)
(443, nan)
(444, nan)
(445, nan)
(446, nan)
(447, nan)
(448, nan)
(449, nan)
(450, nan)
(451, nan)
(452, nan)
(453, nan)
(454, nan)
(455, nan)
(456, nan)
(457, nan)
(458, nan)
(459, nan)
(460, nan)
(461, nan)
(462, nan)
(463, nan)
(464, nan)
(465, nan)
(466, nan)
(467, nan)
(468, nan)
(469, nan)
(470, nan)
(471, nan)
(472, nan)
(473, nan)
(474, nan)
(475, nan)
(476, nan)
(477, nan)
(478, nan)
(479, nan)
(480, nan)
(481, nan)
(482, nan)
(483, nan)
(484, nan)
(485, nan)
(486, nan)
(487, nan)
(488, nan)
(489, nan)
(490, nan)
(491, nan)
(492, nan)
(493, nan)
(494, nan)
(495, nan)
(496, nan)
(497, nan)
(498, nan)
(499, nan)

Can someone please show me what is wrong here?

Uncomment the first pair of comments in the for loop. The gradient buffers have to be manually reset before fresh gradients are calculated.

Your for loop should be:

for t in range(500):
    y_pred = x.mm(w1).clamp(min = 0).mm(w2)
    
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.data[0])
    
    w1.grad.data.zero_()
    w2.grad.data.zero_()
    
    loss.backward()
    #print w1.grad.data
    #print w2.grad.data
    w1.data -= learning_rate * w1.grad.data
    w2.data -= learning_rate * w2.grad.data
3 Likes

In the latest version of pytorch, uncommenting those 2 lines shows error. That is why I uncommented them.

AttributeError: 'NoneType' object has no attribute 'data'

@nafizh1
Try this:

Manually zero the gradients before running the backward pass

if t:
w1.grad.data.zero_()
w2.grad.data.zero_()

I had this issue too and found out that in the latest the grad data is not created (not just initialized but its not even created) until backward is called and gradients need to be computed. The above code basically checks:

If first iteration (i.e, t = 0) then don’t zero the data else zero out the grad data.

1 Like

Thanks, that solved the problem. But I was wondering if there is a more elegant solution. Also, why do you have to zero the grad data every time in a loop? I assume that is something that should take care of itself.

Looking at the examples like https://github.com/pytorch/examples/blob/master/mnist/main.py, the canonical way to do this appears to be calling optimizer.zero_grad() unconditionally.

I also have question why we have to run
w1.grad.data.zero_()
w2.grad.data.zero_()
before
loss.backward()
every step?

@Chun_Li

Yes we must.

I suppose that it may be the design choice. The designer set a buffer to keep the accumulated sum of every parameter’s grad. For every training step, loss.backward() calculate the current grads for every parameter, and then dump them to the buffer. After that, we can use optimizer.step() to apply these grads which stored on the buffer to update parameters. So if we don’t use w1.grad.data.zero_() or optimizer.zero_grad() before update the parameters, we will apply the accumulated grads to update them, which may be wrong in most situations.