#!/usr/bin/env python3 #-*- mode: Python;-*- import sys import os import time import random import statistics import functools import argparse import pprint import json VERSION = "{DEVELOPMENT}" if VERSION == "{DEVELOPMENT}": script_dir = '.' try: script_dir = os.path.dirname(os.path.realpath(__file__)) except: try: script_dir = os.path.dirname(os.path.abspath(sys.argv[0])) except: pass sys.path.append("%s/../lib" % script_dir) from nanownlib import * from nanownlib.stats import * from nanownlib.train import * from nanownlib.parallel import WorkerThreads import nanownlib.storage parser = argparse.ArgumentParser( description="") #parser.add_argument('-c', dest='cases', type=str, default='{"short":10000,"long":1010000}', # help='JSON representation of echo timing cases. Default: {"short":10000,"long":1010000}') parser.add_argument('--unusual-case', action='store', default=None, help='Specify the unusual case and whether it is greater than the other cases. Format: {case name},{1 or 0}') parser.add_argument('--retrain', action='append', default=[], help='Force a classifier to be retrained (and retested). May be specified multiple times.') parser.add_argument('--retest', action='append', default=[], help='Force a classifier to be retested. May be specified multiple times.') parser.add_argument('session_data', default=None, help='Database file storing session information') options = parser.parse_args() db = nanownlib.storage.db(options.session_data) def trainClassifier(db, unusual_case, greater, classifier, retrain=False): if retrain: print("Dropping stored training results...") db.deleteClassifierResults(classifier, 'train') trainer = classifiers[classifier]['train'] threshold = 5.0 # in percent num_obs = 7 max_obs = int(db.populationSize('train')/5) result = None while num_obs < max_obs: num_obs = min(int(num_obs*1.5), max_obs) result = db.fetchClassifierResult(classifier, 'train', num_obs) if result != None: train_time = "(stored)" else: start = time.time() result = trainer(db,unusual_case,greater,num_obs) result['classifier'] = classifier train_time = "%8.2f" % (time.time()-start) error = statistics.mean([result['false_positives'],result['false_negatives']]) print("num. observations: %5d | error: %6.2f | fp: %6.2f | fn: %6.2f | train time: %s | params: %s" % (num_obs, error, result['false_positives'],result['false_negatives'], train_time, result['params'])) db.addClassifierResult(result) classifiers[classifier]['train_results'].append(result) if error < threshold and num_obs > 100: break return result def testClassifier(db, unusual_case, greater, classifier, retest=False): target_error = 5.0 # in percent num_trials = 1000 max_obs = int(db.populationSize('test')/5) tester = classifiers[classifier]['test'] def testAux(params, num_trials, num_observations): estimator = functools.partial(tester, params, greater) estimates = bootstrap3(estimator, db, 'test', unusual_case, num_observations, num_trials) null_estimates = bootstrap3(estimator, db, 'train_null', unusual_case, num_observations, num_trials) bad_estimates = len([e for e in estimates if e != 1]) bad_null_estimates = len([e for e in null_estimates if e != 0]) false_negatives = 100.0*bad_estimates/num_trials false_positives = 100.0*bad_null_estimates/num_trials return false_positives,false_negatives def getResult(classifier, params, num_obs, num_trials): jparams = json.dumps(params, sort_keys=True) result = db.fetchClassifierResult(classifier, 'test', num_obs, jparams) if result: test_time = '(stored)' fp = result['false_positives'] fn = result['false_negatives'] else: start = time.time() fp,fn = testAux(params, num_trials, num_obs) result = {'classifier':classifier, 'trial_type':"test", 'num_observations':num_obs, 'num_trials':num_trials, 'params':jparams, 'false_positives':fp, 'false_negatives':fn} db.addClassifierResult(result) test_time = '%8.2f' % (time.time()-start) print("num. observations: %5d | error: %6.2f | fp: %6.2f | fn: %6.2f | test time: %s" % (num_obs,(fp+fn)/2.0,fp,fn,test_time)) return ((fp+fn)/2.0,result) if retest: print("Dropping stored test results...") db.deleteClassifierResults(classifier, 'test') lte = math.log(target_error/100.0) for tr in classifiers[classifier]['train_results']: db.resetOffsets() params = json.loads(tr['params']) num_obs = tr['num_observations'] print("parameters:", params) error,result = getResult(classifier,params,num_obs,num_trials) #print("walking up") while (error > target_error) and (num_obs < max_obs): increase_factor = 1.5 * lte/math.log(error/100.0) # don't ask how I came up with this #print("increase_factor:", increase_factor) num_obs = min(int(increase_factor*num_obs), max_obs) error,result = getResult(classifier,params,num_obs,num_trials) #print("walking down") while (num_obs > 0): num_obs = int(0.95*num_obs) error,result = getResult(classifier,params,num_obs,num_trials) if error > target_error: break if options.unusual_case != None: unusual_case,greater = options.unusual_case.split(',') greater = bool(int(greater)) db.setUnusualCase(unusual_case,greater) else: ucg = db.getUnusualCase() if ucg != None: unusual_case,greater = ucg print("Using cached unusual_case:", unusual_case) else: unusual_case,delta = findUnusualTestCase(db) greater = (delta > 0) print("Auto-detected unusual_case '%s' with delta: %d" % (unusual_case,delta)) db.setUnusualCase(unusual_case,greater) for c in sorted(classifiers.keys()): if classifiers[c]['train'] == None: continue start = time.time() print("Training %s..." % c) result = trainClassifier(db, unusual_case, greater, c, c in options.retrain) #print("%s result:" % c) #pprint.pprint(result) print("completed in: %8.2f\n"% (time.time()-start)) db.clearCache() for c in sorted(classifiers.keys()): start = time.time() print("Testing %s..." % c) testClassifier(db, unusual_case, greater, c, c in (options.retest+options.retrain)) print("completed in: %8.2f\n"% (time.time()-start)) best_obs,best_error = evaluateTestResults(db) best_obs = sorted(best_obs, key=lambda x: x['num_observations']) best_error = sorted(best_error, key=lambda x: x['error']) winner = None for bo in best_obs: sys.stdout.write("%(num_observations)5d obs | %(classifier)12s | %(params)s" % bo) if winner == None: sys.stdout.write(" (winner)") winner = bo print() for be in best_error: sys.stdout.write("%(error)3.2f%% error | %(classifier)12s | %(params)s" % be) if winner == None: sys.stdout.write(" (winner)") winner = be print()