Changeset 16 for trunk/bin/train


Ignore:
Timestamp:
08/01/15 19:01:31 (9 years ago)
Author:
tim
Message:

.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/bin/train

    r13 r16  
    6767            result = trainer(db,unusual_case,greater,num_obs)
    6868            result['classifier'] = classifier
    69             train_time = "%f" % (time.time()-start)
     69            train_time = "%8.2f" % (time.time()-start)
    7070           
    7171        error = statistics.mean([result['false_positives'],result['false_negatives']])
    72         print("number of observations: %d | error: %f | false_positives: %f | false_negatives: %f | train time: %s | params: %s"
     72        print("num. observations: %5d | error: %6.2f | fp: %6.2f | fn: %6.2f | train time: %s | params: %s"
    7373              % (num_obs, error, result['false_positives'],result['false_negatives'], train_time, result['params']))
    7474        db.addClassifierResult(result)
     
    9999        false_negatives = 100.0*bad_estimates/num_trials
    100100        false_positives = 100.0*bad_null_estimates/num_trials
    101         print("testAux:", num_observations, false_positives, false_negatives, params)
    102101        return false_positives,false_negatives
    103102
     
    107106        result = db.fetchClassifierResult(classifier, 'test', num_obs, jparams)
    108107        if result:
     108            test_time = '(stored)'
    109109            fp = result['false_positives']
    110110            fn = result['false_negatives']
    111111        else:
     112            start = time.time()
    112113            fp,fn = testAux(params, num_trials, num_obs)
    113114            result = {'classifier':classifier,
     
    119120                      'false_negatives':fn}
    120121            db.addClassifierResult(result)
     122            test_time = '%8.2f' % (time.time()-start)
     123           
     124        print("num. observations: %5d | error: %6.2f | fp: %6.2f | fn: %6.2f | test time: %s"
     125              % (num_obs,(fp+fn)/2.0,fp,fn,test_time))
    121126        return ((fp+fn)/2.0,result)
    122127   
     
    126131
    127132
    128     test_results = []
    129133    lte = math.log(target_error/100.0)
    130134    for tr in classifiers[classifier]['train_results']:
     
    133137        num_obs = tr['num_observations']
    134138   
    135         print("initial test")
     139        print("parameters:", params)
    136140        error,result = getResult(classifier,params,num_obs,num_trials)
    137         print("walking up")
     141        #print("walking up")
    138142        while (error > target_error) and (num_obs < max_obs):
    139143            increase_factor = 1.5 * lte/math.log(error/100.0) # don't ask how I came up with this
     
    142146            error,result = getResult(classifier,params,num_obs,num_trials)
    143147
    144         print("walking down")
     148        #print("walking down")
    145149        while (num_obs > 0):
    146             current_best = (error,result)
    147150            num_obs = int(0.95*num_obs)
    148151            error,result = getResult(classifier,params,num_obs,num_trials)
    149152            if error > target_error:
    150153                break
    151        
    152     return current_best
    153 
     154   
    154155
    155156if options.unusual_case != None:
    156157    unusual_case,greater = options.unusual_case.split(',')
    157158    greater = bool(int(greater))
     159    db.setUnusualCase(unusual_case,greater)
    158160else:
    159     start = time.time()
    160     unusual_case,unusual_diff = findUnusualTestCase(db)
    161     greater = (unusual_diff > 0)
    162     print("unusual_case:", unusual_case)
    163     print("unusual_diff:", unusual_diff)
    164     end = time.time()
    165     print(":", end-start)
     161    ucg = db.getUnusualCase()
     162    if ucg != None:
     163        unusual_case,greater = ucg
     164        print("Using cached unusual_case:", unusual_case)
     165    else:
     166        unusual_case,delta = findUnusualTestCase(db)
     167        greater = (delta > 0)
     168        print("Auto-detected unusual_case '%s' with delta: %d" %  (unusual_case,delta))
     169        db.setUnusualCase(unusual_case,greater)
    166170
    167171
     
    172176    print("Training %s..." % c)
    173177    result = trainClassifier(db, unusual_case, greater, c, c in options.retrain)
    174     print("%s result:" % c)
    175     pprint.pprint(result)
    176     print("completed in:", time.time()-start)
     178    #print("%s result:" % c)
     179    #pprint.pprint(result)
     180    print("completed in: %8.2f\n"% (time.time()-start))
    177181
    178182db.clearCache()
     
    181185    start = time.time()
    182186    print("Testing %s..." % c)
    183     error,result = testClassifier(db, unusual_case, greater, c, c in (options.retest+options.retrain))
    184     print("%s result:" % c)
    185     pprint.pprint(result)
    186     classifiers[c]['test_error'] = error
    187     print("completed in:", time.time()-start)
     187    testClassifier(db, unusual_case, greater, c, c in (options.retest+options.retrain))
     188    print("completed in: %8.2f\n"% (time.time()-start))
     189
     190
     191best_obs,best_error = evaluateTestResults(db)
     192best_obs =   sorted(best_obs,   key=lambda x: x['num_observations'])
     193best_error = sorted(best_error, key=lambda x: x['error'])
     194winner = None
     195for bo in best_obs:
     196    sys.stdout.write("%(num_observations)5d obs   | %(classifier)12s | %(params)s" % bo)
     197    if winner == None:
     198        sys.stdout.write(" (winner)")
     199        winner = bo
     200    print()
     201       
     202for be in best_error:
     203    sys.stdout.write("%(error)3.2f%% error | %(classifier)12s | %(params)s" % be)
     204    if winner == None:
     205        sys.stdout.write(" (winner)")
     206        winner = be
     207    print()
Note: See TracChangeset for help on using the changeset viewer.