source: trunk/bin/train

Last change on this file was 16, checked in by tim, 9 years ago

.

  • Property svn:executable set to *
File size: 7.4 KB
Line 
1#!/usr/bin/env python3
2#-*- mode: Python;-*-
3
4import sys
5import os
6import time
7import random
8import statistics
9import functools
10import argparse
11import pprint
12import json
13
14
15VERSION = "{DEVELOPMENT}"
16if VERSION == "{DEVELOPMENT}":
17    script_dir = '.'
18    try:
19        script_dir = os.path.dirname(os.path.realpath(__file__))
20    except:
21        try:
22            script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
23        except:
24            pass
25    sys.path.append("%s/../lib" % script_dir)
26
27
28from nanownlib import *
29from nanownlib.stats import *
30from nanownlib.train import *
31from nanownlib.parallel import WorkerThreads
32import nanownlib.storage
33
34
35
36parser = argparse.ArgumentParser(
37    description="")
38#parser.add_argument('-c', dest='cases', type=str, default='{"short":10000,"long":1010000}',
39#                    help='JSON representation of echo timing cases. Default: {"short":10000,"long":1010000}')
40parser.add_argument('--unusual-case', action='store', default=None, help='Specify the unusual case and whether it is greater than the other cases.  Format: {case name},{1 or 0}')
41parser.add_argument('--retrain', action='append', default=[], help='Force a classifier to be retrained (and retested).  May be specified multiple times.')
42parser.add_argument('--retest', action='append', default=[], help='Force a classifier to be retested.  May be specified multiple times.')
43parser.add_argument('session_data', default=None,
44                    help='Database file storing session information')
45options = parser.parse_args()
46db = nanownlib.storage.db(options.session_data)
47
48
49
50def trainClassifier(db, unusual_case, greater, classifier, retrain=False):
51    if retrain:
52        print("Dropping stored training results...")
53        db.deleteClassifierResults(classifier, 'train')
54   
55    trainer = classifiers[classifier]['train']
56    threshold = 5.0 # in percent
57    num_obs = 7
58    max_obs = int(db.populationSize('train')/5)
59    result = None
60    while num_obs < max_obs:
61        num_obs = min(int(num_obs*1.5), max_obs)
62        result = db.fetchClassifierResult(classifier, 'train', num_obs)
63        if result != None:
64            train_time = "(stored)"
65        else:
66            start = time.time()
67            result = trainer(db,unusual_case,greater,num_obs)
68            result['classifier'] = classifier
69            train_time = "%8.2f" % (time.time()-start)
70           
71        error = statistics.mean([result['false_positives'],result['false_negatives']])
72        print("num. observations: %5d | error: %6.2f | fp: %6.2f | fn: %6.2f | train time: %s | params: %s"
73              % (num_obs, error, result['false_positives'],result['false_negatives'], train_time, result['params']))
74        db.addClassifierResult(result)
75        classifiers[classifier]['train_results'].append(result)
76
77        if error < threshold and num_obs > 100:
78            break
79
80    return result
81
82
83
84def testClassifier(db, unusual_case, greater, classifier, retest=False):
85    target_error = 5.0 # in percent
86    num_trials = 1000
87    max_obs = int(db.populationSize('test')/5)
88
89    tester = classifiers[classifier]['test']
90   
91    def testAux(params, num_trials, num_observations):
92        estimator = functools.partial(tester, params, greater)
93        estimates = bootstrap3(estimator, db, 'test', unusual_case, num_observations, num_trials)
94        null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, num_observations, num_trials)
95
96        bad_estimates = len([e for e in estimates if e != 1])
97        bad_null_estimates = len([e for e in null_estimates if e != 0])
98       
99        false_negatives = 100.0*bad_estimates/num_trials
100        false_positives = 100.0*bad_null_estimates/num_trials
101        return false_positives,false_negatives
102
103
104    def getResult(classifier, params, num_obs, num_trials):
105        jparams = json.dumps(params, sort_keys=True)
106        result = db.fetchClassifierResult(classifier, 'test', num_obs, jparams)
107        if result:
108            test_time = '(stored)'
109            fp = result['false_positives']
110            fn = result['false_negatives']
111        else:
112            start = time.time()
113            fp,fn = testAux(params, num_trials, num_obs)
114            result = {'classifier':classifier,
115                      'trial_type':"test",
116                      'num_observations':num_obs,
117                      'num_trials':num_trials,
118                      'params':jparams,
119                      'false_positives':fp,
120                      'false_negatives':fn}
121            db.addClassifierResult(result)
122            test_time = '%8.2f' % (time.time()-start)
123           
124        print("num. observations: %5d | error: %6.2f | fp: %6.2f | fn: %6.2f | test time: %s"
125              % (num_obs,(fp+fn)/2.0,fp,fn,test_time))
126        return ((fp+fn)/2.0,result)
127   
128    if retest:
129        print("Dropping stored test results...")
130        db.deleteClassifierResults(classifier, 'test')
131
132
133    lte = math.log(target_error/100.0)
134    for tr in classifiers[classifier]['train_results']:
135        db.resetOffsets()
136        params = json.loads(tr['params'])
137        num_obs = tr['num_observations']
138   
139        print("parameters:", params)
140        error,result = getResult(classifier,params,num_obs,num_trials)
141        #print("walking up")
142        while (error > target_error) and (num_obs < max_obs):
143            increase_factor = 1.5 * lte/math.log(error/100.0) # don't ask how I came up with this
144            #print("increase_factor:", increase_factor)
145            num_obs = min(int(increase_factor*num_obs), max_obs)
146            error,result = getResult(classifier,params,num_obs,num_trials)
147
148        #print("walking down")
149        while (num_obs > 0):
150            num_obs = int(0.95*num_obs)
151            error,result = getResult(classifier,params,num_obs,num_trials)
152            if error > target_error:
153                break
154   
155
156if options.unusual_case != None:
157    unusual_case,greater = options.unusual_case.split(',')
158    greater = bool(int(greater))
159    db.setUnusualCase(unusual_case,greater)
160else:
161    ucg = db.getUnusualCase()
162    if ucg != None:
163        unusual_case,greater = ucg
164        print("Using cached unusual_case:", unusual_case)
165    else:
166        unusual_case,delta = findUnusualTestCase(db)
167        greater = (delta > 0)
168        print("Auto-detected unusual_case '%s' with delta: %d" %  (unusual_case,delta))
169        db.setUnusualCase(unusual_case,greater)
170
171
172for c in sorted(classifiers.keys()):
173    if classifiers[c]['train'] == None:
174        continue
175    start = time.time()
176    print("Training %s..." % c)
177    result = trainClassifier(db, unusual_case, greater, c, c in options.retrain)
178    #print("%s result:" % c)
179    #pprint.pprint(result)
180    print("completed in: %8.2f\n"% (time.time()-start))
181
182db.clearCache()
183
184for c in sorted(classifiers.keys()):
185    start = time.time()
186    print("Testing %s..." % c)
187    testClassifier(db, unusual_case, greater, c, c in (options.retest+options.retrain))
188    print("completed in: %8.2f\n"% (time.time()-start))
189
190
191best_obs,best_error = evaluateTestResults(db)
192best_obs =   sorted(best_obs,   key=lambda x: x['num_observations'])
193best_error = sorted(best_error, key=lambda x: x['error'])
194winner = None
195for bo in best_obs:
196    sys.stdout.write("%(num_observations)5d obs   | %(classifier)12s | %(params)s" % bo)
197    if winner == None:
198        sys.stdout.write(" (winner)")
199        winner = bo
200    print()
201       
202for be in best_error:
203    sys.stdout.write("%(error)3.2f%% error | %(classifier)12s | %(params)s" % be)
204    if winner == None:
205        sys.stdout.write(" (winner)")
206        winner = be
207    print()
Note: See TracBrowser for help on using the repository browser.