【Python】 SVMによるデータの分類
書籍「PYTHON機械学習プログラミング」を読んでいて、SVMによるXORデータの分類例があったのですが、分類結果をグラフ化する部分が勉強になったのでメモします。
# -*- coding: utf-8 -*- import numpy as np from matplotlib.colors import ListedColormap from matplotlib import pyplot as plt from sklearn.svm import SVC def plot_decision_regions(X, y, classifier, resolution=0.02): # マーカーとカラーマップの準備 markers = ('s', 'x', 'o', '^', 'v') colors = ('red', 'blue', 'lightgreen', 'gray', 'cyan') cmap = ListedColormap(colors[:len(np.unique(y))]) # 領域の最大最小値 x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1 x2_min, x2_max = X[:, 1].min() - 1, X[:, 1].max() + 1 # グリッドポイントの生成 xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution), np.arange(x2_min, x2_max, resolution)) # 予測を実行する Z = classifier.predict(np.array([xx1.ravel(), xx2.ravel()]).T) # 予測結果をプロットする Z = Z.reshape(xx1.shape) plt.contourf(xx1, xx2, Z, alpha=0.4, cmap=cmap) plt.xlim(xx1.min(), xx1.max()) plt.ylim(xx2.min(), xx2.max()) for i, cls in enumerate(np.unique(y)): plt.scatter(X[y==cls, 0], X[y==cls, 1], alpha=0.8, c=cmap(i), marker=markers[i], label=cls) plt.show() if __name__ == '__main__': # データの生成 np.random.seed(0) X = np.random.randn(200, 2) y = np.logical_xor(X[:, 0] > 0, X[:, 1] > 0) y = np.where(y, 1, -1) # SVMによるフィッティング svm = SVC(gamma='auto') # γを大きくし過ぎると過適合してしまう svm.fit(X, y) # 分類結果を表示する plot_decision_regions(X, y, svm)
Python機械学習プログラミング 達人データサイエンティストによる理論と実践 (impress top gear)
- 作者: Sebastian Raschka,株式会社クイープ,福島真太朗
- 出版社/メーカー: インプレス
- 発売日: 2016/06/30
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (4件) を見る