source: trunk/bin/train @ 9

Last change on this file since 9 was 9, checked in by tim, 9 years ago

.

  • Property svn:executable set to *
File size: 9.4 KB
RevLine 
[4]1#!/usr/bin/env python3
2#-*- mode: Python;-*-
3
4import sys
5import os
6import time
7import random
8import statistics
9import functools
10import argparse
11import pprint
12import json
13
14
15VERSION = "{DEVELOPMENT}"
16if VERSION == "{DEVELOPMENT}":
17    script_dir = '.'
18    try:
19        script_dir = os.path.dirname(os.path.realpath(__file__))
20    except:
21        try:
22            script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
23        except:
24            pass
25    sys.path.append("%s/../lib" % script_dir)
26
27from nanownlib import *
28import nanownlib.storage
29from nanownlib.stats import boxTest,multiBoxTest,subsample,bootstrap,bootstrap2,trimean,midhinge,midhingeTest,samples2Distributions,samples2MeanDiffs
[9]30from nanownlib.parallel import WorkerThreads
[4]31
[9]32
[4]33parser = argparse.ArgumentParser(
34    description="")
35#parser.add_argument('-c', dest='cases', type=str, default='{"short":10000,"long":1010000}',
36#                    help='JSON representation of echo timing cases. Default: {"short":10000,"long":1010000}')
37parser.add_argument('session_data', default=None,
38                    help='Database file storing session information')
39options = parser.parse_args()
40
41           
42
[8]43def trainBoxTest(db, unusual_case, greater, subseries_size):
[4]44
45    def trainAux(low,high,num_trials):
[8]46        estimator = functools.partial(multiBoxTest, {'low':low, 'high':high}, greater)
47        estimates = bootstrap3(estimator, db, 'train', unusual_case, subseries_size, num_trials)
48        null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, subseries_size, num_trials)
[4]49
50        bad_estimates = len([e for e in estimates if e != 1])
51        bad_null_estimates = len([e for e in null_estimates if e != 0])
52       
53        false_negatives = 100.0*bad_estimates/num_trials
54        false_positives = 100.0*bad_null_estimates/num_trials
55        return false_positives,false_negatives
56
[9]57    #start = time.time()
[4]58    wt = WorkerThreads(2, trainAux)
59   
60    num_trials = 200
[8]61    width = 1.0
[4]62    performance = []
[8]63    for low in range(0,50):
[4]64        wt.addJob(low, (low,low+width,num_trials))
65    wt.wait()
66    while not wt.resultq.empty():
67        job_id,errors = wt.resultq.get()
68        fp,fn = errors
69        performance.append(((fp+fn)/2.0, job_id, fn, fp))
70    performance.sort()
[9]71    #pprint.pprint(performance)
72    #print(time.time()-start)
[4]73   
[8]74    num_trials = 200
[4]75    lows = [p[1] for p in performance[0:5]]
[8]76    widths = [w/10.0 for w in range(5,65,5)]
[4]77    performance = []
78    for width in widths:
79        false_positives = []
80        false_negatives = []
81        for low in lows:
82            wt.addJob(low,(low,low+width,num_trials))
83        wt.wait()
84        while not wt.resultq.empty():
85            job_id,errors = wt.resultq.get()
86            fp,fn = errors
87            false_negatives.append(fn)
88            false_positives.append(fp)
89
90        #print(width, false_negatives)
91        #print(width, false_positives)
[8]92        #performance.append(((statistics.mean(false_positives)+statistics.mean(false_negatives))/2.0,
93        #                    width, statistics.mean(false_negatives), statistics.mean(false_positives)))
94        performance.append((abs(statistics.mean(false_positives)-statistics.mean(false_negatives)),
[4]95                            width, statistics.mean(false_negatives), statistics.mean(false_positives)))
96    performance.sort()
[9]97    #pprint.pprint(performance)
[4]98    good_width = performance[0][1]
[9]99    #print("good_width:",good_width)
[4]100
101
[8]102    num_trials = 500
[4]103    performance = []
104    for low in lows:
105        wt.addJob(low, (low,low+good_width,num_trials))
106    wt.wait()
107    while not wt.resultq.empty():
108        job_id,errors = wt.resultq.get()
109        fp,fn = errors
110        performance.append(((fp+fn)/2.0, job_id, fn, fp))
111    performance.sort()
[9]112    #pprint.pprint(performance)
[4]113    best_low = performance[0][1]
[9]114    #print("best_low:", best_low)
[4]115
[8]116   
117    num_trials = 500
[9]118    widths = [good_width+(x/100.0) for x in range(-60,75,5) if good_width+(x/100.0) > 0.0]
[4]119    performance = []
120    for width in widths:
121        wt.addJob(width, (best_low,best_low+width,num_trials))
122    wt.wait()
123    while not wt.resultq.empty():
124        job_id,errors = wt.resultq.get()
125        fp,fn = errors
126        performance.append(((fp+fn)/2.0, job_id, fn, fp))
127    performance.sort()
[9]128    #pprint.pprint(performance)
[4]129    best_width=performance[0][1]
[9]130    #print("best_width:",best_width)
131    #print("final_performance:", performance[0][0])
132
133    wt.stop()
[8]134    params = json.dumps({"low":best_low,"high":best_low+best_width})
[4]135    return {'algorithm':"boxtest",
136            'params':params,
[8]137            'sample_size':subseries_size,
[4]138            'num_trials':num_trials,
139            'trial_type':"train",
140            'false_positives':performance[0][3],
141            'false_negatives':performance[0][2]}
142
143
[6]144def trainMidhinge(db, unusual_case, greater, subseries_size):
[4]145
146    def trainAux(distance, threshold, num_trials):
[6]147        estimator = functools.partial(midhingeTest, {'distance':distance,'threshold':threshold}, greater)
148        estimates = bootstrap3(estimator, db, 'train', unusual_case, subseries_size, num_trials)
149        null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, subseries_size, num_trials)
[4]150
151        bad_estimates = len([e for e in estimates if e != 1])
152        bad_null_estimates = len([e for e in null_estimates if e != 0])
153       
154        false_negatives = 100.0*bad_estimates/num_trials
155        false_positives = 100.0*bad_null_estimates/num_trials
156        return false_positives,false_negatives
157
158    #determine expected delta based on differences
[8]159    mean_diffs = [s['unusual_case']-s['other_cases'] for s in db.subseries('train', unusual_case)]
[4]160    threshold = trimean(mean_diffs)/2.0
[9]161    #print("init_threshold:", threshold)
[4]162   
[8]163    wt = WorkerThreads(2, trainAux)
[4]164   
[8]165    num_trials = 500
[4]166    performance = []
[8]167    for distance in range(1,50):
[4]168        wt.addJob(distance, (distance,threshold,num_trials))
169    wt.wait()
170    while not wt.resultq.empty():
171        job_id,errors = wt.resultq.get()
172        fp,fn = errors
173        performance.append(((fp+fn)/2.0, job_id, fn, fp))
[6]174   
[4]175    performance.sort()
[8]176    #pprint.pprint(performance)
[4]177    good_distance = performance[0][1]
[9]178    #print("good_distance:",good_distance)
[4]179
180   
[8]181    num_trials = 500
[4]182    performance = []
[8]183    for t in range(50,154,4):
[4]184        wt.addJob(threshold*(t/100.0), (good_distance,threshold*(t/100.0),num_trials))
185    wt.wait()
186    while not wt.resultq.empty():
187        job_id,errors = wt.resultq.get()
188        fp,fn = errors
189        performance.append(((fp+fn)/2.0, job_id, fn, fp))
190    performance.sort()
[8]191    #pprint.pprint(performance)
[4]192    good_threshold = performance[0][1]
[9]193    #print("good_threshold:", good_threshold)
[4]194
195   
[8]196    num_trials = 500
[4]197    performance = []
[8]198    for d in [good_distance+s for s in range(-4,5) if good_distance+s > -1]:
199        wt.addJob(d, (d,good_threshold,num_trials))
[4]200    wt.wait()
201    while not wt.resultq.empty():
202        job_id,errors = wt.resultq.get()
203        fp,fn = errors
204        performance.append(((fp+fn)/2.0, job_id, fn, fp))
205    performance.sort()
[8]206    #pprint.pprint(performance)
[4]207    best_distance = performance[0][1]
[9]208    #print("best_distance:",best_distance)
[8]209
[6]210   
[8]211    num_trials = 500
[4]212    performance = []
[9]213    for t in range(90,111):
[4]214        wt.addJob(good_threshold*(t/100.0), (best_distance,good_threshold*(t/100.0),num_trials))
215    wt.wait()
216    while not wt.resultq.empty():
217        job_id,errors = wt.resultq.get()
218        fp,fn = errors
219        performance.append(((fp+fn)/2.0, job_id, fn, fp))
220    performance.sort()
[8]221    #pprint.pprint(performance)
[4]222    best_threshold = performance[0][1]
[9]223    #print("best_threshold:", best_threshold)
[4]224
[9]225    wt.stop()
[4]226    params = json.dumps({'distance':best_distance,'threshold':best_threshold})
227    return {'algorithm':"midhinge",
228            'params':params,
[6]229            'sample_size':subseries_size,
[4]230            'num_trials':num_trials,
231            'trial_type':"train",
232            'false_positives':performance[0][3],
233            'false_negatives':performance[0][2]}
234
235
[9]236classifiers = {'boxtest':{'train':trainBoxTest, 'test':multiBoxTest},
237               'midhinge':{'train':trainMidhinge, 'test':midhinge}}
[4]238
239
240db = nanownlib.storage.db(options.session_data)
241
[9]242import cProfile
243
244def trainClassifier(db, unusual_case, greater, trainer):
245    threshold = 5.0 # in percent
246    size = 4000
247    result = None
248    while size < db.populationSize('train')/5:
249        size = min(size*2, int(db.populationSize('train')/5))
250        result = trainer(db,unusual_case,greater,size)
251        error = statistics.mean([result['false_positives'],result['false_negatives']])
252        print("subseries size: %d | error: %f | false_positives: %f | false_negatives: %f"
253              % (size,error,result['false_positives'],result['false_negatives']))
254        if error < threshold:
255            break
256    if result != None:
257        db.addClassifierResults(result)
258
259    return result
260
261
[4]262start = time.time()
263unusual_case,unusual_diff = findUnusualTestCase(db)
264greater = (unusual_diff > 0)
265print("unusual_case:", unusual_case)
266print("unusual_diff:", unusual_diff)
267end = time.time()
268print(":", end-start)
269
[6]270
[9]271for c,funcs in classifiers.items():
[8]272    start = time.time()
[9]273    print("Training %s..." % c)
274    result = trainClassifier(db, unusual_case, greater, funcs['train'])
275    print("%s result:" % c)
276    pprint.pprint(result)
277    print("completed in:", time.time()-start)
[4]278
[8]279sys.exit(0)
[4]280
[8]281start = time.time()
282results = trainBoxTest(db, unusual_case, greater, 6000)
283#db.addClassifierResults(results)
284print("multi box test result:")
285pprint.pprint(results)
286print(":", time.time()-start)
Note: See TracBrowser for help on using the repository browser.