-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtestDimensionalReduction.py
More file actions
65 lines (51 loc) · 8.7 KB
/
testDimensionalReduction.py
File metadata and controls
65 lines (51 loc) · 8.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.naive_bayes import BernoulliNB
import numpy as np
import util
dataset = datasets.fetch_mldata("MNIST Original")
(xtr, xte, ytr, yte) = train_test_split(dataset.data / 255.0, dataset.target.astype("int0"), test_size = 0.33)
result = []
for n_components in range(1,101):
xtrain, xtest = util.getPrincipleComponents(xtr, xte, n_components=n_components)
clf = KNeighborsClassifier()
clf = clf.fit(xtrain, ytr)
y_pred = clf.predict(xtest)
score = clf.score(xtest,yte)
result.append((n_components, score))
#1 to 200
#result = [(1, 0.27493506493506492), (2, 0.415974025974026), (3, 0.48632034632034632), (4, 0.62329004329004334), (5, 0.73623376623376624), (6, 0.82766233766233765), (7, 0.87051948051948047), (8, 0.89783549783549788), (9, 0.91415584415584417), (10, 0.92926406926406924), (11, 0.93614718614718617), (12, 0.94454545454545458), (13, 0.95077922077922072), (14, 0.95692640692640696), (15, 0.95861471861471859), (16, 0.96324675324675324), (17, 0.96432900432900437), (18, 0.96645021645021645), (19, 0.96796536796536792), (20, 0.96991341991341995), (21, 0.9709523809523809), (22, 0.97177489177489174), (23, 0.97264069264069264), (24, 0.97277056277056273), (25, 0.97298701298701296), (26, 0.9736796536796537), (27, 0.97450216450216454), (28, 0.97428571428571431), (29, 0.97476190476190472), (30, 0.97545454545454546), (31, 0.97558441558441555), (32, 0.9751515151515151), (33, 0.97554112554112549), (34, 0.97523809523809524), (35, 0.97497835497835494), (36, 0.9754112554112554), (37, 0.97506493506493508), (38, 0.97580086580086578), (39, 0.9752813852813853), (40, 0.97497835497835494), (41, 0.97519480519480517), (42, 0.9761471861471861), (43, 0.97558441558441555), (44, 0.97549783549783553), (45, 0.97558441558441555), (46, 0.97554112554112549), (47, 0.97606060606060607), (48, 0.97506493506493508), (49, 0.97597402597402594), (50, 0.97558441558441555), (51, 0.97571428571428576), (52, 0.97567099567099569), (53, 0.97545454545454546), (54, 0.97523809523809524), (55, 0.9754112554112554), (56, 0.97510822510822515), (57, 0.97502164502164501), (58, 0.97493506493506499), (59, 0.97519480519480517), (60, 0.97497835497835494), (61, 0.97523809523809524), (62, 0.97497835497835494), (63, 0.97480519480519479), (64, 0.97480519480519479), (65, 0.97480519480519479), (66, 0.97450216450216454), (67, 0.97437229437229433), (68, 0.9744155844155844), (69, 0.9744155844155844), (70, 0.97437229437229433), (71, 0.97415584415584411), (72, 0.97424242424242424), (73, 0.97389610389610393), (74, 0.97411255411255415), (75, 0.97376623376623372), (76, 0.97389610389610393), (77, 0.97411255411255415), (78, 0.97428571428571431), (79, 0.97380952380952379), (80, 0.97380952380952379), (81, 0.97350649350649354), (82, 0.97359307359307357), (83, 0.97363636363636363), (84, 0.97316017316017311), (85, 0.97320346320346318), (86, 0.97311688311688316), (87, 0.97350649350649354), (88, 0.97294372294372289), (89, 0.97307359307359309), (90, 0.97329004329004332), (91, 0.97303030303030302), (92, 0.9728138528138528), (93, 0.97264069264069264), (94, 0.9728138528138528), (95, 0.97272727272727277), (96, 0.97268398268398271), (97, 0.97268398268398271), (98, 0.97294372294372289), (99, 0.97242424242424241), (100, 0.97259740259740257), (101, 0.9725541125541125), (102, 0.97246753246753248), (103, 0.97238095238095235), (104, 0.97233766233766239), (105, 0.97264069264069264), (106, 0.97242424242424241), (107, 0.97233766233766239), (108, 0.97216450216450212), (109, 0.97229437229437232), (110, 0.97233766233766239), (111, 0.97251082251082255), (112, 0.97246753246753248), (113, 0.97233766233766239), (114, 0.97194805194805189), (115, 0.97238095238095235), (116, 0.97216450216450212), (117, 0.97203463203463203), (118, 0.97259740259740257), (119, 0.97238095238095235), (120, 0.97203463203463203), (121, 0.97212121212121216), (122, 0.97186147186147187), (123, 0.97212121212121216), (124, 0.9720779220779221), (125, 0.97212121212121216), (126, 0.97186147186147187), (127, 0.97212121212121216), (128, 0.9720779220779221), (129, 0.9718181818181818), (130, 0.97177489177489174), (131, 0.97177489177489174), (132, 0.97151515151515155), (133, 0.97168831168831171), (134, 0.97151515151515155), (135, 0.97147186147186149), (136, 0.9718181818181818), (137, 0.97142857142857142), (138, 0.97173160173160178), (139, 0.97147186147186149), (140, 0.97142857142857142), (141, 0.97129870129870133), (142, 0.97112554112554117), (143, 0.97125541125541126), (144, 0.97125541125541126), (145, 0.97164502164502164), (146, 0.97160173160173158), (147, 0.97116883116883113), (148, 0.97103896103896103), (149, 0.97077922077922074), (150, 0.97125541125541126), (151, 0.97142857142857142), (152, 0.97138528138528135), (153, 0.97086580086580088), (154, 0.97103896103896103), (155, 0.97116883116883113), (156, 0.97077922077922074), (157, 0.97125541125541126), (158, 0.9709523809523809), (159, 0.97060606060606058), (160, 0.97086580086580088), (161, 0.97090909090909094), (162, 0.97086580086580088), (163, 0.97077922077922074), (164, 0.97056277056277052), (165, 0.97064935064935065), (166, 0.97060606060606058), (167, 0.97073593073593079), (168, 0.97043290043290042), (169, 0.97060606060606058), (170, 0.97064935064935065), (171, 0.97060606060606058), (172, 0.97073593073593079), (173, 0.97056277056277052), (174, 0.97043290043290042), (175, 0.97077922077922074), (176, 0.9703463203463204), (177, 0.97069264069264072), (178, 0.97047619047619049), (179, 0.97073593073593079), (180, 0.9703463203463204), (181, 0.97060606060606058), (182, 0.97060606060606058), (183, 0.97043290043290042), (184, 0.97051948051948056), (185, 0.9709523809523809), (186, 0.97051948051948056), (187, 0.9703463203463204), (188, 0.97086580086580088), (189, 0.97051948051948056), (190, 0.97043290043290042), (191, 0.9702164502164502), (192, 0.97056277056277052), (193, 0.97086580086580088), (194, 0.97056277056277052), (195, 0.97064935064935065), (196, 0.97043290043290042), (197, 0.97060606060606058), (198, 0.97064935064935065), (199, 0.97069264069264072), (200, 0.97060606060606058)]
# 1 to 50
#result = [(1, 0.27493506493506492), (2, 0.415974025974026), (3, 0.48632034632034632), (4, 0.62329004329004334), (5, 0.73623376623376624), (6, 0.82766233766233765), (7, 0.87051948051948047), (8, 0.89783549783549788), (9, 0.91415584415584417), (10, 0.92926406926406924), (11, 0.93614718614718617), (12, 0.94454545454545458), (13, 0.95077922077922072), (14, 0.95692640692640696), (15, 0.95861471861471859), (16, 0.96324675324675324), (17, 0.96432900432900437), (18, 0.96645021645021645), (19, 0.96796536796536792), (20, 0.96991341991341995), (21, 0.9709523809523809), (22, 0.97177489177489174), (23, 0.97264069264069264), (24, 0.97277056277056273), (25, 0.97298701298701296), (26, 0.9736796536796537), (27, 0.97450216450216454), (28, 0.97428571428571431), (29, 0.97476190476190472), (30, 0.97545454545454546), (31, 0.97558441558441555), (32, 0.9751515151515151), (33, 0.97554112554112549), (34, 0.97523809523809524), (35, 0.97497835497835494), (36, 0.9754112554112554), (37, 0.97506493506493508), (38, 0.97580086580086578), (39, 0.9752813852813853), (40, 0.97497835497835494), (41, 0.97519480519480517), (42, 0.9761471861471861), (43, 0.97558441558441555), (44, 0.97549783549783553), (45, 0.97558441558441555), (46, 0.97554112554112549), (47, 0.97606060606060607), (48, 0.97506493506493508), (49, 0.97597402597402594), (50, 0.97558441558441555)]
util.plotBarGraph(map(lambda x : x, zip(*result)[0]), zip(*result)[1], "PCA on MNIST for KNN", "n_components", "KNN Score", "PCAKNN")
result.sort(key=lambda x: x[1], reverse=True)
print result
print "best principle component is", result[0]
xtrain, xtest = util.getPrincipleComponents(xtr, xte, n_components=50)
clf = KNeighborsClassifier()
clf = clf.fit(xtrain, ytr)
y_pred = clf.predict(xtest)
print "KNN score ", clf.score(xtest,yte) #0.975411255411
clf = DecisionTreeClassifier()
clf = clf.fit(xtrain, ytr)
y_pred = clf.predict(xtest)
print "DT score ", clf.score(xtest,yte) #0.834415584416
clf = GaussianNB()
clf = clf.fit(xtrain, ytr)
y_pred = clf.predict(xtest)
print "GaussianNB score ", clf.score(xtest,yte) #0.869090909091
clf = BernoulliNB()
clf = clf.fit(xtrain, ytr)
y_pred = clf.predict(xtest)
print "BernoulliNB score ", clf.score(xtest,yte) #0.727835497835
# {'kernel': 'rbf', 'C': 1000, 'verbose': False, 'probability': False, 'degree': 3,
# 'shrinking': True, 'max_iter': -1, 'random_state': None, 'tol': 0.001,
# 'cache_size': 200, 'coef0': 0.0, 'gamma': 0.001, 'class_weight': None}
clf = SVC(kernel = 'rbf', C=1000, gamma=0.001)
clf = clf.fit(xtrain, ytr)
y_pred = clf.predict(xtest)
print "SVC score ", clf.score(xtest,yte) # 0.980086580087