Changeset 16 for trunk/bin/train
- Timestamp:
- 08/01/15 19:01:31 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/bin/train
r13 r16 67 67 result = trainer(db,unusual_case,greater,num_obs) 68 68 result['classifier'] = classifier 69 train_time = "% f" % (time.time()-start)69 train_time = "%8.2f" % (time.time()-start) 70 70 71 71 error = statistics.mean([result['false_positives'],result['false_negatives']]) 72 print("num ber of observations: %d | error: %f | false_positives: %f | false_negatives: %f | train time: %s | params: %s"72 print("num. observations: %5d | error: %6.2f | fp: %6.2f | fn: %6.2f | train time: %s | params: %s" 73 73 % (num_obs, error, result['false_positives'],result['false_negatives'], train_time, result['params'])) 74 74 db.addClassifierResult(result) … … 99 99 false_negatives = 100.0*bad_estimates/num_trials 100 100 false_positives = 100.0*bad_null_estimates/num_trials 101 print("testAux:", num_observations, false_positives, false_negatives, params)102 101 return false_positives,false_negatives 103 102 … … 107 106 result = db.fetchClassifierResult(classifier, 'test', num_obs, jparams) 108 107 if result: 108 test_time = '(stored)' 109 109 fp = result['false_positives'] 110 110 fn = result['false_negatives'] 111 111 else: 112 start = time.time() 112 113 fp,fn = testAux(params, num_trials, num_obs) 113 114 result = {'classifier':classifier, … … 119 120 'false_negatives':fn} 120 121 db.addClassifierResult(result) 122 test_time = '%8.2f' % (time.time()-start) 123 124 print("num. observations: %5d | error: %6.2f | fp: %6.2f | fn: %6.2f | test time: %s" 125 % (num_obs,(fp+fn)/2.0,fp,fn,test_time)) 121 126 return ((fp+fn)/2.0,result) 122 127 … … 126 131 127 132 128 test_results = []129 133 lte = math.log(target_error/100.0) 130 134 for tr in classifiers[classifier]['train_results']: … … 133 137 num_obs = tr['num_observations'] 134 138 135 print(" initial test")139 print("parameters:", params) 136 140 error,result = getResult(classifier,params,num_obs,num_trials) 137 print("walking up")141 #print("walking up") 138 142 while (error > target_error) and (num_obs < max_obs): 139 143 increase_factor = 1.5 * lte/math.log(error/100.0) # don't ask how I came up with this … … 142 146 error,result = getResult(classifier,params,num_obs,num_trials) 143 147 144 print("walking down")148 #print("walking down") 145 149 while (num_obs > 0): 146 current_best = (error,result)147 150 num_obs = int(0.95*num_obs) 148 151 error,result = getResult(classifier,params,num_obs,num_trials) 149 152 if error > target_error: 150 153 break 151 152 return current_best 153 154 154 155 155 156 if options.unusual_case != None: 156 157 unusual_case,greater = options.unusual_case.split(',') 157 158 greater = bool(int(greater)) 159 db.setUnusualCase(unusual_case,greater) 158 160 else: 159 start = time.time() 160 unusual_case,unusual_diff = findUnusualTestCase(db) 161 greater = (unusual_diff > 0) 162 print("unusual_case:", unusual_case) 163 print("unusual_diff:", unusual_diff) 164 end = time.time() 165 print(":", end-start) 161 ucg = db.getUnusualCase() 162 if ucg != None: 163 unusual_case,greater = ucg 164 print("Using cached unusual_case:", unusual_case) 165 else: 166 unusual_case,delta = findUnusualTestCase(db) 167 greater = (delta > 0) 168 print("Auto-detected unusual_case '%s' with delta: %d" % (unusual_case,delta)) 169 db.setUnusualCase(unusual_case,greater) 166 170 167 171 … … 172 176 print("Training %s..." % c) 173 177 result = trainClassifier(db, unusual_case, greater, c, c in options.retrain) 174 print("%s result:" % c)175 pprint.pprint(result)176 print("completed in: ", time.time()-start)178 #print("%s result:" % c) 179 #pprint.pprint(result) 180 print("completed in: %8.2f\n"% (time.time()-start)) 177 181 178 182 db.clearCache() … … 181 185 start = time.time() 182 186 print("Testing %s..." % c) 183 error,result = testClassifier(db, unusual_case, greater, c, c in (options.retest+options.retrain)) 184 print("%s result:" % c) 185 pprint.pprint(result) 186 classifiers[c]['test_error'] = error 187 print("completed in:", time.time()-start) 187 testClassifier(db, unusual_case, greater, c, c in (options.retest+options.retrain)) 188 print("completed in: %8.2f\n"% (time.time()-start)) 189 190 191 best_obs,best_error = evaluateTestResults(db) 192 best_obs = sorted(best_obs, key=lambda x: x['num_observations']) 193 best_error = sorted(best_error, key=lambda x: x['error']) 194 winner = None 195 for bo in best_obs: 196 sys.stdout.write("%(num_observations)5d obs | %(classifier)12s | %(params)s" % bo) 197 if winner == None: 198 sys.stdout.write(" (winner)") 199 winner = bo 200 print() 201 202 for be in best_error: 203 sys.stdout.write("%(error)3.2f%% error | %(classifier)12s | %(params)s" % be) 204 if winner == None: 205 sys.stdout.write(" (winner)") 206 winner = be 207 print()
Note: See TracChangeset
for help on using the changeset viewer.