source: trunk/bin/train @ 11

Last change on this file since 11 was 11, checked in by tim, 9 years ago

.

  • Property svn:executable set to *
File size: 6.7 KB
Line 
1#!/usr/bin/env python3
2#-*- mode: Python;-*-
3
4import sys
5import os
6import time
7import random
8import statistics
9import functools
10import argparse
11import pprint
12import json
13
14
15VERSION = "{DEVELOPMENT}"
16if VERSION == "{DEVELOPMENT}":
17    script_dir = '.'
18    try:
19        script_dir = os.path.dirname(os.path.realpath(__file__))
20    except:
21        try:
22            script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
23        except:
24            pass
25    sys.path.append("%s/../lib" % script_dir)
26
27
28from nanownlib import *
29from nanownlib.stats import *
30from nanownlib.train import *
31from nanownlib.parallel import WorkerThreads
32import nanownlib.storage
33
34
35
36parser = argparse.ArgumentParser(
37    description="")
38#parser.add_argument('-c', dest='cases', type=str, default='{"short":10000,"long":1010000}',
39#                    help='JSON representation of echo timing cases. Default: {"short":10000,"long":1010000}')
40parser.add_argument('--unusual-case', action='store', default=None, help='Specify the unusual case and whether it is greater than the other cases.  Format: {case name},{1 or 0}')
41parser.add_argument('--retrain', action='append', default=[], help='Force a classifier to be retrained (and retested).  May be specified multiple times.')
42parser.add_argument('--retest', action='append', default=[], help='Force a classifier to be retested.  May be specified multiple times.')
43parser.add_argument('session_data', default=None,
44                    help='Database file storing session information')
45options = parser.parse_args()
46db = nanownlib.storage.db(options.session_data)
47
48
49
50def trainClassifier(db, unusual_case, greater, classifier, retrain=False):
51    if retrain:
52        print("Dropping stored training results...")
53        db.deleteClassifierResults(classifier, 'train')
54   
55    trainer = classifiers[classifier]['train']
56    threshold = 5.0 # in percent
57    num_obs = 1000
58    max_obs = int(db.populationSize('train')/5)
59    result = None
60    while num_obs < max_obs:
61        num_obs = min(int(num_obs*1.5), max_obs)
62        result = db.fetchClassifierResult(classifier, 'train', num_obs)
63        if result != None:
64            train_time = "(stored)"
65        else:
66            start = time.time()
67            result = trainer(db,unusual_case,greater,num_obs)
68            result['classifier'] = classifier
69            train_time = "%f" % (time.time()-start)
70           
71        error = statistics.mean([result['false_positives'],result['false_negatives']])
72        print("number of observations: %d | error: %f | false_positives: %f | false_negatives: %f | train time: %s | params: %s"
73              % (num_obs, error, result['false_positives'],result['false_negatives'], train_time, result['params']))
74        db.addClassifierResult(result)
75        classifiers[classifier]['train_results'].append(result)
76
77        if error < threshold:
78            break
79
80    return result
81
82
83
84def testClassifier(db, unusual_case, greater, classifier, retest=False):
85    target_error = 5.0 # in percent
86    num_trials = 1000
87    max_obs = int(db.populationSize('test')/5)
88
89    tester = classifiers[classifier]['test']
90   
91    def testAux(params, num_trials, num_observations):
92        estimator = functools.partial(tester, params, greater)
93        estimates = bootstrap3(estimator, db, 'test', unusual_case, num_observations, num_trials)
94        null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, num_observations, num_trials)
95
96        bad_estimates = len([e for e in estimates if e != 1])
97        bad_null_estimates = len([e for e in null_estimates if e != 0])
98       
99        false_negatives = 100.0*bad_estimates/num_trials
100        false_positives = 100.0*bad_null_estimates/num_trials
101        print("testAux:", num_observations, false_positives, false_negatives, params)
102        return false_positives,false_negatives
103
104
105    def getResult(classifier, params, num_obs, num_trials):
106        jparams = json.dumps(params, sort_keys=True)
107        result = db.fetchClassifierResult(classifier, 'test', num_obs, jparams)
108        if result:
109            fp = result['false_positives']
110            fn = result['false_negatives']
111        else:
112            fp,fn = testAux(params, num_trials, num_obs)
113            result = {'classifier':classifier,
114                      'trial_type':"test",
115                      'num_observations':num_obs,
116                      'num_trials':num_trials,
117                      'params':jparams,
118                      'false_positives':fp,
119                      'false_negatives':fn}
120            db.addClassifierResult(result)
121        return ((fp+fn)/2.0,result)
122   
123    if retest:
124        print("Dropping stored test results...")
125        db.deleteClassifierResults(classifier, 'test')
126
127
128    test_results = []
129    lte = math.log(target_error/100.0)
130    for tr in classifiers[classifier]['train_results']:
131        db.resetOffsets()
132        params = json.loads(tr['params'])
133        num_obs = tr['num_observations']
134   
135        print("initial test")
136        error,result = getResult(classifier,params,num_obs,num_trials)
137        print("walking up")
138        while (error > target_error) and (num_obs < max_obs):
139            increase_factor = 1.5 * lte/math.log(error/100.0) # don't ask how I came up with this
140            #print("increase_factor:", increase_factor)
141            num_obs = min(int(increase_factor*num_obs), max_obs)
142            error,result = getResult(classifier,params,num_obs,num_trials)
143
144        print("walking down")
145        while (num_obs > 0):
146            current_best = (error,result)
147            num_obs = int(0.95*num_obs)
148            error,result = getResult(classifier,params,num_obs,num_trials)
149            if error > target_error:
150                break
151       
152    return current_best
153
154
155if options.unusual_case != None:
156    unusual_case,greater = options.unusual_case.split(',')
157    greater = bool(int(greater))
158else:
159    start = time.time()
160    unusual_case,unusual_diff = findUnusualTestCase(db)
161    greater = (unusual_diff > 0)
162    print("unusual_case:", unusual_case)
163    print("unusual_diff:", unusual_diff)
164    end = time.time()
165    print(":", end-start)
166
167
168for c in sorted(classifiers.keys()):
169    if classifiers[c]['train'] == None:
170        continue
171    start = time.time()
172    print("Training %s..." % c)
173    result = trainClassifier(db, unusual_case, greater, c, c in options.retrain)
174    print("%s result:" % c)
175    pprint.pprint(result)
176    print("completed in:", time.time()-start)
177
178db.clearCache()
179
180for c in sorted(classifiers.keys()):
181    start = time.time()
182    print("Testing %s..." % c)
183    error,result = testClassifier(db, unusual_case, greater, c, c in (options.retest+options.retrain))
184    print("%s result:" % c)
185    pprint.pprint(result)
186    classifiers[c]['test_error'] = error
187    print("completed in:", time.time()-start)
Note: See TracBrowser for help on using the repository browser.