source: trunk/bin/train @ 8

Last change on this file since 8 was 8, checked in by tim, 9 years ago

.

  • Property svn:executable set to *
File size: 10.0 KB
Line 
1#!/usr/bin/env python3
2#-*- mode: Python;-*-
3
4import sys
5import os
6import time
7import random
8import statistics
9import functools
10import argparse
11import threading
12import queue
13import pprint
14import json
15
16
17VERSION = "{DEVELOPMENT}"
18if VERSION == "{DEVELOPMENT}":
19    script_dir = '.'
20    try:
21        script_dir = os.path.dirname(os.path.realpath(__file__))
22    except:
23        try:
24            script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
25        except:
26            pass
27    sys.path.append("%s/../lib" % script_dir)
28
29from nanownlib import *
30import nanownlib.storage
31from nanownlib.stats import boxTest,multiBoxTest,subsample,bootstrap,bootstrap2,trimean,midhinge,midhingeTest,samples2Distributions,samples2MeanDiffs
32
33parser = argparse.ArgumentParser(
34    description="")
35#parser.add_argument('-c', dest='cases', type=str, default='{"short":10000,"long":1010000}',
36#                    help='JSON representation of echo timing cases. Default: {"short":10000,"long":1010000}')
37parser.add_argument('session_data', default=None,
38                    help='Database file storing session information')
39options = parser.parse_args()
40
41
42
43class WorkerThreads(object):
44    workq = None
45    resultq = None
46    target = None
47   
48    def __init__(self, num_workers, target):
49        self.workq = queue.Queue()
50        self.resultq = queue.Queue()
51        self.target = target
52       
53        self.workers = []
54        for i in range(num_workers):
55            t = threading.Thread(target=self._worker)
56            t.daemon = True
57            t.start()
58            self.workers.append(t)
59
60    def _worker(self):
61        while True:
62            item = self.workq.get()
63            if item == None:
64                self.workq.task_done()
65                break
66
67            job_id,args = item
68            self.resultq.put((job_id, self.target(*args)))
69            self.workq.task_done()
70
71    def addJob(self, job_id, args):
72        self.workq.put((job_id, args))
73           
74    def wait(self):
75        self.workq.join()
76
77    def stop(self):
78        for i in range(0,len(workers)):
79            self.workq.put(None)
80        for w in self.workers:
81            w.join()
82
83           
84
85def trainBoxTest(db, unusual_case, greater, subseries_size):
86
87    def trainAux(low,high,num_trials):
88        estimator = functools.partial(multiBoxTest, {'low':low, 'high':high}, greater)
89        estimates = bootstrap3(estimator, db, 'train', unusual_case, subseries_size, num_trials)
90        null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, subseries_size, num_trials)
91
92        bad_estimates = len([e for e in estimates if e != 1])
93        bad_null_estimates = len([e for e in null_estimates if e != 0])
94       
95        false_negatives = 100.0*bad_estimates/num_trials
96        false_positives = 100.0*bad_null_estimates/num_trials
97        return false_positives,false_negatives
98
99    start = time.time()
100    wt = WorkerThreads(2, trainAux)
101   
102    num_trials = 200
103    width = 1.0
104    performance = []
105    for low in range(0,50):
106        wt.addJob(low, (low,low+width,num_trials))
107    wt.wait()
108    while not wt.resultq.empty():
109        job_id,errors = wt.resultq.get()
110        fp,fn = errors
111        performance.append(((fp+fn)/2.0, job_id, fn, fp))
112    performance.sort()
113    pprint.pprint(performance)
114    print(time.time()-start)
115   
116    num_trials = 200
117    lows = [p[1] for p in performance[0:5]]
118    widths = [w/10.0 for w in range(5,65,5)]
119    performance = []
120    for width in widths:
121        false_positives = []
122        false_negatives = []
123        for low in lows:
124            wt.addJob(low,(low,low+width,num_trials))
125        wt.wait()
126        while not wt.resultq.empty():
127            job_id,errors = wt.resultq.get()
128            fp,fn = errors
129            false_negatives.append(fn)
130            false_positives.append(fp)
131
132        #print(width, false_negatives)
133        #print(width, false_positives)
134        #performance.append(((statistics.mean(false_positives)+statistics.mean(false_negatives))/2.0,
135        #                    width, statistics.mean(false_negatives), statistics.mean(false_positives)))
136        performance.append((abs(statistics.mean(false_positives)-statistics.mean(false_negatives)),
137                            width, statistics.mean(false_negatives), statistics.mean(false_positives)))
138    performance.sort()
139    pprint.pprint(performance)
140    good_width = performance[0][1]
141    print("good_width:",good_width)
142
143
144    num_trials = 500
145    performance = []
146    for low in lows:
147        wt.addJob(low, (low,low+good_width,num_trials))
148    wt.wait()
149    while not wt.resultq.empty():
150        job_id,errors = wt.resultq.get()
151        fp,fn = errors
152        performance.append(((fp+fn)/2.0, job_id, fn, fp))
153    performance.sort()
154    pprint.pprint(performance)
155    best_low = performance[0][1]
156    print("best_low:", best_low)
157
158   
159    num_trials = 500
160    widths = [good_width+(x/10.0) for x in range(-6,7) if good_width+(x/10.0) > 0.0]
161    performance = []
162    for width in widths:
163        wt.addJob(width, (best_low,best_low+width,num_trials))
164    wt.wait()
165    while not wt.resultq.empty():
166        job_id,errors = wt.resultq.get()
167        fp,fn = errors
168        performance.append(((fp+fn)/2.0, job_id, fn, fp))
169    performance.sort()
170    pprint.pprint(performance)
171    best_width=performance[0][1]
172    print("best_width:",best_width)
173    print("final_performance:", performance[0][0])
174   
175    params = json.dumps({"low":best_low,"high":best_low+best_width})
176    return {'algorithm':"boxtest",
177            'params':params,
178            'sample_size':subseries_size,
179            'num_trials':num_trials,
180            'trial_type':"train",
181            'false_positives':performance[0][3],
182            'false_negatives':performance[0][2]}
183
184
185def trainMidhinge(db, unusual_case, greater, subseries_size):
186
187    def trainAux(distance, threshold, num_trials):
188        estimator = functools.partial(midhingeTest, {'distance':distance,'threshold':threshold}, greater)
189        estimates = bootstrap3(estimator, db, 'train', unusual_case, subseries_size, num_trials)
190        null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, subseries_size, num_trials)
191
192        bad_estimates = len([e for e in estimates if e != 1])
193        bad_null_estimates = len([e for e in null_estimates if e != 0])
194       
195        false_negatives = 100.0*bad_estimates/num_trials
196        false_positives = 100.0*bad_null_estimates/num_trials
197        return false_positives,false_negatives
198
199    #determine expected delta based on differences
200    mean_diffs = [s['unusual_case']-s['other_cases'] for s in db.subseries('train', unusual_case)]
201    threshold = trimean(mean_diffs)/2.0
202    print("init_threshold:", threshold)
203   
204    wt = WorkerThreads(2, trainAux)
205   
206    num_trials = 500
207    performance = []
208    for distance in range(1,50):
209        wt.addJob(distance, (distance,threshold,num_trials))
210    wt.wait()
211    while not wt.resultq.empty():
212        job_id,errors = wt.resultq.get()
213        fp,fn = errors
214        performance.append(((fp+fn)/2.0, job_id, fn, fp))
215   
216    performance.sort()
217    #pprint.pprint(performance)
218    good_distance = performance[0][1]
219    print("good_distance:",good_distance)
220
221   
222    num_trials = 500
223    performance = []
224    for t in range(50,154,4):
225        wt.addJob(threshold*(t/100.0), (good_distance,threshold*(t/100.0),num_trials))
226    wt.wait()
227    while not wt.resultq.empty():
228        job_id,errors = wt.resultq.get()
229        fp,fn = errors
230        performance.append(((fp+fn)/2.0, job_id, fn, fp))
231    performance.sort()
232    #pprint.pprint(performance)
233    good_threshold = performance[0][1]
234    print("good_threshold:", good_threshold)
235
236   
237    num_trials = 500
238    performance = []
239    for d in [good_distance+s for s in range(-4,5) if good_distance+s > -1]:
240        wt.addJob(d, (d,good_threshold,num_trials))
241    wt.wait()
242    while not wt.resultq.empty():
243        job_id,errors = wt.resultq.get()
244        fp,fn = errors
245        performance.append(((fp+fn)/2.0, job_id, fn, fp))
246    performance.sort()
247    #pprint.pprint(performance)
248    best_distance = performance[0][1]
249    print("best_distance:",best_distance)
250
251   
252    num_trials = 500
253    performance = []
254    for t in range(95,106):
255        wt.addJob(good_threshold*(t/100.0), (best_distance,good_threshold*(t/100.0),num_trials))
256    wt.wait()
257    while not wt.resultq.empty():
258        job_id,errors = wt.resultq.get()
259        fp,fn = errors
260        performance.append(((fp+fn)/2.0, job_id, fn, fp))
261    performance.sort()
262    #pprint.pprint(performance)
263    best_threshold = performance[0][1]
264    print("best_threshold:", best_threshold)
265
266    params = json.dumps({'distance':best_distance,'threshold':best_threshold})
267    return {'algorithm':"midhinge",
268            'params':params,
269            'sample_size':subseries_size,
270            'num_trials':num_trials,
271            'trial_type':"train",
272            'false_positives':performance[0][3],
273            'false_negatives':performance[0][2]}
274
275
276#classifiers = {'boxtest':{'train':trainBoxTest2, 'test':multiBoxTest},
277#               'midhinge':{'train':trainMidhinge, 'test':midhinge}}
278
279
280db = nanownlib.storage.db(options.session_data)
281#cursor = db.cursor()
282#cursor.execute("SELECT min(sample) min, max(sample) max FROM probes")
283#train_start,test_end = cursor.fetchone()
284#train_end = int(test_end-train_start)
285#test_start = train_end+1
286#subsample_size = min(10000,(train_end-train_start+1)/4)
287
288start = time.time()
289unusual_case,unusual_diff = findUnusualTestCase(db)
290greater = (unusual_diff > 0)
291print("unusual_case:", unusual_case)
292print("unusual_diff:", unusual_diff)
293end = time.time()
294print(":", end-start)
295
296import cProfile
297
298
299
300
301for size in (500,1000,2000,4000,5000,6000):
302    start = time.time()
303    #cProfile.run('results = trainMidhinge(db, unusual_case, greater, 100)')
304    results = trainMidhinge(db, unusual_case, greater, size)
305    #db.addClassifierResults(results)
306    print("midhinge result:")
307    pprint.pprint(results)
308    print(":", time.time()-start)
309
310sys.exit(0)
311
312start = time.time()
313results = trainBoxTest(db, unusual_case, greater, 6000)
314#db.addClassifierResults(results)
315print("multi box test result:")
316pprint.pprint(results)
317print(":", time.time()-start)
Note: See TracBrowser for help on using the repository browser.