1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | from sklearn import svm, metrics import random, re lines = open ( 'iris.csv' , 'r' , encoding = 'utf-8' ).read().split( "\n" ) f_tonum = lambda n : float (n) if re.match(r '^[0-9\.]+$' , n) else n f_cols = lambda li: list ( map (f_tonum, li.strip().split( ',' ))) csv = list ( map (f_cols, lines)) del csv[ 0 ] random.shuffle(csv) K = 5 csvk = [ [] for i in range (K) ] for i in range ( len (csv)): csvk[i % K].append(csv[i]) def split_data_label(rows): data = []; label = [] for row in rows: data.append(row[ 0 : 4 ]) label.append(row[ 4 ]) return (data, label) def calc_score(test, train): test_f, test_l = split_data_label(test) train_f, train_l = split_data_label(train) clf = svm.SVC() clf.fit(train_f, train_l) pre = clf.predict(test_f) return metrics.accuracy_score(test_l, pre) score_list = [] for testc in csvk: trainc = [] for i in csvk: if i ! = testc: trainc + = i sc = calc_score(testc, trainc) score_list.append(sc) print ( "各正解率=" , score_list) print ( "平均成果率=" , sum (score_list) / len (score_list)) |
各正解率= [0.9666666666666667, 1.0, 1.0, 0.9333333333333333, 1.0]
平均成果率= 0.9800000000000001
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | import pandas as pd from sklearn import cross_validation, svm, metrics from sklearn.grid_search import GridSearchCV train_csv = pd.read_csv( "./mnist/train.csv" ) test_csv = pd.read_csv( "./mnist/t10k.csv" ) train_label = train_csv.ix[:, 0 ] train_data = train_csv.ix[:, 1 : 577 ] test_label = test_csv.ix[:, 0 ] test_data = test_csv.ix[:, 1 : 577 ] print ( "学習データ数=" , len (train_label)) params = [ { "C" : [ 1 , 10 , 100 , 1000 ], "kernel" :[ "linear" ]}, { "C" : [ 1 , 10 , 100 , 1000 ], "kernel" :[ "rbf" ], "gamma" :[ 0.001 , 0.0001 ]} ] clf = GridSearchCV(svm.SVC(), params, n_jobs = - 1 ) clf.fit(train_data, train_label) print ( "学習器=" , clf.best_estimator_) pre = clf.predict(test_data) ac_score = metrics.accuracy_score(pre, test_label) print ( "正解率=" , ac_score) |