X = torch.tensor(df_admission.values)
class LinearRegressionGD():
def init(self, nb_feats):
super(LinearRegressionGD, self).init()
# Create a bias torch.tensor and set it to 0
# Crate a weight torch.tensor and set its values to 0’s
self.nb_feats = nb_feats
self.bias = torch.tensor(0.0, requires_grad=True)
self.weights = torch.zeros(self.nb_feats, requires_grad=True)
def forward(self, X):
# Do forward propagation to obtain yhat
yhat = self.bias + torch.matmul(X, self.weights)
return yhat
def backward(self, X, yhat, y):
# Perform backward propagation to compute the gradients of the loss function
# with respect to the bias and weights
loss = MSELoss(yhat, y)
loss.backward()
return loss
def update(self, lr):
# Update the bias and weights using the gradients computed in the backward pass
with torch.no_grad():
self.bias -= lr * self.bias.grad
self.weights -= lr * self.weights.grad
# Set the gradients to zero
self.bias.grad.zero_()
self.weights.grad.zero_()
Initialise your Linear Regression model with the corresponding number of features.
model = LinearRegressionGD(nb_feats=xTrainNorm.shape[1])
def trainManualGD(model, X, y, num_epochs, learning_rate=0.001):
loss = []
for epoch in range(num_epochs):
# 1. Perform forward propagation on the data samples and obtain the model's prediction.
yhat = model.forward(X)
# 2. Perform backward propagation to compute the gradients of the loss function over the weights, and over the bias term.
grad_w, grad_bias = model.backward(X, y, yhat)
# 3. Use gradient descent to update the weight and bias values: w = w - lr * grad_w
model.weights = model.weights - learning_rate * grad_w
model.bias = model.bias - learning_rate * grad_bias
# 4. Performing forward propagation and computing the MSE loss for logging purposes
yhat = model.forward(X)
curr_loss = MSELoss(yhat, y)
loss.append(curr_loss)
if not epoch % 100:
print('Epoch: {:03d} | MSE: {:.7f}'.format(epoch+1, curr_loss))
return loss
loss = trainManualGD(linear_model, X, y, num_epochs=10000, learning_rate=0.001)
RuntimeError: size mismatch, got 500, 500x8,1