Changeset 11 for trunk/bin/train


Ignore:
Timestamp:
07/16/15 20:40:01 (9 years ago)
Author:
tim
Message:

.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/bin/train

    r10 r11  
    2525    sys.path.append("%s/../lib" % script_dir)
    2626
     27
    2728from nanownlib import *
    2829from nanownlib.stats import *
     30from nanownlib.train import *
    2931from nanownlib.parallel import WorkerThreads
    3032import nanownlib.storage
     
    3638#parser.add_argument('-c', dest='cases', type=str, default='{"short":10000,"long":1010000}',
    3739#                    help='JSON representation of echo timing cases. Default: {"short":10000,"long":1010000}')
    38 parser.add_argument('--retrain', action='append', default=[], help='Force a classifier to be retrained.  May be specified multiple times.')
     40parser.add_argument('--unusual-case', action='store', default=None, help='Specify the unusual case and whether it is greater than the other cases.  Format: {case name},{1 or 0}')
     41parser.add_argument('--retrain', action='append', default=[], help='Force a classifier to be retrained (and retested).  May be specified multiple times.')
    3942parser.add_argument('--retest', action='append', default=[], help='Force a classifier to be retested.  May be specified multiple times.')
    4043parser.add_argument('session_data', default=None,
    4144                    help='Database file storing session information')
    4245options = parser.parse_args()
     46db = nanownlib.storage.db(options.session_data)
    4347
    4448
    45 def trainBoxTest(db, unusual_case, greater, num_observations):
    46     db.resetOffsets()
    47    
    48     def trainAux(low,high,num_trials):
    49         estimator = functools.partial(multiBoxTest, {'low':low, 'high':high}, greater)
    50         estimates = bootstrap3(estimator, db, 'train', unusual_case, num_observations, num_trials)
    51         null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, num_observations, num_trials)
    52 
    53         bad_estimates = len([e for e in estimates if e != 1])
    54         bad_null_estimates = len([e for e in null_estimates if e != 0])
    55        
    56         false_negatives = 100.0*bad_estimates/num_trials
    57         false_positives = 100.0*bad_null_estimates/num_trials
    58         return false_positives,false_negatives
    59 
    60     #start = time.time()
    61     wt = WorkerThreads(2, trainAux)
    62    
    63     num_trials = 200
    64     width = 1.0
    65     performance = []
    66     for low in range(0,50):
    67         wt.addJob(low, (low,low+width,num_trials))
    68     wt.wait()
    69     while not wt.resultq.empty():
    70         job_id,errors = wt.resultq.get()
    71         fp,fn = errors
    72         performance.append(((fp+fn)/2.0, job_id, fn, fp))
    73     performance.sort()
    74     #pprint.pprint(performance)
    75     #print(time.time()-start)
    76    
    77     num_trials = 200
    78     lows = [p[1] for p in performance[0:5]]
    79     widths = [w/10.0 for w in range(5,65,5)]
    80     performance = []
    81     for width in widths:
    82         false_positives = []
    83         false_negatives = []
    84         for low in lows:
    85             wt.addJob(low,(low,low+width,num_trials))
    86         wt.wait()
    87         while not wt.resultq.empty():
    88             job_id,errors = wt.resultq.get()
    89             fp,fn = errors
    90             false_negatives.append(fn)
    91             false_positives.append(fp)
    92 
    93         #print(width, false_negatives)
    94         #print(width, false_positives)
    95         #performance.append(((statistics.mean(false_positives)+statistics.mean(false_negatives))/2.0,
    96         #                    width, statistics.mean(false_negatives), statistics.mean(false_positives)))
    97         performance.append((abs(statistics.mean(false_positives)-statistics.mean(false_negatives)),
    98                             width, statistics.mean(false_negatives), statistics.mean(false_positives)))
    99     performance.sort()
    100     #pprint.pprint(performance)
    101     good_width = performance[0][1]
    102     #print("good_width:",good_width)
    103 
    104 
    105     num_trials = 500
    106     performance = []
    107     for low in lows:
    108         wt.addJob(low, (low,low+good_width,num_trials))
    109     wt.wait()
    110     while not wt.resultq.empty():
    111         job_id,errors = wt.resultq.get()
    112         fp,fn = errors
    113         performance.append(((fp+fn)/2.0, job_id, fn, fp))
    114     performance.sort()
    115     #pprint.pprint(performance)
    116     best_low = performance[0][1]
    117     #print("best_low:", best_low)
    118 
    119    
    120     num_trials = 500
    121     widths = [good_width+(x/100.0) for x in range(-70,75,5) if good_width+(x/100.0) > 0.0]
    122     performance = []
    123     for width in widths:
    124         wt.addJob(width, (best_low,best_low+width,num_trials))
    125     wt.wait()
    126     while not wt.resultq.empty():
    127         job_id,errors = wt.resultq.get()
    128         fp,fn = errors
    129         #performance.append(((fp+fn)/2.0, job_id, fn, fp))
    130         performance.append((abs(fp-fn), job_id, fn, fp))
    131     performance.sort()
    132     #pprint.pprint(performance)
    133     best_width=performance[0][1]
    134     #print("best_width:",best_width)
    135     #print("final_performance:", performance[0][0])
    136 
    137     wt.stop()
    138     params = json.dumps({"low":best_low,"high":best_low+best_width})
    139     return {'trial_type':"train",
    140             'num_observations':num_observations,
    141             'num_trials':num_trials,
    142             'params':params,
    143             'false_positives':performance[0][3],
    144             'false_negatives':performance[0][2]}
    145 
    146 
    147 def trainSummary(summaryFunc, db, unusual_case, greater, num_observations):
    148     db.resetOffsets()
    149     stest = functools.partial(summaryTest, summaryFunc)
    150    
    151     def trainAux(distance, threshold, num_trials):
    152         estimator = functools.partial(stest, {'distance':distance,'threshold':threshold}, greater)
    153         estimates = bootstrap3(estimator, db, 'train', unusual_case, num_observations, num_trials)
    154         null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, num_observations, num_trials)
    155 
    156         bad_estimates = len([e for e in estimates if e != 1])
    157         bad_null_estimates = len([e for e in null_estimates if e != 0])
    158        
    159         false_negatives = 100.0*bad_estimates/num_trials
    160         false_positives = 100.0*bad_null_estimates/num_trials
    161         return false_positives,false_negatives
    162 
    163     #determine expected delta based on differences
    164     mean_diffs = [s['unusual_case']-s['other_cases'] for s in db.subseries('train', unusual_case)]
    165     threshold = summaryFunc(mean_diffs)/2.0
    166     #print("init_threshold:", threshold)
    167    
    168     wt = WorkerThreads(2, trainAux)
    169    
    170     num_trials = 500
    171     performance = []
    172     for distance in range(1,50):
    173         wt.addJob(distance, (distance,threshold,num_trials))
    174     wt.wait()
    175     while not wt.resultq.empty():
    176         job_id,errors = wt.resultq.get()
    177         fp,fn = errors
    178         performance.append(((fp+fn)/2.0, job_id, fn, fp))
    179    
    180     performance.sort()
    181     #pprint.pprint(performance)
    182     good_distance = performance[0][1]
    183     #print("good_distance:",good_distance)
    184 
    185    
    186     num_trials = 500
    187     performance = []
    188     for t in range(80,122,2):
    189         wt.addJob(threshold*(t/100.0), (good_distance,threshold*(t/100.0),num_trials))
    190     wt.wait()
    191     while not wt.resultq.empty():
    192         job_id,errors = wt.resultq.get()
    193         fp,fn = errors
    194         #performance.append(((fp+fn)/2.0, job_id, fn, fp))
    195         performance.append((abs(fp-fn), job_id, fn, fp))
    196     performance.sort()
    197     #pprint.pprint(performance)
    198     good_threshold = performance[0][1]
    199     #print("good_threshold:", good_threshold)
    200 
    201    
    202     num_trials = 500
    203     performance = []
    204     for d in [good_distance+s for s in range(-4,5) if good_distance+s > -1]:
    205         wt.addJob(d, (d,good_threshold,num_trials))
    206     wt.wait()
    207     while not wt.resultq.empty():
    208         job_id,errors = wt.resultq.get()
    209         fp,fn = errors
    210         performance.append(((fp+fn)/2.0, job_id, fn, fp))
    211     performance.sort()
    212     #pprint.pprint(performance)
    213     best_distance = performance[0][1]
    214     #print("best_distance:",best_distance)
    215 
    216    
    217     num_trials = 500
    218     performance = []
    219     for t in range(90,111):
    220         wt.addJob(good_threshold*(t/100.0), (best_distance,good_threshold*(t/100.0),num_trials))
    221     wt.wait()
    222     while not wt.resultq.empty():
    223         job_id,errors = wt.resultq.get()
    224         fp,fn = errors
    225         #performance.append(((fp+fn)/2.0, job_id, fn, fp))
    226         performance.append((abs(fp-fn), job_id, fn, fp))
    227     performance.sort()
    228     #pprint.pprint(performance)
    229     best_threshold = performance[0][1]
    230     #print("best_threshold:", best_threshold)
    231 
    232     wt.stop()
    233     params = json.dumps({'distance':best_distance,'threshold':best_threshold})
    234     return {'trial_type':"train",
    235             'num_observations':num_observations,
    236             'num_trials':num_trials,
    237             'params':params,
    238             'false_positives':performance[0][3],
    239             'false_negatives':performance[0][2]}
    240 
    241 
    242 def trainKalman(db, unusual_case, greater, num_observations):
    243     db.resetOffsets()
    244 
    245     def trainAux(params, num_trials):
    246         estimator = functools.partial(kalmanTest, params, greater)
    247         estimates = bootstrap3(estimator, db, 'train', unusual_case, num_observations, num_trials)
    248         null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, num_observations, num_trials)
    249        
    250         bad_estimates = len([e for e in estimates if e != 1])
    251         bad_null_estimates = len([e for e in null_estimates if e != 0])
    252        
    253         false_negatives = 100.0*bad_estimates/num_trials
    254         false_positives = 100.0*bad_null_estimates/num_trials
    255         return false_positives,false_negatives
    256    
    257     mean_diffs = [s['unusual_case']-s['other_cases'] for s in db.subseries('train', unusual_case)]
    258     good_threshold = kfilter({},mean_diffs)['est'][-1]/2.0
    259 
    260     wt = WorkerThreads(2, trainAux)
    261     num_trials = 200
    262     performance = []
    263     for t in range(90,111):
    264         params = {'threshold':good_threshold*(t/100.0)}
    265         wt.addJob(good_threshold*(t/100.0), (params,num_trials))
    266     wt.wait()
    267     while not wt.resultq.empty():
    268         job_id,errors = wt.resultq.get()
    269         fp,fn = errors
    270         #performance.append(((fp+fn)/2.0, job_id, fn, fp))
    271         performance.append((abs(fp-fn), job_id, fn, fp))
    272     performance.sort()
    273     #pprint.pprint(performance)
    274     best_threshold = performance[0][1]
    275     #print("best_threshold:", best_threshold)
    276     params = {'threshold':best_threshold}
    277 
    278     wt.stop()
    279    
    280     return {'trial_type':"train",
    281             'num_observations':num_observations,
    282             'num_trials':num_trials,
    283             'params':json.dumps(params),
    284             'false_positives':performance[0][3],
    285             'false_negatives':performance[0][2]}
    286 
    287    
    288     #determine expected delta based on differences
    289 classifiers = {'boxtest':{'train':trainBoxTest, 'test':multiBoxTest, 'train_results':[]},
    290                'midsummary':{'train':functools.partial(trainSummary, midsummary), 'test':midsummaryTest, 'train_results':[]},
    291                #'ubersummary':{'train':functools.partial(trainSummary, ubersummary), 'test':ubersummaryTest, 'train_results':[]},
    292                'quadsummary':{'train':functools.partial(trainSummary, quadsummary), 'test':quadsummaryTest, 'train_results':[]},
    293                'kalman':{'train':trainKalman, 'test':kalmanTest, 'train_results':[]},
    294                #'_trimean':{'train':None, 'test':trimeanTest, 'train_results':[]},
    295               }
    296 
    297 
    298 db = nanownlib.storage.db(options.session_data)
    299 
    300 import cProfile
    30149
    30250def trainClassifier(db, unusual_case, greater, classifier, retrain=False):
     
    32472        print("number of observations: %d | error: %f | false_positives: %f | false_negatives: %f | train time: %s | params: %s"
    32573              % (num_obs, error, result['false_positives'],result['false_negatives'], train_time, result['params']))
    326         db.addClassifierResults(result)
     74        db.addClassifierResult(result)
    32775        classifiers[classifier]['train_results'].append(result)
    32876
     
    355103
    356104
     105    def getResult(classifier, params, num_obs, num_trials):
     106        jparams = json.dumps(params, sort_keys=True)
     107        result = db.fetchClassifierResult(classifier, 'test', num_obs, jparams)
     108        if result:
     109            fp = result['false_positives']
     110            fn = result['false_negatives']
     111        else:
     112            fp,fn = testAux(params, num_trials, num_obs)
     113            result = {'classifier':classifier,
     114                      'trial_type':"test",
     115                      'num_observations':num_obs,
     116                      'num_trials':num_trials,
     117                      'params':jparams,
     118                      'false_positives':fp,
     119                      'false_negatives':fn}
     120            db.addClassifierResult(result)
     121        return ((fp+fn)/2.0,result)
     122   
    357123    if retest:
    358124        print("Dropping stored test results...")
     
    368134   
    369135        print("initial test")
    370         fp,fn = testAux(params, num_trials, num_obs)
    371         error = (fp+fn)/2.0
     136        error,result = getResult(classifier,params,num_obs,num_trials)
    372137        print("walking up")
    373138        while (error > target_error) and (num_obs < max_obs):
     
    375140            #print("increase_factor:", increase_factor)
    376141            num_obs = min(int(increase_factor*num_obs), max_obs)
    377             fp,fn = testAux(params, num_trials, num_obs)
    378             error = (fp+fn)/2.0
     142            error,result = getResult(classifier,params,num_obs,num_trials)
    379143
    380144        print("walking down")
    381145        while (num_obs > 0):
    382             current_best = (num_obs,error,params,fp,fn)
     146            current_best = (error,result)
    383147            num_obs = int(0.95*num_obs)
    384             fp,fn = testAux(params, num_trials, num_obs)
    385             error = (fp+fn)/2.0
     148            error,result = getResult(classifier,params,num_obs,num_trials)
    386149            if error > target_error:
    387150                break
    388151       
    389         test_results.append(current_best)
    390 
    391     test_results.sort()
    392     best_obs,error,best_params,fp,fn = test_results[0]
    393    
    394     return {'classifier':classifier,
    395             'trial_type':"test",
    396             'num_observations':best_obs,
    397             'num_trials':num_trials,
    398             'params':best_params,
    399             'false_positives':fp,
    400             'false_negatives':fn}
     152    return current_best
    401153
    402154
    403 start = time.time()
    404 unusual_case,unusual_diff = findUnusualTestCase(db)
    405 greater = (unusual_diff > 0)
    406 print("unusual_case:", unusual_case)
    407 print("unusual_diff:", unusual_diff)
    408 end = time.time()
    409 print(":", end-start)
     155if options.unusual_case != None:
     156    unusual_case,greater = options.unusual_case.split(',')
     157    greater = bool(int(greater))
     158else:
     159    start = time.time()
     160    unusual_case,unusual_diff = findUnusualTestCase(db)
     161    greater = (unusual_diff > 0)
     162    print("unusual_case:", unusual_case)
     163    print("unusual_diff:", unusual_diff)
     164    end = time.time()
     165    print(":", end-start)
     166
    410167
    411168for c in sorted(classifiers.keys()):
     
    424181    start = time.time()
    425182    print("Testing %s..." % c)
    426     result = testClassifier(db, unusual_case, greater, c, c in options.retest)
     183    error,result = testClassifier(db, unusual_case, greater, c, c in (options.retest+options.retrain))
    427184    print("%s result:" % c)
    428185    pprint.pprint(result)
    429     classifiers[c]['test_error'] = (result['false_positives']+result['false_negatives'])/2.0
     186    classifiers[c]['test_error'] = error
    430187    print("completed in:", time.time()-start)
Note: See TracChangeset for help on using the changeset viewer.