アヤメの品種分類

GitHubからcsvをダウンロードします
https://github.com/pandas-dev/pandas/tree/master/pandas/tests/data

from sklearn import svm, metrics
import random, re

csv = []
with open(‘iris.csv’, ‘r’, encoding=’utf-8′) as fp:
for line in fp:
line = line.strip()
cols = line.split(‘,’)
fn = lambda n : float(n) if re.match(r’^[0-9\.]+$’, n) else n
cols = list(map(fn, cols))
csv.append(cols)

del csv[0]

random.shuffle(csv)

total_len = len(csv)
train_len = int(total_len * 2 / 3)
train_data = []
train_label = []
test_data = []
test_label = []
for i in range(total_len):
data = csv[i][0:4]
label = csv[i][4]
if i < train_len: train_data.append(data) train_label.append(label) else: test_data.append(data) test_label.append(label) clf = svm.SVC() clf.fit(train_data, train_label) pre = clf.predict(test_data) ac_score = metrics.accuracy_score(test_label, pre) print("正解率=", ac_score) [/python] [vagrant@localhost python]$ python3 app.py return f(*args, **kwds) 正解率= 0.96 [python] import pandas as pd from sklearn import svm, metrics, cross_validation csv = pd.read_csv('iris.csv') csv_data = csv[["SepalLength","SepalWidth","PetalLength","PetalWidth"]] csv_label = csv["Name"] train_data, test_data, train_label, test_label = \ cross_validation.train_test_split(csv_data, csv_label) clf = svm.SVC() clf.fit(train_data, train_label) pre = clf.predict(test_data) ac_score = metrics.accuracy_score(test_label, pre) print("正解率=", ac_score) [/python] [vagrant@localhost python]$ python3 app.py 正解率= 0.9736842105263158