Section2 단순선형회귀(Simple Linear Regression)

Gerald·2022년 5월 5일
0

1.선형회귀분석(Linear Regression)이란이란?

회귀분석은 독립변수와 종속변수간의 관계를 보여주는 통계적 방법입니다. 즉 X(독립변수)에 따라 달라지는 Y(종속변수)의 변화를 예측하는 것으로 이해하면 될 것입니다. 여기서 독립변수와 종속변수는 표현이 다양함으로 한번 짚고 넘어가자.

X(독립변수)
예측변수(predict)
설명변수(explanatory)
특성(feature)

Y(종속변수)
반응변수(response)
레이블, 라벨(label)
타겟(target)

코딩할때는 feature, target을 주로 많이썻다.

선형이라는건 직선이니까 기울기를 가지지 않을까?? 맞다 따라서 선형회귀모델계수가 존재하며
y= ax+b 단순하게 1차 방정식처럼 생겼다. 여기서 기울기 a, y절편 b를 계수라 부른다.

2.기준모델(Baseline Model)

모델 성능을 비교하기 위해서는 기준이 되는 모델이 필요합니다. 따라서 모델 성능의 기준이 되는 모델을 Baseline model이라고 합니다. 따라서 문제별로 기준모델은 아래와 같이 설정한다고 합니다.

회귀문제 : 타겟의 평균값
분류문제 : 타겟의 최빈 클래스
시계열회귀문제 : 이전 타임스탬프의 값

3.실습

이제 한번 코드를 사용해서 해보겠습니다.
(데이터 파일 : https://www.kaggle.com/harlfoxem/housesalesprediction?select=kc_house_data.csv)
케글의 경우 train, test 두가지 데이터가 파일이 존재합니다.

df.info() 

데이터 정보확인(결측치, 타입 등), 결측치는 없고 따라서 여기선 날짜 데이트를 타입을 바꿔주면 좋을꺼 같다.

df['date'] = pd.to_datetime(df['date'])
df.info()

다음으로 특성들과 타겟(price)과의 상관계수를 알아보고 높은 특성을 찾아서 시각화 해보겠습니다.

correlation = df.corr()
correlation2 = correlation.sort_values(by='price', ascending=False) 
correlation2['price']

import matplotlib.pyplot as plt
import seaborn as sns
plt.scatter(df['price'], df['sqft_living'])

feature가 커질수록 타겟도 증가하는거 같습니다. 이제 기준모델을 그려서 다시 보겠습니다.

x = df['sqft_living']
y = df['price']

predict = df['price'].mean()
errors = predict - df['price']
mean_absolute_error = errors.abs().mean()

sns.lineplot(x=x, y=predict, color='red')
sns.scatterplot(x=x, y=y, color='blue');

확실히 평균으로 쭉 가기 때문에 성능이 매우 안좋을꺼 같습니다. 따라서 Scikit-Learn 라이브러리를 사용해 특성 sqft_living에 대한 선형회귀모델을 만들어 보겠습니다.

from sklearn.linear_model import LinearRegression

model = LinearRegression()

feature = ['sqft_living']
target = ['price']
X_train = df[feature]
y_train = df[target]

model.fit(X_train, y_train)

y_pred = model.predict(X_train)

plt.scatter(X_train, y_train, color='black', linewidth=1)
plt.scatter(X_train, y_pred, color='blue', linewidth=1);

이렇게 그리면 파란점이 모델 학습을 한 그래프를 확인할 수 있습니다.
마지막으로 선형회귀계수를 알아보자

In) model.coef_, model.intercept_

Out) (array([[280.6235679]]), array([-43580.74309447]))

이렇게 기울기가 280이며 절편이 -43580임을 알 수 있다.

profile
비전공자로 도전하기 시작입니다!

0개의 댓글