クロスバリデーション

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from sklearn import svm, metrics
import random, re
 
lines = open('iris.csv', 'r', encoding='utf-8').read().split("\n")
f_tonum = lambda n : float(n) if re.match(r'^[0-9\.]+$', n) else n
f_cols = lambda li: list(map(f_tonum, li.strip().split(',')))
csv = list(map(f_cols, lines))
del csv[0]
random.shuffle(csv)
 
K = 5
csvk = [ [] for i in range(K) ]
for i in range(len(csv)):
    csvk[i % K].append(csv[i])
 
def split_data_label(rows):
    data = []; label = []
    for row in rows:
        data.append(row[0:4])
        label.append(row[4])
    return (data, label)
 
def calc_score(test, train):
    test_f, test_l = split_data_label(test)
    train_f, train_l = split_data_label(train)
    clf = svm.SVC()
    clf.fit(train_f, train_l)
    pre = clf.predict(test_f)
    return metrics.accuracy_score(test_l, pre)
 
score_list = []
for testc in csvk:
    trainc = []
    for i in csvk:
        if i != testc: trainc += i
    sc = calc_score(testc, trainc)
    score_list.append(sc)
print("各正解率=", score_list)
print("平均成果率=", sum(score_list) / len(score_list))

各正解率= [0.9666666666666667, 1.0, 1.0, 0.9333333333333333, 1.0]
平均成果率= 0.9800000000000001

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import pandas as pd
from sklearn import cross_validation, svm, metrics
from sklearn.grid_search import GridSearchCV
 
train_csv = pd.read_csv("./mnist/train.csv")
test_csv = pd.read_csv("./mnist/t10k.csv")
 
train_label = train_csv.ix[:, 0]
train_data = train_csv.ix[:, 1:577]
test_label = test_csv.ix[:, 0]
test_data = test_csv.ix[:, 1:577]
print("学習データ数=", len(train_label))
 
params = [
    {"C": [1,10,100,1000], "kernel":["linear"]},
    {"C": [1,10,100,1000], "kernel":["rbf"], "gamma":[0.001, 0.0001]}
]
 
clf = GridSearchCV(svm.SVC(), params, n_jobs = -1)
clf.fit(train_data, train_label)
print("学習器=", clf.best_estimator_)
 
pre = clf.predict(test_data)
ac_score = metrics.accuracy_score(pre, test_label)
print("正解率=", ac_score)