Analysis
로지스틱 회귀 편미분 정리 및 구현
da-da-da
2022. 11. 14. 09:32
728x90
반응형
편미분 계산 과정
기본 구현
모델
import torch
import torchvision
import torch.nn as nn
from torchvision import datasets, models, transforms
import os
import numpy as np
class LR(nn.Module):
def __init__(self, dim, lr=torch.scalar_tensor(0.01)):
super(LR, self).__init__()
self.w = torch.zeros(dim, 1, dtype=torch.float).to(device)
self.b = torch.scalar_tensor(0).to(device)
self.grads = {'dw': torch.zeros(dim, 1, dtype=torch.float).to(device),
'db': torch.scalar_tensor(0).to(device)}
self.lr = lr.to(device)
def forward(self, x):
z = torch.mm(self.w.T, x) + self.b
a = self.sigmoid(z)
return a
def sigmoid(self, z):
return 1/ (1 + torch.exp(-z))
def backward(self, x, yhat, y):
self.grads['dw'] = (1/x.shape[1]) * torch.mm(x, (yhat-y).T)
self.grads['db'] = (1/x.shape[1]) * torch.sum(yhat-y)
def optimize(self):
self.w = self.w - self.lr * self.grads['dw']
self.b = self.b - self.lr * self.grads['db']
def loss(yhat, y):
m = y.size()[1]
return -(1/m) * torch.sum(y * torch.log(yhat) + (1-y)*torch.log(1-yhat))
def predict(yhat, y):
y_prediction = torch.zeros(1, y.size()[1])
for i in range(yhat.size()[1]):
if yhat[0, i] <= 0.5:
y_prediction[0, i] = 0
else:
y_prediction[0, i] = 1
return 100 - torch.mean(torch.abs(y_prediction - y)) * 100
학습
costs = []
dim = x_flatten.shape[0]
learning_rate = torch.scalar_tensor(0.0001).to(device)
num_iterations = 100
lrmodel = LR(dim, learning_rate)
lrmodel = lrmodel.to(device)
def transform_data(x, y):
x_flatten = x.T
y = y.unsqueeze(0)
return x_flatten, y
for i in range(num_iterations):
x, y = next(iter(train_dataset))
test_x, test_y = next(iter(test_dataset))
x, y = transform_data(x, y)
test_x, test_y = transform_data(test_x, test_y)
# forward
yhat = lrmodel.forward(x)
cost = loss(yhat.data.cpu(), y)
train_pred = predict(yhat, y)
# backward
lrmodel.backward(x.to(device), yhat.to(device), y.to(device))
lrmodel.optimize()
# test
yhat_test = lrmodel.forward(test_x.to(device))
test_pred = predict(yhat_test, test_y)
if i % 10 == 0:
costs.append(cost)
print(f'Cost after iteration {i} : {cost} | Train Acc: {train_pred} | Test Acc : {test_pred} ')
Reference
728x90
반응형