Jupyter Notebookなどで、コードを実装して実際に確かめてみましょう。
Jetson Nanoでjupyter-notebookを使う場合
mglearnは入ってないのでインストールしておきます
$sudo pip3 install mglearn
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import numpy as np import matplotlib.pyplot as plt from sklearn.datasets import make_moons from sklearn.model_selection import train_test_split moons = make_moons(n_samples=200,noise=0.1,random_state=0) x = moons[0] y = moons[1] x_train,x_test,y_train,y_test = train_test_split(x,y,random_state=0) from sklearn.tree import DecisionTreeClassifier tree_clf = DecisionTreeClassifier().fit(x_train,y_train) tree_clf_3 = DecisionTreeClassifier(max_depth=3).fit(x_train,y_train) print(tree_clf.score(x_test,y_test)) print(tree_clf_3.score(x_test,y_test)) from matplotlib.colors import ListedColormap def plot_decision_boundary(clf,x,y): _x1 = np.linspace(x[:,0].min()-0.5,x[:,0].max()+0.5,100) _x2 = np.linspace(x[:,1].min()-0.5,x[:,1].max()+0.5,100) x1,x2 = np.meshgrid(_x1,_x2) x_new = np.c_[x1.ravel(),x2.ravel()] y_pred = clf.predict(x_new).reshape(x1.shape) custom_cmap = ListedColormap(['mediumblue','orangered']) plt.contourf(x1,x2,y_pred,alpha=0.3,cmap=custom_cmap) def plot_dataset(x,y): plt.plot(x[:,0][y==0],x[:,1][y==0],'bo',ms=15) plt.plot(x[:,0][y==1],x[:,1][y==1],'r^',ms=15) plt.xlabel("$x_0$",fontsize=30) plt.ylabel("$x_1$",fontsize=30,rotation=0) plt.figure(figsize=(24,8)) plt.subplot(121) plot_decision_boundary(tree_clf,x,y) plot_dataset(x,y) plt.subplot(122) plot_decision_boundary(tree_clf_3,x,y) plot_dataset(x,y) plt.show() ---------------------------------------------- import mglearn from sklearn.tree import DecisionTreeRegressor reg_x,reg_y = mglearn.datasets.make_wave(n_samples=100) tree_reg = DecisionTreeRegressor().fit(reg_x,reg_y) tree_reg_3 = DecisionTreeRegressor(max_depth=3).fit(reg_x,reg_y) def plot_regression_predictions(tree_reg,x,y): x1 = np.linspace(x.min()-1,x.max()+1,500).reshape(-1,1) y_pred = tree_reg.predict(x1) plt.xlabel('x',fontsize=30) plt.ylabel('y',fontsize=30,rotation=0) plt.plot(x,y,"bo",ms=15) plt.plot(x1,y_pred,"r-",linewidth=6) plt.figure(figsize=(24,8)) plt.subplot(121) plot_regression_predictions(tree_reg,reg_x,reg_y) plt.subplot(122) plot_regression_predictions(tree_reg_3,reg_x,reg_y) plt.show() |
Leave a Reply