source: trunk/bin/train @ 10

Last change on this file since 10 was 10, checked in by tim, 9 years ago

.

  • Property svn:executable set to *
File size: 15.4 KB
RevLine 
[4]1#!/usr/bin/env python3
2#-*- mode: Python;-*-
3
4import sys
5import os
6import time
7import random
8import statistics
9import functools
10import argparse
11import pprint
12import json
13
14
15VERSION = "{DEVELOPMENT}"
16if VERSION == "{DEVELOPMENT}":
17    script_dir = '.'
18    try:
19        script_dir = os.path.dirname(os.path.realpath(__file__))
20    except:
21        try:
22            script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
23        except:
24            pass
25    sys.path.append("%s/../lib" % script_dir)
26
27from nanownlib import *
[10]28from nanownlib.stats import *
29from nanownlib.parallel import WorkerThreads
[4]30import nanownlib.storage
31
[9]32
[10]33
[4]34parser = argparse.ArgumentParser(
35    description="")
36#parser.add_argument('-c', dest='cases', type=str, default='{"short":10000,"long":1010000}',
37#                    help='JSON representation of echo timing cases. Default: {"short":10000,"long":1010000}')
[10]38parser.add_argument('--retrain', action='append', default=[], help='Force a classifier to be retrained.  May be specified multiple times.')
39parser.add_argument('--retest', action='append', default=[], help='Force a classifier to be retested.  May be specified multiple times.')
[4]40parser.add_argument('session_data', default=None,
41                    help='Database file storing session information')
42options = parser.parse_args()
43
44
[10]45def trainBoxTest(db, unusual_case, greater, num_observations):
46    db.resetOffsets()
47   
[4]48    def trainAux(low,high,num_trials):
[8]49        estimator = functools.partial(multiBoxTest, {'low':low, 'high':high}, greater)
[10]50        estimates = bootstrap3(estimator, db, 'train', unusual_case, num_observations, num_trials)
51        null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, num_observations, num_trials)
[4]52
53        bad_estimates = len([e for e in estimates if e != 1])
54        bad_null_estimates = len([e for e in null_estimates if e != 0])
55       
56        false_negatives = 100.0*bad_estimates/num_trials
57        false_positives = 100.0*bad_null_estimates/num_trials
58        return false_positives,false_negatives
59
[9]60    #start = time.time()
[4]61    wt = WorkerThreads(2, trainAux)
62   
63    num_trials = 200
[8]64    width = 1.0
[4]65    performance = []
[8]66    for low in range(0,50):
[4]67        wt.addJob(low, (low,low+width,num_trials))
68    wt.wait()
69    while not wt.resultq.empty():
70        job_id,errors = wt.resultq.get()
71        fp,fn = errors
72        performance.append(((fp+fn)/2.0, job_id, fn, fp))
73    performance.sort()
[9]74    #pprint.pprint(performance)
75    #print(time.time()-start)
[4]76   
[8]77    num_trials = 200
[4]78    lows = [p[1] for p in performance[0:5]]
[8]79    widths = [w/10.0 for w in range(5,65,5)]
[4]80    performance = []
81    for width in widths:
82        false_positives = []
83        false_negatives = []
84        for low in lows:
85            wt.addJob(low,(low,low+width,num_trials))
86        wt.wait()
87        while not wt.resultq.empty():
88            job_id,errors = wt.resultq.get()
89            fp,fn = errors
90            false_negatives.append(fn)
91            false_positives.append(fp)
92
93        #print(width, false_negatives)
94        #print(width, false_positives)
[8]95        #performance.append(((statistics.mean(false_positives)+statistics.mean(false_negatives))/2.0,
96        #                    width, statistics.mean(false_negatives), statistics.mean(false_positives)))
97        performance.append((abs(statistics.mean(false_positives)-statistics.mean(false_negatives)),
[4]98                            width, statistics.mean(false_negatives), statistics.mean(false_positives)))
99    performance.sort()
[9]100    #pprint.pprint(performance)
[4]101    good_width = performance[0][1]
[9]102    #print("good_width:",good_width)
[4]103
104
[8]105    num_trials = 500
[4]106    performance = []
107    for low in lows:
108        wt.addJob(low, (low,low+good_width,num_trials))
109    wt.wait()
110    while not wt.resultq.empty():
111        job_id,errors = wt.resultq.get()
112        fp,fn = errors
113        performance.append(((fp+fn)/2.0, job_id, fn, fp))
114    performance.sort()
[9]115    #pprint.pprint(performance)
[4]116    best_low = performance[0][1]
[9]117    #print("best_low:", best_low)
[4]118
[8]119   
120    num_trials = 500
[10]121    widths = [good_width+(x/100.0) for x in range(-70,75,5) if good_width+(x/100.0) > 0.0]
[4]122    performance = []
123    for width in widths:
124        wt.addJob(width, (best_low,best_low+width,num_trials))
125    wt.wait()
126    while not wt.resultq.empty():
127        job_id,errors = wt.resultq.get()
128        fp,fn = errors
[10]129        #performance.append(((fp+fn)/2.0, job_id, fn, fp))
130        performance.append((abs(fp-fn), job_id, fn, fp))
[4]131    performance.sort()
[9]132    #pprint.pprint(performance)
[4]133    best_width=performance[0][1]
[9]134    #print("best_width:",best_width)
135    #print("final_performance:", performance[0][0])
136
137    wt.stop()
[8]138    params = json.dumps({"low":best_low,"high":best_low+best_width})
[10]139    return {'trial_type':"train",
140            'num_observations':num_observations,
141            'num_trials':num_trials,
[4]142            'params':params,
143            'false_positives':performance[0][3],
144            'false_negatives':performance[0][2]}
145
146
[10]147def trainSummary(summaryFunc, db, unusual_case, greater, num_observations):
148    db.resetOffsets()
149    stest = functools.partial(summaryTest, summaryFunc)
150   
[4]151    def trainAux(distance, threshold, num_trials):
[10]152        estimator = functools.partial(stest, {'distance':distance,'threshold':threshold}, greater)
153        estimates = bootstrap3(estimator, db, 'train', unusual_case, num_observations, num_trials)
154        null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, num_observations, num_trials)
[4]155
156        bad_estimates = len([e for e in estimates if e != 1])
157        bad_null_estimates = len([e for e in null_estimates if e != 0])
158       
159        false_negatives = 100.0*bad_estimates/num_trials
160        false_positives = 100.0*bad_null_estimates/num_trials
161        return false_positives,false_negatives
162
163    #determine expected delta based on differences
[8]164    mean_diffs = [s['unusual_case']-s['other_cases'] for s in db.subseries('train', unusual_case)]
[10]165    threshold = summaryFunc(mean_diffs)/2.0
[9]166    #print("init_threshold:", threshold)
[4]167   
[8]168    wt = WorkerThreads(2, trainAux)
[4]169   
[8]170    num_trials = 500
[4]171    performance = []
[8]172    for distance in range(1,50):
[4]173        wt.addJob(distance, (distance,threshold,num_trials))
174    wt.wait()
175    while not wt.resultq.empty():
176        job_id,errors = wt.resultq.get()
177        fp,fn = errors
178        performance.append(((fp+fn)/2.0, job_id, fn, fp))
[6]179   
[4]180    performance.sort()
[8]181    #pprint.pprint(performance)
[4]182    good_distance = performance[0][1]
[9]183    #print("good_distance:",good_distance)
[4]184
185   
[8]186    num_trials = 500
[4]187    performance = []
[10]188    for t in range(80,122,2):
[4]189        wt.addJob(threshold*(t/100.0), (good_distance,threshold*(t/100.0),num_trials))
190    wt.wait()
191    while not wt.resultq.empty():
192        job_id,errors = wt.resultq.get()
193        fp,fn = errors
[10]194        #performance.append(((fp+fn)/2.0, job_id, fn, fp))
195        performance.append((abs(fp-fn), job_id, fn, fp))
[4]196    performance.sort()
[8]197    #pprint.pprint(performance)
[4]198    good_threshold = performance[0][1]
[9]199    #print("good_threshold:", good_threshold)
[4]200
201   
[8]202    num_trials = 500
[4]203    performance = []
[8]204    for d in [good_distance+s for s in range(-4,5) if good_distance+s > -1]:
205        wt.addJob(d, (d,good_threshold,num_trials))
[4]206    wt.wait()
207    while not wt.resultq.empty():
208        job_id,errors = wt.resultq.get()
209        fp,fn = errors
210        performance.append(((fp+fn)/2.0, job_id, fn, fp))
211    performance.sort()
[8]212    #pprint.pprint(performance)
[4]213    best_distance = performance[0][1]
[9]214    #print("best_distance:",best_distance)
[8]215
[6]216   
[8]217    num_trials = 500
[4]218    performance = []
[9]219    for t in range(90,111):
[4]220        wt.addJob(good_threshold*(t/100.0), (best_distance,good_threshold*(t/100.0),num_trials))
221    wt.wait()
222    while not wt.resultq.empty():
223        job_id,errors = wt.resultq.get()
224        fp,fn = errors
[10]225        #performance.append(((fp+fn)/2.0, job_id, fn, fp))
226        performance.append((abs(fp-fn), job_id, fn, fp))
[4]227    performance.sort()
[8]228    #pprint.pprint(performance)
[4]229    best_threshold = performance[0][1]
[9]230    #print("best_threshold:", best_threshold)
[4]231
[9]232    wt.stop()
[4]233    params = json.dumps({'distance':best_distance,'threshold':best_threshold})
[10]234    return {'trial_type':"train",
235            'num_observations':num_observations,
236            'num_trials':num_trials,
[4]237            'params':params,
[10]238            'false_positives':performance[0][3],
239            'false_negatives':performance[0][2]}
240
241
242def trainKalman(db, unusual_case, greater, num_observations):
243    db.resetOffsets()
244
245    def trainAux(params, num_trials):
246        estimator = functools.partial(kalmanTest, params, greater)
247        estimates = bootstrap3(estimator, db, 'train', unusual_case, num_observations, num_trials)
248        null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, num_observations, num_trials)
249       
250        bad_estimates = len([e for e in estimates if e != 1])
251        bad_null_estimates = len([e for e in null_estimates if e != 0])
252       
253        false_negatives = 100.0*bad_estimates/num_trials
254        false_positives = 100.0*bad_null_estimates/num_trials
255        return false_positives,false_negatives
256   
257    mean_diffs = [s['unusual_case']-s['other_cases'] for s in db.subseries('train', unusual_case)]
258    good_threshold = kfilter({},mean_diffs)['est'][-1]/2.0
259
260    wt = WorkerThreads(2, trainAux)
261    num_trials = 200
262    performance = []
263    for t in range(90,111):
264        params = {'threshold':good_threshold*(t/100.0)}
265        wt.addJob(good_threshold*(t/100.0), (params,num_trials))
266    wt.wait()
267    while not wt.resultq.empty():
268        job_id,errors = wt.resultq.get()
269        fp,fn = errors
270        #performance.append(((fp+fn)/2.0, job_id, fn, fp))
271        performance.append((abs(fp-fn), job_id, fn, fp))
272    performance.sort()
273    #pprint.pprint(performance)
274    best_threshold = performance[0][1]
275    #print("best_threshold:", best_threshold)
276    params = {'threshold':best_threshold}
277
278    wt.stop()
279   
280    return {'trial_type':"train",
281            'num_observations':num_observations,
[4]282            'num_trials':num_trials,
[10]283            'params':json.dumps(params),
[4]284            'false_positives':performance[0][3],
285            'false_negatives':performance[0][2]}
286
[10]287   
288    #determine expected delta based on differences
289classifiers = {'boxtest':{'train':trainBoxTest, 'test':multiBoxTest, 'train_results':[]},
290               'midsummary':{'train':functools.partial(trainSummary, midsummary), 'test':midsummaryTest, 'train_results':[]},
291               #'ubersummary':{'train':functools.partial(trainSummary, ubersummary), 'test':ubersummaryTest, 'train_results':[]},
292               'quadsummary':{'train':functools.partial(trainSummary, quadsummary), 'test':quadsummaryTest, 'train_results':[]},
293               'kalman':{'train':trainKalman, 'test':kalmanTest, 'train_results':[]},
294               #'_trimean':{'train':None, 'test':trimeanTest, 'train_results':[]},
295              }
[4]296
297
298db = nanownlib.storage.db(options.session_data)
299
[9]300import cProfile
301
[10]302def trainClassifier(db, unusual_case, greater, classifier, retrain=False):
303    if retrain:
304        print("Dropping stored training results...")
305        db.deleteClassifierResults(classifier, 'train')
306   
307    trainer = classifiers[classifier]['train']
[9]308    threshold = 5.0 # in percent
[10]309    num_obs = 1000
310    max_obs = int(db.populationSize('train')/5)
[9]311    result = None
[10]312    while num_obs < max_obs:
313        num_obs = min(int(num_obs*1.5), max_obs)
314        result = db.fetchClassifierResult(classifier, 'train', num_obs)
315        if result != None:
316            train_time = "(stored)"
317        else:
318            start = time.time()
319            result = trainer(db,unusual_case,greater,num_obs)
320            result['classifier'] = classifier
321            train_time = "%f" % (time.time()-start)
322           
[9]323        error = statistics.mean([result['false_positives'],result['false_negatives']])
[10]324        print("number of observations: %d | error: %f | false_positives: %f | false_negatives: %f | train time: %s | params: %s"
325              % (num_obs, error, result['false_positives'],result['false_negatives'], train_time, result['params']))
326        db.addClassifierResults(result)
327        classifiers[classifier]['train_results'].append(result)
328
[9]329        if error < threshold:
330            break
331
332    return result
333
334
[10]335
336def testClassifier(db, unusual_case, greater, classifier, retest=False):
337    target_error = 5.0 # in percent
338    num_trials = 1000
339    max_obs = int(db.populationSize('test')/5)
340
341    tester = classifiers[classifier]['test']
342   
343    def testAux(params, num_trials, num_observations):
344        estimator = functools.partial(tester, params, greater)
345        estimates = bootstrap3(estimator, db, 'test', unusual_case, num_observations, num_trials)
346        null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, num_observations, num_trials)
347
348        bad_estimates = len([e for e in estimates if e != 1])
349        bad_null_estimates = len([e for e in null_estimates if e != 0])
350       
351        false_negatives = 100.0*bad_estimates/num_trials
352        false_positives = 100.0*bad_null_estimates/num_trials
353        print("testAux:", num_observations, false_positives, false_negatives, params)
354        return false_positives,false_negatives
355
356
357    if retest:
358        print("Dropping stored test results...")
359        db.deleteClassifierResults(classifier, 'test')
360
361
362    test_results = []
363    lte = math.log(target_error/100.0)
364    for tr in classifiers[classifier]['train_results']:
365        db.resetOffsets()
366        params = json.loads(tr['params'])
367        num_obs = tr['num_observations']
368   
369        print("initial test")
370        fp,fn = testAux(params, num_trials, num_obs)
371        error = (fp+fn)/2.0
372        print("walking up")
373        while (error > target_error) and (num_obs < max_obs):
374            increase_factor = 1.5 * lte/math.log(error/100.0) # don't ask how I came up with this
375            #print("increase_factor:", increase_factor)
376            num_obs = min(int(increase_factor*num_obs), max_obs)
377            fp,fn = testAux(params, num_trials, num_obs)
378            error = (fp+fn)/2.0
379
380        print("walking down")
381        while (num_obs > 0):
382            current_best = (num_obs,error,params,fp,fn)
383            num_obs = int(0.95*num_obs)
384            fp,fn = testAux(params, num_trials, num_obs)
385            error = (fp+fn)/2.0
386            if error > target_error:
387                break
388       
389        test_results.append(current_best)
390
391    test_results.sort()
392    best_obs,error,best_params,fp,fn = test_results[0]
393   
394    return {'classifier':classifier,
395            'trial_type':"test",
396            'num_observations':best_obs,
397            'num_trials':num_trials,
398            'params':best_params,
399            'false_positives':fp,
400            'false_negatives':fn}
401
402
[4]403start = time.time()
404unusual_case,unusual_diff = findUnusualTestCase(db)
405greater = (unusual_diff > 0)
406print("unusual_case:", unusual_case)
407print("unusual_diff:", unusual_diff)
408end = time.time()
409print(":", end-start)
410
[10]411for c in sorted(classifiers.keys()):
412    if classifiers[c]['train'] == None:
413        continue
[8]414    start = time.time()
[9]415    print("Training %s..." % c)
[10]416    result = trainClassifier(db, unusual_case, greater, c, c in options.retrain)
[9]417    print("%s result:" % c)
418    pprint.pprint(result)
419    print("completed in:", time.time()-start)
[4]420
[10]421db.clearCache()
[4]422
[10]423for c in sorted(classifiers.keys()):
424    start = time.time()
425    print("Testing %s..." % c)
426    result = testClassifier(db, unusual_case, greater, c, c in options.retest)
427    print("%s result:" % c)
428    pprint.pprint(result)
429    classifiers[c]['test_error'] = (result['false_positives']+result['false_negatives'])/2.0
430    print("completed in:", time.time()-start)
Note: See TracBrowser for help on using the repository browser.