医療職からデータサイエンティストへ

統計学、機械学習に関する記事をまとめています。

多項式曲線フィッティング~パターン認識と機械学習~

パターン認識機械学習の第1章多項式曲線フィッティングについてまとめます。

多項式曲線フィッティング

多項式曲線フィッティングは、目的変数にうまくフィットするような線形モデルを作成します。今回はsin関数を多項式フィッティングしていきます。

> x <-  seq(0,1,length=10)
> t <- sin(2*pi*x) + rnorm(8 ,0,0.4)
> plot(x , t ,xlim = c(min(x),max(x)),ylim=c(-2 , 2) , col = "blue")
> curve(sin(2*pi*x),add = TRUE , col="green")

f:id:h-wadsworth02:20181222221724p:plain:w400
まずは、sin(2\pi x)正規分布に従うノイズを加えた目的データtとして、以下のようなxの多項式で近似します。

 \large f(x) = w_0 + w_1x + w_2x^{2} +\cdots + w_mx^{m}

Mが次数を表しています。
係数wはどうやって求めるかというと、f(x)で 求めた値と目的変数tの値が最小になるようにwを決めれば良いので、

\large  E(w) = \dfrac{1}{2}\sum_{n=1}^{N}(f(x) - t_n)^{2}

このE(w)を最小化するようにwを求めます。これは最小二乗法と呼ばれて、偏微分を使って解くのですが、詳しい方法については割愛します!
f:id:h-wadsworth02:20181222224107p:plain:w400
グラフのイメージだと、この赤い線を足した長さが一番小さくなるようにwを決めます!それでは、一次式、三次式、七次式でそれぞれ近似してみましょう!

>#七次までのxの値と目的変数tからなるデータフレームを作成
> df <- data.frame("x" = x,
+                  "x2" = x^2,
+                  "x3" = x^3,
+                  "x4" = x^4,
+                  "x5" = x^5,
+                  "x6" = x^6,
+                  "x7" = x^7,
+                  "t" = t
+ )
>#lm関数を使って一次式、三次式、七次式の線形モデルをつくる
> model1 <- lm(t ~ x, data=df)
> model3 <- lm(t ~ x + x2+x3, data=df)
> model7 <- lm(t ~ x + x2+x3+x4+x5+x6+x7, data=df)
> model3

Call:
lm(formula = t ~ x + x2 + x3, data = df)

Coefficients:
(Intercept)            x           x2           x3  
    -0.3165      13.7086     -39.2870      25.9682 

三次式を見てみると、
 t = f(x) = -0.32 + 13.7x + -39.3x^{2} + 30.0 x^{3}
という式で近似できたことが分かります。 グラフにしてみるとこんな感じです。⇩

f:id:h-wadsworth02:20181222221806p:plain:w400
これを見ると、七次式の時が一番データに当てはまっているように見えます。

ここで、それぞれの線形式がどれぐら当てはまっているかを数値で求めてみましょう。どのように計算するかというと
 誤差 = \sqrt{\dfrac{\sum(t - fx)^{2}}{N}}
で計算できて、これを平均二乗誤差(Root-Mean-Square Error)と呼びます!
数字嫌いな方は理解しなくても大丈夫です!先ほど赤い線の長さ合計が一番小さくなるようにwを決めました。その赤い線の長さ平均みたいなイメージです!

さてそれぞれの一次式、三次式、七次式の平均二乗誤差を計算すると...

>#平均二乗誤差(RMSE)を計算する関数
> rmse <- function(model, y,newdata,n ){
+   (y - predict(model , newdata = newdata))^2 %>% 
+     sum() %>% 
+     divide_by(n) %>%   
+     sqrt()
+ }
>#一次式のRMSE
> rmse(model1, t, df , 10)
[1] 0.699838
>#三次式のRMSE
> rmse(model3, t, df , 10)
[1] 0.3113414
>#七次式のRMSE
> rmse(model7, t, df , 10)
[1] 0.235343

グラフで見た通り、やはり七次式のRMSEが一番小さくなっています。

予測誤差

先ほどのデータでは、七次式が一番当てはまりがよくなりました。
次に目的変数のノイズを新しく選び、t2とした変数を作成します。
これをテストデータとして、先ほど作成した線形式との平均二乗誤差を求めてみましょう!

>t2 <- sin(2*pi*x) + rnorm(10 ,0,0.4)
>#一次式のRMSE
> rmse(model1, t2, df , 10)
[1] 0.8378053
>#三次式のRMSE
> rmse(model3, t2, df , 10)
[1] 0.4696673
>#七次式のRMSE
> rmse(model7, t2, df , 10)
[1] 0.5158258

今度は三次式のモデルが一番当てあまりがよくなりました!
一〜七次式のモデルで、元々のデータ(t)とテストデータ(t2)のRMSEをグラフ化してみると

f:id:h-wadsworth02:20181223095833p:plain:w400

元々のデータ(train)では七次式が一番当てはまりが良いですが、テストデータではそうではないですね。改めてグラフを見てみると
f:id:h-wadsworth02:20181223100137p:plain:w400
どうやらデータのノイズにもフィットしているようです。これを過学習と言います!過学習したモデルは元々のデータにはよく当てはまりますが、新たなデータに対しては当てはまりが悪くなります。

過学習への対処法

過学習への対処法には様々な手法がありますが、ここでは2つ紹介します。

サンプルの数を増やす

先ほどの例では、サンプル数が10個に対して七次式、つまりパラメータの数が8個ありました。サンプルの数に対してパラメーターの数が多いと過学習しやすくなります。そこでサンプルの数を100個に増やしてみましょう! f:id:h-wadsworth02:20181223101619p:plain:w400

過学習が抑えらているように見えます。 同じようにRMSEの値をプロットしてみると
f:id:h-wadsworth02:20181223101756p:plain:w400

おお!どうやらサンプル数を増やすと過学習の問題は減るようです。

罰則項をつける

もう一つの方法が、係数wを求める際に罰則をつけることです。式で表すと
 \large \tilde{E}(x) =  \dfrac{1}{2}\sum_{n=1}^{N}(f(x) - t_n)^{2}+\dfrac{\lambda}{2}\sum_{m=1}^{M}w_m^{2}

第一項は通常の最小二乗法ですが、第二項で係数の二乗を足したものを入れることで、この関数を最小化する際に各係数が大きくなりすぎないように罰則がつきます。ちなみに係数の二乗和を罰則項とする方法をリッジ回帰(ridge)と呼びます。罰則を使った過学習への対処法はまた詳しくまとめたいと思うので、今回はふわっとだけ触れます!

まとめ

今回は多項式フィッティングと過学習についてまとめてみました。過学習問題は機械学習をする上で避けては通れないので、学習を続けていきます。

最後までお読み頂きありがとうございました!

 参考書

パターン認識と機械学習 上

パターン認識と機械学習 上