scikit-learn とは
- 機械学習で使われているアルゴリズムに対応
- 機械学習の結果を検証する機能がある
- 機械学習でよく利用されるライブラリ(「Pandas」「NumPy」「Scipy」「Matplotlib」など)と親和性が高い
- BSDライセンスのオープンソースのため、無料で商用利用が可能
AND演算を試す
ゴールの設定
- 入力(X,Y)と結果(X and Y)の全パターンを学習させる
- 正しい結果(X and Y)に分類してくれるかを評価する
アルゴリズムの選択
アルゴリズムチートシートを利用することも可能
今回は「LinearSVC」を採用します。
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score
# Data for learning
learn_data = [[0,0], [1,0], [0,1], [1,1]]
learn_label = [0, 0, 0, 1]
# Algorithm
clf = LinearSVC()
# Learning
clf.fit(learn_data, learn_label)
# Data for testing
test_data = [[0,0], [1,0], [0,1], [1,1]]
test_label = clf.predict(test_data)
# Evaluation
print("Prediction for ", test_data, " is ", test_label)
print("Accuracy score = ", accuracy_score([0, 0, 0, 1], test_label))
- sklearn.svm.LinearSVC : LinearSVCアルゴリズムを利用するためのパッケージ
- sklearn.metrics.accuracy_score:テスト結果を評価するためのパッケージ
結果:正解率が1.0で100%である
XOR演算を試す
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score
# Data for learning
learn_data = [[0,0], [1,0], [0,1], [1,1]]
learn_label = [0, 1, 1, 0]
# Algorithm
clf = LinearSVC()
# Learning
clf.fit(learn_data, learn_label)
# Data for testing
test_data = [[0,0], [1,0], [0,1], [1,1]]
test_label = clf.predict(test_data)
# Evaluation
print("Prediction for ", test_data, " is ", test_label)
print("Accuracy score = ", accuracy_score([0, 1, 1, 0], test_label))
結果:正解率が25%, 50%, 75%と実行する度に変化する。100%にはならない。
別のアルゴリズムを試す。ここでは、「KNeighborsClassifier」を採用します。
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
# Data for learning
learn_data = [[0,0], [1,0], [0,1], [1,1]]
learn_label = [0, 1, 1, 0]
# Algorithm
clf = KNeighborsClassifier(n_neighbors = 1)
# Learning
clf.fit(learn_data, learn_label)
# Data for testing
test_data = [[0,0], [1,0], [0,1], [1,1]]
test_label = clf.predict(test_data)
# Evaluation
print("Prediction for ", test_data, " is ", test_label)
print("Accuracy score = ", accuracy_score([0, 1, 1, 0], test_label))
- KNeighborsClassifier(n_neighbors = 1) では、コンストラクタを呼び出してオブジェクトを生成しているが、そのパラメータに n_neighbors を指定している。
n_neighbors
により、近傍として取り扱うデータの数を指定しています。
結果:正解率が1.0で100%である
評価結果が良くない場合、別のアルゴリズムを採用することで、評価が向上することがあります。