In my case it works.
Look at the following code, mainly the values of a
and b
before and after the forward-backward pass + optimizer gradient step (I have directly extracted some code snippets from one of my text classification problems)
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn.common_types import _size_2_t
class Conv2d(torch.nn.Conv2d) :
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: _size_2_t = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros'
) :
super(Conv2d, self).__init__(
in_channels = in_channels,
out_channels = out_channels,
kernel_size = kernel_size,
stride = stride,
padding = padding,
dilation = dilation,
groups = groups,
bias = bias,
padding_mode = padding_mode
)
# self.weight.shape ~ n_filters x in_channels x filter_size x emb_dim
filter_size, emb_dim = kernel_size
assert self.weight.shape == torch.Size([n_filters, in_channels, filter_size, emb_dim]) # n_filters x in_channels x filter_size x emb_dim
# Here I just want to make sure that AxB has the same dimensions as self.weight (It's up to you to make sure that it is the same on your side)
torch.manual_seed(0)
intermediate_dim = 7
A = torch.rand((n_filters, in_channels, filter_size, intermediate_dim))
B = torch.rand((intermediate_dim, emb_dim))
# AxB ~ n_filters x in_channels x filter_size x emb_dim)
self.a = torch.nn.Parameter(A, requires_grad=True)
self.b = torch.nn.Parameter(B, requires_grad=True)
# self.weight = torch.matmul(self.a, self.b)
def forward(self, input: Tensor) -> Tensor:
#return self._conv_forward(input, self.weight, self.bias)
return self._conv_forward(input, torch.matmul(self.a, self.b), self.bias)
## Models
in_channels = 1
n_filters = 2
emb_dim = 7
filter_size = 3
model = Conv2d(in_channels = in_channels, out_channels = n_filters, kernel_size = (filter_size, emb_dim))
n_labels = 3
pred_layer = torch.nn.Linear(n_filters, n_labels)
## Data
torch.manual_seed(0)
bs, slen = 5, 6
x = torch.rand((bs, slen, emb_dim))
y = torch.empty(bs, dtype=torch.long).random_(n_labels)
# optimizer
#optimizer = torch.optim.Adam(list(model.parameters()) + list(pred_layer.parameters()), lr=1e-2)
# The update of the parameters of the classification layer is not of interest to us here
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
model.state_dict()
"""
OrderedDict([('weight',
tensor([[[[-0.0464, 0.0466, -0.1422, -0.0112, 0.1562, -0.0224, 0.0061],
[-0.0188, 0.0442, 0.1388, 0.2067, 0.1386, 0.2072, -0.0158],
[-0.1960, -0.1035, 0.1486, -0.0014, -0.1085, -0.1672, -0.2042]]],
[[[-0.1842, -0.0443, 0.1197, 0.1180, -0.2105, 0.1361, -0.1708],
[-0.0461, -0.0885, -0.0420, -0.0428, -0.1958, -0.1884, -0.0341],
[ 0.0028, -0.0991, 0.0822, -0.1964, -0.0147, 0.1919, -0.0890]]]])),
('bias', tensor([0.1971, 0.0790])),
('a',
tensor([[[[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341, 0.4901],
[0.8964, 0.4556, 0.6323, 0.3489, 0.4017, 0.0223, 0.1689],
[0.2939, 0.5185, 0.6977, 0.8000, 0.1610, 0.2823, 0.6816]]],
[[[0.9152, 0.3971, 0.8742, 0.4194, 0.5529, 0.9527, 0.0362],
[0.1852, 0.3734, 0.3051, 0.9320, 0.1759, 0.2698, 0.1507],
[0.0317, 0.2081, 0.9298, 0.7231, 0.7423, 0.5263, 0.2437]]]])),
('b',
tensor([[0.5846, 0.0332, 0.1387, 0.2422, 0.8155, 0.7932, 0.2783],
[0.4820, 0.8198, 0.9971, 0.6984, 0.5675, 0.8352, 0.2056],
[0.5932, 0.1123, 0.1535, 0.2417, 0.7262, 0.7011, 0.2038],
[0.6511, 0.7745, 0.4369, 0.5191, 0.6159, 0.8102, 0.9801],
[0.1147, 0.3168, 0.6965, 0.9143, 0.9351, 0.9412, 0.5995],
[0.0652, 0.5460, 0.1872, 0.0340, 0.9442, 0.8802, 0.0012],
[0.5936, 0.4158, 0.4177, 0.2711, 0.6923, 0.2038, 0.6833]]))])
"""
## Zero all the gradients
optimizer.zero_grad()
## Forward pass
x = x.unsqueeze(dim=1) # bs x 1 x slen x emb_dim
conved = model(x) # bs x n_filters x https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html ?
conved = F.relu(conved).squeeze(3) # bs x n_filters x (slen - emb_dim - filter_size + 1) ? https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
pooled = F.max_pool1d(conved, conved.shape[2]).squeeze(2) # bs x n_filters
pooled = F.dropout(pooled, p = 0.1) # bs x n_filters
y_pred = pred_layer(pooled)
## Loss and backward pass
loss = F.cross_entropy(input=y_pred, target=y) # tensor(13.4064, grad_fn=<NllLossBackward>)
loss.backward()
## optimizer gradient step
optimizer.step()
model.state_dict()
"""
OrderedDict([('weight',
tensor([[[[-0.0464, 0.0466, -0.1422, -0.0112, 0.1562, -0.0224, 0.0061],
[-0.0188, 0.0442, 0.1388, 0.2067, 0.1386, 0.2072, -0.0158],
[-0.1960, -0.1035, 0.1486, -0.0014, -0.1085, -0.1672, -0.2042]]],
[[[-0.1842, -0.0443, 0.1197, 0.1180, -0.2105, 0.1361, -0.1708],
[-0.0461, -0.0885, -0.0420, -0.0428, -0.1958, -0.1884, -0.0341],
[ 0.0028, -0.0991, 0.0822, -0.1964, -0.0147, 0.1919, -0.0890]]]])),
('bias', tensor([0.1871, 0.0690])),
('a',
tensor([[[[0.4863, 0.7582, 0.0785, 0.1220, 0.2974, 0.6241, 0.4801],
[0.8864, 0.4456, 0.6223, 0.3389, 0.3917, 0.0123, 0.1589],
[0.2839, 0.5085, 0.6877, 0.7900, 0.1510, 0.2723, 0.6716]]],
[[[0.9052, 0.3871, 0.8642, 0.4094, 0.5429, 0.9427, 0.0262],
[0.1752, 0.3634, 0.2951, 0.9220, 0.1659, 0.2598, 0.1407],
[0.0217, 0.1981, 0.9198, 0.7131, 0.7323, 0.5163, 0.2337]]]])),
('b',
tensor([[ 0.5746, 0.0232, 0.1287, 0.2322, 0.8055, 0.7832, 0.2683],
[ 0.4720, 0.8098, 0.9871, 0.6884, 0.5575, 0.8252, 0.1956],
[ 0.5832, 0.1023, 0.1435, 0.2317, 0.7162, 0.6911, 0.1938],
[ 0.6411, 0.7645, 0.4269, 0.5091, 0.6059, 0.8002, 0.9701],
[ 0.1047, 0.3068, 0.6865, 0.9043, 0.9251, 0.9312, 0.5895],
[ 0.0552, 0.5360, 0.1772, 0.0240, 0.9342, 0.8702, -0.0088],
[ 0.5836, 0.4058, 0.4077, 0.2611, 0.6823, 0.1938, 0.6733]]))])
"""
a and b have indeed changed.