前回のOctaveで散布図をプロットしてみるに続いて、今後はOctaveで最急降下法を実装して、θの値を探索してみます。
結論から言うと失敗した。どこかfeature scalingで間違っているっぽいんだけど、どう間違っているのかがわからず。改めてリベンジしたい。ひとまずは記録を残しておく。
前に実装した目的関数(costFunctionJ.m)。
function J = costFunctionJ(X, y, theta) m = size(X,1); predictions = X*theta; sqrErrors = (predictions-y).^2; J = 1 / (2*m) * sum(sqrErrors);
今回は、新たに最急降下法の実装をしてみた。解説を書きたいところだけども、ちょっと時間が足りなそうなので一旦諦める。
gradientDescent.m
% X => 訓練セットのfeatureのベクトル % y => 訓練セットの結果のベクトル % theta => θの初期値 % alpha => 学習の度合い % num_iters => 反復回数 function theta = gradientDescent(X, y, theta, alpha, num_iters) m = length(y); for iter = 1:num_iters costFunctionJ(X, y, theta) predictions = X * theta; theta(1) = theta(1) - alpha / m * sum(predictions - y); theta(2) = theta(2) - alpha / m * sum((predictions - y) .* X(:, 2)); end
実行してみます。
octave> load ramenX.txt octave> ramenX ramenX = 596 522 1135 598 389 605 484 214 417 445 351 134 525 32 251 136 63 134 193 150 octave> load ramenY.txt octave> y = ramenY y = 98.657 98.454 97.738 97.650 97.461 97.207 97.179 96.901 96.538 95.969 94.228 93.939 93.225 92.935 92.223 92.154 91.902 91.728 91.420 91.341 octave> m = length(X) m = 20 octave> X = [ones(m, 1), X]; octave> alpha = 0.01; octave> theta = [0;0]; octave> gradientDescent(X, y, theta, alpha, 100) ans = 4510.4 ans = 1.2653e+10 ans = 5.1166e+16 ans = 2.0690e+23 ans = 8.3668e+29 ans = 3.3833e+36 ans = 1.3681e+43 ans = 5.5325e+49 ans = 2.2372e+56 ans = 9.0468e+62 ans = 3.6583e+69 ans = 1.4794e+76 ans = 5.9822e+82 ans = 2.4191e+89 ans = 9.7822e+95 ans = 3.9557e+102 ans = 1.5996e+109 ans = 6.4684e+115 ans = 2.6157e+122 ans = 1.0577e+129 ans = 4.2772e+135 ans = 1.7296e+142 ans = 6.9942e+148 ans = 2.8283e+155 ans = 1.1437e+162 ans = 4.6249e+168 ans = 1.8702e+175 ans = 7.5627e+181 ans = 3.0582e+188 ans = 1.2367e+195 ans = 5.0008e+201 ans = 2.0222e+208 ans = 8.1774e+214 ans = 3.3068e+221 ans = 1.3372e+228 ans = 5.4073e+234 ans = 2.1866e+241 ans = 8.8421e+247 ans = 3.5755e+254 ans = 1.4459e+261 ans = 5.8468e+267 ans = 2.3643e+274 ans = 9.5608e+280 ans = 3.8662e+287 ans = 1.5634e+294 ans = 6.3220e+300 ans = Inf ans = Inf ans = Inf ans = Inf ans = Inf ans = Inf
学習の度合い(α)が大きすぎて点が降下しなかった。値を変更してやりおなし。
octave> alpha = 0.0000099 alpha = 9.9000e-06 octave:89> gradientDescent(X, y, theta, alpha, 1000) ans = 4510.4 ans = 4459.2 ans = 4408.9 ans = 4359.5 ans = 4310.8 ans = 4262.9 ans = 4215.8 ans = 4169.5 ans = 4123.9 ans = 4079.1 ans = 4035.0 ans = 3991.7 ans = 3949.0 ans = 3907.0 ans = 3865.8 ans = 3825.1 ans = 3785.2 ans = 3745.9 ans = 3707.3 ans = 3669.3 ans = 3631.9 ans = 3595.1 ans = 3558.9 ans = 3523.3 ans = 3488.3 ans = 3453.9 ans = 3420.0 ans = 3386.7 ans = 3353.9 ans = 3321.6 ans = 3289.9 ans = 3258.7 ans = 3228.0 ans = 3197.8 ans = 3168.2 ans = 3138.9 ans = 3110.2 ans = 3081.9 ans = 3054.1 ans = 3026.8 ans = 2999.9 ans = 2973.4 ans = 2947.4 ans = 2921.8 ans = 2896.6 ans = 2871.8 ans = 2847.5 ans = 2823.5 ans = 2799.9 ans = 2776.7 ans = 2753.9 ans = 2731.5 ...skipping... ans = 1373.0 ans = 1372.9 ans = 1372.9 ans = 1372.9 ans = 1372.9 ans = 1372.9 ans = 1372.9 ans = 1372.9 ans = 1372.9 ans = 1372.9 ans = 1372.9 ans = 1372.9 ans = 1372.8 ans = 1372.8 ans = 1372.8 ans = 1372.8 ans = 1372.8 ans = 1372.8 ans = 1372.8 ans = 1372.8 ans = 1372.8 ans = 1372.8 ans = 1372.8 ans = 1372.7 ans = 1372.7 ans = 1372.7 ans = 1372.7 ans = 1372.7 ans = 1372.7 ans = 1372.7 ans = 1372.7 ans = 1372.7 ans = 1372.7 ans = 1372.7 ans = 1372.7 ans = 1372.6 ans = 1372.6 ans = 1372.6 ans = 1372.6 ans = 1372.6 ans = 1372.6 ans = 1372.6 ans = 1372.6 ans = 1372.6 ans = 1372.6 ans = 1372.6 ans = 1372.5 ans = 0.29602 0.17578
一応降下はしたものの、途中で降下が鈍くなって最適解にたどり着いていない模様。こういう場合はパラメータfeatureをスケーリングすればいいので見よう見まねでやってみる。ここもスケーリングの理論があるんだけれども書くのに時間がかかるので実装だけ載せちゃう。
octave> range = max(X(:, 2)) - min(X(:, 2)) octave> avg = sum(X(:,2)) / length(X) octave> scaledX = [ones(m, 1), (X(:, 2) - avg) ./ range] octave> alpha = 0.3 octave> gradientDescent(scaledX, y, theta, alpha, 500) ans = 4510.4 ans = 2211.7 ans = 1085.4 ans = 533.44 ans = 262.96 ans = 130.41 ans = 65.435 ans = 33.575 ans = 17.943 ans = 10.263 ans = 6.4798 ans = 4.6069 ans = 3.6705 ans = 3.1936 ans = 2.9424 ans = 2.8023 ans = 2.7171 ans = 2.6594 ans = 2.6156 ans = 2.5791 ans = 2.5466 ans = 2.5166 ans = 2.4881 ans = 2.4608 ans = 2.4346 ans = 2.4092 ans = 2.3846 ans = 2.3608 ans = 2.3377 ans = 2.3153 ans = 2.2936 ans = 2.2725 ans = 2.2521 ans = 2.2323 ans = 2.2131 ans = 2.1945 ans = 2.1764 ans = 2.1589 ans = 2.1419 ans = 2.1254 ans = 2.1094 ans = 2.0939 ans = 2.0789 ans = 2.0643 ans = 2.0501 ans = 2.0364 ans = 2.0231 ans = 2.0102 ans = 1.9977 ans = 1.9856 ans = 1.9738 ans = 1.9624 ...skipping... ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 1.5954 ans = 94.9424 8.3081
どうやら最小値らへんに収束したっぽい。
ただ残念なことにθの値がこれだとすると、仮説関数が以下になるので、
グラフにプロットするまでもなく傾きが間違っていることがわかる。
うーん。またリベンジしよう。