source: trunk/lib/nanownlib/__init__.py @ 16

Last change on this file since 16 was 16, checked in by tim, 9 years ago

.

File size: 18.9 KB
RevLine 
[4]1#!/usr/bin/env python3
2#-*- mode: Python;-*-
3
4import sys
5import time
[10]6import traceback
[4]7import random
8import argparse
9import socket
10import datetime
11import http.client
12import threading
13import queue
14import subprocess
15import multiprocessing
16import csv
17import json
18import gzip
19import statistics
20try:
[16]21    import numpy
22except:
23    sys.stderr.write('ERROR: Could not import numpy module.  Ensure it is installed.\n')
24    sys.stderr.write('       Under Debian, the package name is "python3-numpy"\n.')
25    sys.exit(1)
26
27try:
[4]28    import requests
29except:
30    sys.stderr.write('ERROR: Could not import requests module.  Ensure it is installed.\n')
31    sys.stderr.write('       Under Debian, the package name is "python3-requests"\n.')
32    sys.exit(1)
33
34from .stats import *
35
36
37def getLocalIP(remote_host, remote_port):
38    connection = socket.create_connection((remote_host, remote_port))
39    ret_val = connection.getsockname()[0]
40    connection.close()
41
42    return ret_val
43
44
45def getIfaceForIP(ip):
[16]46    try:
47        import netifaces
48    except:
49        sys.stderr.write('ERROR: Could not import netifaces module.  Ensure it is installed.\n')
50        sys.stderr.write('       Try: pip3 install netifaces\n.')
51        sys.exit(1)
52   
[4]53    for iface in netifaces.interfaces():
54        addrs = netifaces.ifaddresses(iface).get(netifaces.AF_INET, None)
55        if addrs:
56            for a in addrs:
57                if a.get('addr', None) == ip:
58                    return iface
59
60
61def setTCPTimestamps(enabled=True):
62    fh = open('/proc/sys/net/ipv4/tcp_timestamps', 'r+b')
63    ret_val = False
64    if fh.read(1) == b'1':
65        ret_val = True
66
67    fh.seek(0)
68    if enabled:
69        fh.write(b'1')
70    else:
71        fh.write(b'0')
72    fh.close()
73   
74    return ret_val
75
76
77def trickleHTTPRequest(ip,port,hostname):
78    my_port = None
79    try:
80        sock = socket.create_connection((ip, port))
81        my_port = sock.getsockname()[1]
82       
83        #print('.')
84        sock.sendall(b'GET / HTTP/1.1\r\n')
85        time.sleep(0.5)
86        rest = b'''Host: '''+hostname.encode('utf-8')+b'''\r\nUser-Agent: Secret Agent Man\r\nX-Extra: extra read all about it!\r\nConnection: close\r\n'''
87        for r in rest:
88            sock.sendall(bytearray([r]))
89            time.sleep(0.05)
90
91        time.sleep(0.5)
92        sock.sendall('\r\n')
93
94        r = None
95        while r != b'':
96            r = sock.recv(16)
97
98        sock.close()
99    except Exception as e:
100        pass
101
102    return my_port
103
104
105def runTimestampProbes(host_ip, port, hostname, num_trials, concurrency=4): 
106    myq = queue.Queue()
107    def threadWrapper(*args):
108        try:
109            myq.put(trickleHTTPRequest(*args))
110        except Exception as e:
111            sys.stderr.write("ERROR from trickleHTTPRequest: %s\n" % repr(e))
112            myq.put(None)
113
114    threads = []
115    ports = []
116    for i in range(num_trials):
117        if len(threads) >= concurrency:
118            ports.append(myq.get())
119        t = threading.Thread(target=threadWrapper, args=(host_ip, port, hostname))
120        t.start()
121        threads.append(t)
122
123    for t in threads:
124        t.join()
125
126    while myq.qsize() > 0:
127        ports.append(myq.get())
128
129    return ports
130
131
132def computeTimestampPrecision(sniffer_fp, ports):
133    rcvd = []
134    for line in sniffer_fp:
135        p = json.loads(line)
136        if p['sent']==0:
137            rcvd.append((p['observed'],p['tsval'],int(p['local_port'])))
138
139    slopes = []
140    for port in ports:
141        trcvd = [tr for tr in rcvd if tr[2]==port and tr[1]!=0]
142
143        if len(trcvd) < 2:
144            sys.stderr.write("WARN: Inadequate data points.\n")
145            continue
146       
147        if trcvd[0][1] > trcvd[-1][1]:
148            sys.stderr.write("WARN: TSval wrap.\n")
149            continue
150
151        x = [tr[1] for tr in trcvd]
152        y = [tr[0] for tr in trcvd]
153
154        slope,intercept = OLSRegression(x, y)
155        slopes.append(slope)
156
157    if len(slopes) == 0:
158        return None,None,None
159
160    m = statistics.mean(slopes)
161    if len(slopes) == 1:
162        return (m, None, slopes)
163    else:
164        return (m, statistics.stdev(slopes), slopes)
165
166   
167def OLSRegression(x,y):
168    #print(x,y)
169    x = numpy.array(x)
170    y = numpy.array(y)
171    #A = numpy.vstack([x, numpy.ones(len(x))]).T
172    #m, c = numpy.linalg.lstsq(A, y)[0] # broken
173    #c,m = numpy.polynomial.polynomial.polyfit(x, y, 1) # less accurate
174    c,m = numpy.polynomial.Polynomial.fit(x,y,1).convert().coef
175
176    #print(m,c)
177
178    #import matplotlib.pyplot as plt
179    #plt.clf()
180    #plt.scatter(x, y)
181    #plt.plot(x, m*x + c, 'r', label='Fitted line')
182    #plt.show()
183   
184    return (m,c)
185
186
187def startSniffer(target_ip, target_port, output_file):
188    my_ip = getLocalIP(target_ip, target_port)
189    my_iface = getIfaceForIP(my_ip)
[16]190    return subprocess.Popen(['chrt', '-r', '99', 'nanown-listen', my_iface, my_ip,
[4]191                             target_ip, "%d" % target_port, output_file, '0'])
192
193def stopSniffer(sniffer):
194    sniffer.terminate()
195    sniffer.wait(2)
196    if sniffer.poll() == None:
197        sniffer.kill()
198        sniffer.wait(1)
199
200       
201def setCPUAffinity():
202    import ctypes
203    from ctypes import cdll,c_int,byref
204    cpus = multiprocessing.cpu_count()
205   
206    libc = cdll.LoadLibrary("libc.so.6")
207    #libc.sched_setaffinity(os.getpid(), 1, ctypes.byref(ctypes.c_int(0x01)))
208    return libc.sched_setaffinity(0, 4, byref(c_int(0x00000001<<(cpus-1))))
209
210
211# Monkey patching that instruments the HTTPResponse to collect connection source port info
212class MonitoredHTTPResponse(http.client.HTTPResponse):
213    local_address = None
214
215    def __init__(self, sock, *args, **kwargs):
216        self.local_address = sock.getsockname()
217        super(MonitoredHTTPResponse, self).__init__(sock,*args,**kwargs)
218           
219requests.packages.urllib3.connection.HTTPConnection.response_class = MonitoredHTTPResponse
220
221
222def removeDuplicatePackets(packets):
223    #return packets
[10]224    suspect = ''
[4]225    seen = {}
226    # XXX: Need to review this deduplication algorithm and make sure it is correct
227    for p in packets:
228        key = (p['sent'],p['tcpseq'],p['tcpack'],p['payload_len'])
[10]229        if (key not in seen):
[4]230            seen[key] = p
[10]231            continue
232        if p['sent']==1 and (seen[key]['observed'] > p['observed']): #earliest sent
233            seen[key] = p
[11]234            suspect += 's' # duplicated sent packets
[10]235            continue 
236        if p['sent']==0 and (seen[key]['observed'] > p['observed']): #earliest rcvd
237            seen[key] = p
[11]238            suspect += 'r' # duplicated received packets
[10]239            continue
[4]240   
[10]241    #if len(seen) < len(packets):
242    #   sys.stderr.write("INFO: removed %d duplicate packets.\n" % (len(packets) - len(seen)))
[4]243
244    return suspect,seen.values()
245
246
247def analyzePackets(packets, timestamp_precision, trim_sent=0, trim_rcvd=0):
248    suspect,packets = removeDuplicatePackets(packets)
249
[11]250    sort_key = lambda d: (d['observed'],d['tcpseq'])
251    alt_key = lambda d: (d['tcpseq'],d['observed'])
[4]252    sent = sorted((p for p in packets if p['sent']==1 and p['payload_len']>0), key=sort_key)
253    rcvd = sorted((p for p in packets if p['sent']==0 and p['payload_len']>0), key=sort_key)
[10]254    rcvd_alt = sorted((p for p in packets if p['sent']==0 and p['payload_len']>0), key=alt_key)
[4]255
[10]256    s_off = trim_sent
257    if s_off >= len(sent):
[11]258        suspect += 'd' # dropped packet?
[10]259        s_off = -1
260    last_sent = sent[s_off]
261
262    r_off = len(rcvd) - trim_rcvd - 1
[11]263    if r_off < 0:
264        suspect += 'd' # dropped packet?
[10]265        r_off = 0
266    last_rcvd = rcvd[r_off]
267    if last_rcvd != rcvd_alt[r_off]:
[11]268        suspect += 'R' # reordered received packets
[4]269   
270    last_sent_ack = None
271    try:
[13]272        last_sent_ack = min(((p['tcpack'],p['observed'],p) for p in packets
273                             if p['sent']==0 and p['payload_len']+last_sent['tcpseq']>=p['tcpack']))[2]
[11]274       
[4]275    except Exception as e:
276        sys.stderr.write("WARN: Could not find last_sent_ack.\n")
277
[13]278    packet_rtt = last_rcvd['observed'] - last_sent['observed']
[4]279    tsval_rtt = None
280    if None not in (timestamp_precision, last_sent_ack):
281        tsval_rtt = int(round((last_rcvd['tsval'] - last_sent_ack['tsval'])*timestamp_precision))
282
[13]283    if packet_rtt < 0 or (tsval_rtt != None and tsval_rtt < 0):
284        #sys.stderr.write("WARN: Negative packet or tsval RTT. last_rcvd=%s,last_sent=%s\n" % (last_rcvd, last_sent))
285        suspect += 'N'
286       
[4]287    return {'packet_rtt':packet_rtt,
288            'tsval_rtt':tsval_rtt,
289            'suspect':suspect,
290            'sent_trimmed':trim_sent,
291            'rcvd_trimmed':trim_rcvd},len(sent),len(rcvd)
292
293
[13]294# septasummary and mad for each dist of differences
[4]295def evaluateTrim(db, unusual_case, strim, rtrim):
296    cursor = db.conn.cursor()
297    query="""
298      SELECT packet_rtt-(SELECT avg(packet_rtt) FROM probes,trim_analysis
[10]299                         WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case!=:unusual_case AND sample=u.s AND probes.type in ('train','test'))
300      FROM (SELECT probes.sample s,packet_rtt FROM probes,trim_analysis WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case=:unusual_case AND probes.type in ('train','test') AND 1 NOT IN (select 1 from probes p,trim_analysis t WHERE p.sample=s AND t.probe_id=p.id AND t.suspect LIKE '%R%')) u
[4]301    """
[10]302    query="""
303      SELECT packet_rtt-(SELECT avg(packet_rtt) FROM probes,trim_analysis
304                         WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case!=:unusual_case AND sample=u.s AND probes.type in ('train','test'))
305      FROM (SELECT probes.sample s,packet_rtt FROM probes,trim_analysis WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case=:unusual_case AND probes.type in ('train','test')) u
306    """
[13]307    #TODO: check for "N" in suspect field and return a flag
308   
[4]309    params = {"strim":strim,"rtrim":rtrim,"unusual_case":unusual_case}
310    cursor.execute(query, params)
311    differences = [row[0] for row in cursor]
312   
[13]313    return septasummary(differences),mad(differences)
[4]314
315
316
[16]317def analyzeProbes(db, trim=None, recompute=False):
[4]318    db.conn.execute("CREATE INDEX IF NOT EXISTS packets_probe ON packets (probe_id)")
[10]319    db.conn.commit()
[4]320
[11]321    pcursor = db.conn.cursor()
[4]322    pcursor.execute("SELECT tcpts_mean FROM meta")
323    try:
324        timestamp_precision = pcursor.fetchone()[0]
325    except:
326        timestamp_precision = None
327   
328    pcursor.execute("DELETE FROM trim_analysis")
329    db.conn.commit()
[16]330    if recompute:
331        pcursor.execute("DELETE FROM analysis")
332        db.conn.commit()
[10]333
334    def loadPackets(db):
335        cursor = db.conn.cursor()
[16]336        #cursor.execute("SELECT * FROM packets ORDER BY probe_id")
337        cursor.execute("SELECT * FROM packets WHERE probe_id NOT IN (SELECT probe_id FROM analysis) ORDER BY probe_id")
[10]338
339        probe_id = None
340        entry = []
341        ret_val = []
342        for p in cursor:
343            if probe_id == None:
344                probe_id = p['probe_id']
345            if p['probe_id'] != probe_id:
346                ret_val.append((probe_id,entry))
347                probe_id = p['probe_id']
348                entry = []
349            entry.append(dict(p))
350        ret_val.append((probe_id,entry))
351        return ret_val
[16]352
353    def processPackets(packet_cache, strim, rtrim):
354        sent_tally = []
355        rcvd_tally = []
356        analyses = []
357        for probe_id,packets in packet_cache:
358            try:
359                analysis,s,r = analyzePackets(packets, timestamp_precision)
360                analysis['probe_id'] = probe_id
361                analyses.append(analysis)
362                sent_tally.append(s)
363                rcvd_tally.append(r)
364            except Exception as e:
365                #traceback.print_exc()
366                sys.stderr.write("WARN: couldn't find enough packets for probe_id=%s\n" % probe_id)
367        db.addTrimAnalyses(analyses)
368        db.conn.commit()
369        return statistics.mode(sent_tally),statistics.mode(rcvd_tally)
[4]370   
[16]371    #start = time.time()
[10]372    packet_cache = loadPackets(db)
[16]373    #print("packets loaded in: %f" % (time.time()-start))
374
375    if trim != None:
376        best_strim,best_rtrim = trim
377        processPackets(packet_cache, best_strim, best_rtrim)
378    else:
379        num_sent,num_rcvd = processPackets(packet_cache, 0, 0)
380        print("num_sent: %d, num_rcvd: %d" % (num_sent,num_rcvd))
[10]381   
[16]382        for strim in range(0,num_sent):
383            for rtrim in range(0,num_rcvd):
384                #print(strim,rtrim)
385                if strim == 0 and rtrim == 0:
386                    continue # no point in doing 0,0 again
387                processPackets(packet_cache, strim, rtrim)
[4]388
389   
[16]390        unusual_case,delta = findUnusualTestCase(db, (0,0))
391        evaluations = {}
392        for strim in range(0,num_sent):
393            for rtrim in range(0,num_rcvd):
394                evaluations[(strim,rtrim)] = evaluateTrim(db, unusual_case, strim, rtrim)
[4]395
[16]396        import pprint
397        pprint.pprint(evaluations)
[4]398
[16]399        delta_margin = 0.15
400        best_strim = 0
401        best_rtrim = 0
402        good_delta,good_mad = evaluations[(0,0)]
[4]403   
[16]404        for strim in range(1,num_sent):
405            delta,mad = evaluations[(strim,0)]
406            if delta*good_delta > 0.0 and (abs(good_delta) - abs(delta)) < abs(delta_margin*good_delta) and mad < good_mad:
407                best_strim = strim
408            else:
409                break
[4]410
[16]411        good_delta,good_mad = evaluations[(best_strim,0)]
412        for rtrim in range(1,num_rcvd):
413            delta,mad = evaluations[(best_strim,rtrim)]
414            if delta*good_delta > 0.0 and (abs(good_delta) - abs(delta)) < abs(delta_margin*good_delta) and mad < good_mad:
415                best_rtrim = rtrim
416            else:
417                break
[4]418
[16]419        print("selected trim parameters:",(best_strim,best_rtrim))
[4]420   
[16]421    pcursor.execute("""INSERT OR IGNORE INTO analysis
422                         SELECT id,probe_id,suspect,packet_rtt,tsval_rtt
423                           FROM trim_analysis
424                           WHERE sent_trimmed=? AND rcvd_trimmed=?""",
425                    (best_strim,best_rtrim))
[5]426    db.conn.commit()
[4]427   
[16]428    return len(packet_cache)
[4]429
430
431       
432def parseJSONLines(fp):
433    for line in fp:
434        yield json.loads(line)
435
436
437def associatePackets(sniffer_fp, db):
438    sniffer_fp.seek(0)
439
440    # now combine sampler data with packet data
441    buffered = []
442
443    cursor = db.conn.cursor()
444    cursor.execute("SELECT count(*) count,min(time_of_day) start,max(time_of_day+userspace_rtt) end from probes")
445    ptimes = cursor.fetchone()
446    window_size = 100*int((ptimes['end']-ptimes['start'])/ptimes['count'])
[16]447    #print("associate window_size:", window_size)
[4]448
449    db.addPackets(parseJSONLines(sniffer_fp), window_size)
450
451    cursor.execute("SELECT count(*) count FROM packets WHERE probe_id is NULL")
452    unmatched = cursor.fetchone()['count']
453    if unmatched > 0:
454        sys.stderr.write("WARNING: %d observed packets didn't find a home...\n" % unmatched)
455 
456    return None
457
458
459def enumStoredTestCases(db):
460    cursor = db.conn.cursor()
461    cursor.execute("SELECT test_case FROM probes GROUP BY test_case")
462    return [tc[0] for tc in cursor]
463
464
[16]465def findUnusualTestCase(db, trim=None):
[4]466    test_cases = enumStoredTestCases(db)
[16]467    if trim != None:
468        params = {'strim':trim[0], 'rtrim':trim[1]}
469        qsuffix = " AND sent_trimmed=:strim AND rcvd_trimmed=:rtrim"
470        table = "trim_analysis"
471    else:
472        params = {}
473        qsuffix = ""
474        table = "analysis"
475   
[4]476    cursor = db.conn.cursor()
[16]477    cursor.execute("SELECT packet_rtt FROM probes,"+table+" a WHERE probes.id=a.probe_id AND probes.type in ('train','test')"+qsuffix, params)
[10]478    global_tm = quadsummary([row['packet_rtt'] for row in cursor])
[4]479
480    tm_abs = []
481    tm_map = {}
[16]482
[4]483    # XXX: if more speed needed, percentile extension to sqlite might be handy...
484    for tc in test_cases:
[16]485        params['test_case']=tc
486        query = """SELECT packet_rtt FROM probes,"""+table+""" a
487                   WHERE probes.id=a.probe_id AND probes.type in ('train','test')
488                   AND probes.test_case=:test_case""" + qsuffix
489        cursor.execute(query, params)
[10]490        tm_map[tc] = quadsummary([row['packet_rtt'] for row in cursor])
[4]491        tm_abs.append((abs(tm_map[tc]-global_tm), tc))
492
493    magnitude,tc = max(tm_abs)
[16]494    params['test_case']=tc
495    query = """SELECT packet_rtt FROM probes,"""+table+""" a
496               WHERE probes.id=a.probe_id AND probes.type in ('train','test')
497               AND probes.test_case<>:test_case""" + qsuffix
498    cursor.execute(query,params)
[10]499    remaining_tm = quadsummary([row['packet_rtt'] for row in cursor])
[4]500
[16]501    delta = tm_map[tc]-remaining_tm
502    # Hack to make the chosen unusual_case more intuitive to the user
503    if len(test_cases) == 2 and delta < 0.0:
504        tc = [t for t in test_cases if t != tc][0]
505        delta = abs(delta)
[4]506
[16]507    return tc,delta
[4]508
[16]509
[4]510def reportProgress(db, sample_types, start_time):
511    cursor = db.conn.cursor()
512    output = ''
513    total_completed = 0
514    total_requested = 0
515    for st in sample_types:
516        cursor.execute("SELECT count(id) c FROM (SELECT id FROM probes WHERE type=? AND time_of_day>? GROUP BY sample)", (st[0],int(start_time*1000000000)))
517        count = cursor.fetchone()[0]
[16]518        output += " | %s remaining: %6d" % (st[0], st[1]-count)
[4]519        total_completed += count
520        total_requested += st[1]
521
522    rate = total_completed / (time.time() - start_time)
[16]523    total_time = total_requested / rate
[4]524    eta = datetime.datetime.fromtimestamp(start_time+total_time)
[16]525    print("STATUS:",output[3:],"| est. total_time: %s | ETA: %s" % (str(datetime.timedelta(seconds=total_time)), eta.strftime("%Y-%m-%d %X")))
526
527
528
529def evaluateTestResults(db):
530    cursor = db.conn.cursor()
531    query = """
532      SELECT classifier FROM classifier_results GROUP BY classifier ORDER BY classifier;
533    """
534    cursor.execute(query)
535    classifiers = []
536    for c in cursor:
537        classifiers.append(c[0])
538
539    best_obs = []
540    best_error = []
541    max_obs = 0
542    for classifier in classifiers:
543        query="""
544        SELECT classifier,params,num_observations,(false_positives+false_negatives)/2 error
545        FROM classifier_results
546        WHERE trial_type='test'
547         AND classifier=:classifier
548         AND (false_positives+false_negatives)/2.0 < 5.0
549        ORDER BY num_observations,(false_positives+false_negatives)
550        LIMIT 1
551        """
552        cursor.execute(query, {'classifier':classifier})
553        row = cursor.fetchone()
554        if row == None:
555            query="""
556            SELECT classifier,params,num_observations,(false_positives+false_negatives)/2 error
557            FROM classifier_results
558            WHERE trial_type='test' and classifier=:classifier
559            ORDER BY (false_positives+false_negatives),num_observations
560            LIMIT 1
561            """
562            cursor.execute(query, {'classifier':classifier})
563            row = cursor.fetchone()
564            if row == None:
565                sys.stderr.write("WARN: couldn't find test results for classifier '%s'.\n" % classifier)
566                continue
567            row = dict(row)
568
569            best_error.append(dict(row))
570        else:
571            best_obs.append(dict(row))
572
573
574    return best_obs,best_error
Note: See TracBrowser for help on using the repository browser.