source: trunk/bin/train @ 9

Last change on this file since 9 was 9, checked in by tim, 9 years ago

.

  • Property svn:executable set to *
File size: 9.4 KB
Line 
1#!/usr/bin/env python3
2#-*- mode: Python;-*-
3
4import sys
5import os
6import time
7import random
8import statistics
9import functools
10import argparse
11import pprint
12import json
13
14
15VERSION = "{DEVELOPMENT}"
16if VERSION == "{DEVELOPMENT}":
17    script_dir = '.'
18    try:
19        script_dir = os.path.dirname(os.path.realpath(__file__))
20    except:
21        try:
22            script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
23        except:
24            pass
25    sys.path.append("%s/../lib" % script_dir)
26
27from nanownlib import *
28import nanownlib.storage
29from nanownlib.stats import boxTest,multiBoxTest,subsample,bootstrap,bootstrap2,trimean,midhinge,midhingeTest,samples2Distributions,samples2MeanDiffs
30from nanownlib.parallel import WorkerThreads
31
32
33parser = argparse.ArgumentParser(
34    description="")
35#parser.add_argument('-c', dest='cases', type=str, default='{"short":10000,"long":1010000}',
36#                    help='JSON representation of echo timing cases. Default: {"short":10000,"long":1010000}')
37parser.add_argument('session_data', default=None,
38                    help='Database file storing session information')
39options = parser.parse_args()
40
41           
42
43def trainBoxTest(db, unusual_case, greater, subseries_size):
44
45    def trainAux(low,high,num_trials):
46        estimator = functools.partial(multiBoxTest, {'low':low, 'high':high}, greater)
47        estimates = bootstrap3(estimator, db, 'train', unusual_case, subseries_size, num_trials)
48        null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, subseries_size, num_trials)
49
50        bad_estimates = len([e for e in estimates if e != 1])
51        bad_null_estimates = len([e for e in null_estimates if e != 0])
52       
53        false_negatives = 100.0*bad_estimates/num_trials
54        false_positives = 100.0*bad_null_estimates/num_trials
55        return false_positives,false_negatives
56
57    #start = time.time()
58    wt = WorkerThreads(2, trainAux)
59   
60    num_trials = 200
61    width = 1.0
62    performance = []
63    for low in range(0,50):
64        wt.addJob(low, (low,low+width,num_trials))
65    wt.wait()
66    while not wt.resultq.empty():
67        job_id,errors = wt.resultq.get()
68        fp,fn = errors
69        performance.append(((fp+fn)/2.0, job_id, fn, fp))
70    performance.sort()
71    #pprint.pprint(performance)
72    #print(time.time()-start)
73   
74    num_trials = 200
75    lows = [p[1] for p in performance[0:5]]
76    widths = [w/10.0 for w in range(5,65,5)]
77    performance = []
78    for width in widths:
79        false_positives = []
80        false_negatives = []
81        for low in lows:
82            wt.addJob(low,(low,low+width,num_trials))
83        wt.wait()
84        while not wt.resultq.empty():
85            job_id,errors = wt.resultq.get()
86            fp,fn = errors
87            false_negatives.append(fn)
88            false_positives.append(fp)
89
90        #print(width, false_negatives)
91        #print(width, false_positives)
92        #performance.append(((statistics.mean(false_positives)+statistics.mean(false_negatives))/2.0,
93        #                    width, statistics.mean(false_negatives), statistics.mean(false_positives)))
94        performance.append((abs(statistics.mean(false_positives)-statistics.mean(false_negatives)),
95                            width, statistics.mean(false_negatives), statistics.mean(false_positives)))
96    performance.sort()
97    #pprint.pprint(performance)
98    good_width = performance[0][1]
99    #print("good_width:",good_width)
100
101
102    num_trials = 500
103    performance = []
104    for low in lows:
105        wt.addJob(low, (low,low+good_width,num_trials))
106    wt.wait()
107    while not wt.resultq.empty():
108        job_id,errors = wt.resultq.get()
109        fp,fn = errors
110        performance.append(((fp+fn)/2.0, job_id, fn, fp))
111    performance.sort()
112    #pprint.pprint(performance)
113    best_low = performance[0][1]
114    #print("best_low:", best_low)
115
116   
117    num_trials = 500
118    widths = [good_width+(x/100.0) for x in range(-60,75,5) if good_width+(x/100.0) > 0.0]
119    performance = []
120    for width in widths:
121        wt.addJob(width, (best_low,best_low+width,num_trials))
122    wt.wait()
123    while not wt.resultq.empty():
124        job_id,errors = wt.resultq.get()
125        fp,fn = errors
126        performance.append(((fp+fn)/2.0, job_id, fn, fp))
127    performance.sort()
128    #pprint.pprint(performance)
129    best_width=performance[0][1]
130    #print("best_width:",best_width)
131    #print("final_performance:", performance[0][0])
132
133    wt.stop()
134    params = json.dumps({"low":best_low,"high":best_low+best_width})
135    return {'algorithm':"boxtest",
136            'params':params,
137            'sample_size':subseries_size,
138            'num_trials':num_trials,
139            'trial_type':"train",
140            'false_positives':performance[0][3],
141            'false_negatives':performance[0][2]}
142
143
144def trainMidhinge(db, unusual_case, greater, subseries_size):
145
146    def trainAux(distance, threshold, num_trials):
147        estimator = functools.partial(midhingeTest, {'distance':distance,'threshold':threshold}, greater)
148        estimates = bootstrap3(estimator, db, 'train', unusual_case, subseries_size, num_trials)
149        null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, subseries_size, num_trials)
150
151        bad_estimates = len([e for e in estimates if e != 1])
152        bad_null_estimates = len([e for e in null_estimates if e != 0])
153       
154        false_negatives = 100.0*bad_estimates/num_trials
155        false_positives = 100.0*bad_null_estimates/num_trials
156        return false_positives,false_negatives
157
158    #determine expected delta based on differences
159    mean_diffs = [s['unusual_case']-s['other_cases'] for s in db.subseries('train', unusual_case)]
160    threshold = trimean(mean_diffs)/2.0
161    #print("init_threshold:", threshold)
162   
163    wt = WorkerThreads(2, trainAux)
164   
165    num_trials = 500
166    performance = []
167    for distance in range(1,50):
168        wt.addJob(distance, (distance,threshold,num_trials))
169    wt.wait()
170    while not wt.resultq.empty():
171        job_id,errors = wt.resultq.get()
172        fp,fn = errors
173        performance.append(((fp+fn)/2.0, job_id, fn, fp))
174   
175    performance.sort()
176    #pprint.pprint(performance)
177    good_distance = performance[0][1]
178    #print("good_distance:",good_distance)
179
180   
181    num_trials = 500
182    performance = []
183    for t in range(50,154,4):
184        wt.addJob(threshold*(t/100.0), (good_distance,threshold*(t/100.0),num_trials))
185    wt.wait()
186    while not wt.resultq.empty():
187        job_id,errors = wt.resultq.get()
188        fp,fn = errors
189        performance.append(((fp+fn)/2.0, job_id, fn, fp))
190    performance.sort()
191    #pprint.pprint(performance)
192    good_threshold = performance[0][1]
193    #print("good_threshold:", good_threshold)
194
195   
196    num_trials = 500
197    performance = []
198    for d in [good_distance+s for s in range(-4,5) if good_distance+s > -1]:
199        wt.addJob(d, (d,good_threshold,num_trials))
200    wt.wait()
201    while not wt.resultq.empty():
202        job_id,errors = wt.resultq.get()
203        fp,fn = errors
204        performance.append(((fp+fn)/2.0, job_id, fn, fp))
205    performance.sort()
206    #pprint.pprint(performance)
207    best_distance = performance[0][1]
208    #print("best_distance:",best_distance)
209
210   
211    num_trials = 500
212    performance = []
213    for t in range(90,111):
214        wt.addJob(good_threshold*(t/100.0), (best_distance,good_threshold*(t/100.0),num_trials))
215    wt.wait()
216    while not wt.resultq.empty():
217        job_id,errors = wt.resultq.get()
218        fp,fn = errors
219        performance.append(((fp+fn)/2.0, job_id, fn, fp))
220    performance.sort()
221    #pprint.pprint(performance)
222    best_threshold = performance[0][1]
223    #print("best_threshold:", best_threshold)
224
225    wt.stop()
226    params = json.dumps({'distance':best_distance,'threshold':best_threshold})
227    return {'algorithm':"midhinge",
228            'params':params,
229            'sample_size':subseries_size,
230            'num_trials':num_trials,
231            'trial_type':"train",
232            'false_positives':performance[0][3],
233            'false_negatives':performance[0][2]}
234
235
236classifiers = {'boxtest':{'train':trainBoxTest, 'test':multiBoxTest},
237               'midhinge':{'train':trainMidhinge, 'test':midhinge}}
238
239
240db = nanownlib.storage.db(options.session_data)
241
242import cProfile
243
244def trainClassifier(db, unusual_case, greater, trainer):
245    threshold = 5.0 # in percent
246    size = 4000
247    result = None
248    while size < db.populationSize('train')/5:
249        size = min(size*2, int(db.populationSize('train')/5))
250        result = trainer(db,unusual_case,greater,size)
251        error = statistics.mean([result['false_positives'],result['false_negatives']])
252        print("subseries size: %d | error: %f | false_positives: %f | false_negatives: %f"
253              % (size,error,result['false_positives'],result['false_negatives']))
254        if error < threshold:
255            break
256    if result != None:
257        db.addClassifierResults(result)
258
259    return result
260
261
262start = time.time()
263unusual_case,unusual_diff = findUnusualTestCase(db)
264greater = (unusual_diff > 0)
265print("unusual_case:", unusual_case)
266print("unusual_diff:", unusual_diff)
267end = time.time()
268print(":", end-start)
269
270
271for c,funcs in classifiers.items():
272    start = time.time()
273    print("Training %s..." % c)
274    result = trainClassifier(db, unusual_case, greater, funcs['train'])
275    print("%s result:" % c)
276    pprint.pprint(result)
277    print("completed in:", time.time()-start)
278
279sys.exit(0)
280
281start = time.time()
282results = trainBoxTest(db, unusual_case, greater, 6000)
283#db.addClassifierResults(results)
284print("multi box test result:")
285pprint.pprint(results)
286print(":", time.time()-start)
Note: See TracBrowser for help on using the repository browser.