source: trunk/lib/nanownlib/__init__.py @ 17

Last change on this file since 17 was 16, checked in by tim, 9 years ago

.

File size: 18.9 KB
Line 
1#!/usr/bin/env python3
2#-*- mode: Python;-*-
3
4import sys
5import time
6import traceback
7import random
8import argparse
9import socket
10import datetime
11import http.client
12import threading
13import queue
14import subprocess
15import multiprocessing
16import csv
17import json
18import gzip
19import statistics
20try:
21    import numpy
22except:
23    sys.stderr.write('ERROR: Could not import numpy module.  Ensure it is installed.\n')
24    sys.stderr.write('       Under Debian, the package name is "python3-numpy"\n.')
25    sys.exit(1)
26
27try:
28    import requests
29except:
30    sys.stderr.write('ERROR: Could not import requests module.  Ensure it is installed.\n')
31    sys.stderr.write('       Under Debian, the package name is "python3-requests"\n.')
32    sys.exit(1)
33
34from .stats import *
35
36
37def getLocalIP(remote_host, remote_port):
38    connection = socket.create_connection((remote_host, remote_port))
39    ret_val = connection.getsockname()[0]
40    connection.close()
41
42    return ret_val
43
44
45def getIfaceForIP(ip):
46    try:
47        import netifaces
48    except:
49        sys.stderr.write('ERROR: Could not import netifaces module.  Ensure it is installed.\n')
50        sys.stderr.write('       Try: pip3 install netifaces\n.')
51        sys.exit(1)
52   
53    for iface in netifaces.interfaces():
54        addrs = netifaces.ifaddresses(iface).get(netifaces.AF_INET, None)
55        if addrs:
56            for a in addrs:
57                if a.get('addr', None) == ip:
58                    return iface
59
60
61def setTCPTimestamps(enabled=True):
62    fh = open('/proc/sys/net/ipv4/tcp_timestamps', 'r+b')
63    ret_val = False
64    if fh.read(1) == b'1':
65        ret_val = True
66
67    fh.seek(0)
68    if enabled:
69        fh.write(b'1')
70    else:
71        fh.write(b'0')
72    fh.close()
73   
74    return ret_val
75
76
77def trickleHTTPRequest(ip,port,hostname):
78    my_port = None
79    try:
80        sock = socket.create_connection((ip, port))
81        my_port = sock.getsockname()[1]
82       
83        #print('.')
84        sock.sendall(b'GET / HTTP/1.1\r\n')
85        time.sleep(0.5)
86        rest = b'''Host: '''+hostname.encode('utf-8')+b'''\r\nUser-Agent: Secret Agent Man\r\nX-Extra: extra read all about it!\r\nConnection: close\r\n'''
87        for r in rest:
88            sock.sendall(bytearray([r]))
89            time.sleep(0.05)
90
91        time.sleep(0.5)
92        sock.sendall('\r\n')
93
94        r = None
95        while r != b'':
96            r = sock.recv(16)
97
98        sock.close()
99    except Exception as e:
100        pass
101
102    return my_port
103
104
105def runTimestampProbes(host_ip, port, hostname, num_trials, concurrency=4): 
106    myq = queue.Queue()
107    def threadWrapper(*args):
108        try:
109            myq.put(trickleHTTPRequest(*args))
110        except Exception as e:
111            sys.stderr.write("ERROR from trickleHTTPRequest: %s\n" % repr(e))
112            myq.put(None)
113
114    threads = []
115    ports = []
116    for i in range(num_trials):
117        if len(threads) >= concurrency:
118            ports.append(myq.get())
119        t = threading.Thread(target=threadWrapper, args=(host_ip, port, hostname))
120        t.start()
121        threads.append(t)
122
123    for t in threads:
124        t.join()
125
126    while myq.qsize() > 0:
127        ports.append(myq.get())
128
129    return ports
130
131
132def computeTimestampPrecision(sniffer_fp, ports):
133    rcvd = []
134    for line in sniffer_fp:
135        p = json.loads(line)
136        if p['sent']==0:
137            rcvd.append((p['observed'],p['tsval'],int(p['local_port'])))
138
139    slopes = []
140    for port in ports:
141        trcvd = [tr for tr in rcvd if tr[2]==port and tr[1]!=0]
142
143        if len(trcvd) < 2:
144            sys.stderr.write("WARN: Inadequate data points.\n")
145            continue
146       
147        if trcvd[0][1] > trcvd[-1][1]:
148            sys.stderr.write("WARN: TSval wrap.\n")
149            continue
150
151        x = [tr[1] for tr in trcvd]
152        y = [tr[0] for tr in trcvd]
153
154        slope,intercept = OLSRegression(x, y)
155        slopes.append(slope)
156
157    if len(slopes) == 0:
158        return None,None,None
159
160    m = statistics.mean(slopes)
161    if len(slopes) == 1:
162        return (m, None, slopes)
163    else:
164        return (m, statistics.stdev(slopes), slopes)
165
166   
167def OLSRegression(x,y):
168    #print(x,y)
169    x = numpy.array(x)
170    y = numpy.array(y)
171    #A = numpy.vstack([x, numpy.ones(len(x))]).T
172    #m, c = numpy.linalg.lstsq(A, y)[0] # broken
173    #c,m = numpy.polynomial.polynomial.polyfit(x, y, 1) # less accurate
174    c,m = numpy.polynomial.Polynomial.fit(x,y,1).convert().coef
175
176    #print(m,c)
177
178    #import matplotlib.pyplot as plt
179    #plt.clf()
180    #plt.scatter(x, y)
181    #plt.plot(x, m*x + c, 'r', label='Fitted line')
182    #plt.show()
183   
184    return (m,c)
185
186
187def startSniffer(target_ip, target_port, output_file):
188    my_ip = getLocalIP(target_ip, target_port)
189    my_iface = getIfaceForIP(my_ip)
190    return subprocess.Popen(['chrt', '-r', '99', 'nanown-listen', my_iface, my_ip,
191                             target_ip, "%d" % target_port, output_file, '0'])
192
193def stopSniffer(sniffer):
194    sniffer.terminate()
195    sniffer.wait(2)
196    if sniffer.poll() == None:
197        sniffer.kill()
198        sniffer.wait(1)
199
200       
201def setCPUAffinity():
202    import ctypes
203    from ctypes import cdll,c_int,byref
204    cpus = multiprocessing.cpu_count()
205   
206    libc = cdll.LoadLibrary("libc.so.6")
207    #libc.sched_setaffinity(os.getpid(), 1, ctypes.byref(ctypes.c_int(0x01)))
208    return libc.sched_setaffinity(0, 4, byref(c_int(0x00000001<<(cpus-1))))
209
210
211# Monkey patching that instruments the HTTPResponse to collect connection source port info
212class MonitoredHTTPResponse(http.client.HTTPResponse):
213    local_address = None
214
215    def __init__(self, sock, *args, **kwargs):
216        self.local_address = sock.getsockname()
217        super(MonitoredHTTPResponse, self).__init__(sock,*args,**kwargs)
218           
219requests.packages.urllib3.connection.HTTPConnection.response_class = MonitoredHTTPResponse
220
221
222def removeDuplicatePackets(packets):
223    #return packets
224    suspect = ''
225    seen = {}
226    # XXX: Need to review this deduplication algorithm and make sure it is correct
227    for p in packets:
228        key = (p['sent'],p['tcpseq'],p['tcpack'],p['payload_len'])
229        if (key not in seen):
230            seen[key] = p
231            continue
232        if p['sent']==1 and (seen[key]['observed'] > p['observed']): #earliest sent
233            seen[key] = p
234            suspect += 's' # duplicated sent packets
235            continue 
236        if p['sent']==0 and (seen[key]['observed'] > p['observed']): #earliest rcvd
237            seen[key] = p
238            suspect += 'r' # duplicated received packets
239            continue
240   
241    #if len(seen) < len(packets):
242    #   sys.stderr.write("INFO: removed %d duplicate packets.\n" % (len(packets) - len(seen)))
243
244    return suspect,seen.values()
245
246
247def analyzePackets(packets, timestamp_precision, trim_sent=0, trim_rcvd=0):
248    suspect,packets = removeDuplicatePackets(packets)
249
250    sort_key = lambda d: (d['observed'],d['tcpseq'])
251    alt_key = lambda d: (d['tcpseq'],d['observed'])
252    sent = sorted((p for p in packets if p['sent']==1 and p['payload_len']>0), key=sort_key)
253    rcvd = sorted((p for p in packets if p['sent']==0 and p['payload_len']>0), key=sort_key)
254    rcvd_alt = sorted((p for p in packets if p['sent']==0 and p['payload_len']>0), key=alt_key)
255
256    s_off = trim_sent
257    if s_off >= len(sent):
258        suspect += 'd' # dropped packet?
259        s_off = -1
260    last_sent = sent[s_off]
261
262    r_off = len(rcvd) - trim_rcvd - 1
263    if r_off < 0:
264        suspect += 'd' # dropped packet?
265        r_off = 0
266    last_rcvd = rcvd[r_off]
267    if last_rcvd != rcvd_alt[r_off]:
268        suspect += 'R' # reordered received packets
269   
270    last_sent_ack = None
271    try:
272        last_sent_ack = min(((p['tcpack'],p['observed'],p) for p in packets
273                             if p['sent']==0 and p['payload_len']+last_sent['tcpseq']>=p['tcpack']))[2]
274       
275    except Exception as e:
276        sys.stderr.write("WARN: Could not find last_sent_ack.\n")
277
278    packet_rtt = last_rcvd['observed'] - last_sent['observed']
279    tsval_rtt = None
280    if None not in (timestamp_precision, last_sent_ack):
281        tsval_rtt = int(round((last_rcvd['tsval'] - last_sent_ack['tsval'])*timestamp_precision))
282
283    if packet_rtt < 0 or (tsval_rtt != None and tsval_rtt < 0):
284        #sys.stderr.write("WARN: Negative packet or tsval RTT. last_rcvd=%s,last_sent=%s\n" % (last_rcvd, last_sent))
285        suspect += 'N'
286       
287    return {'packet_rtt':packet_rtt,
288            'tsval_rtt':tsval_rtt,
289            'suspect':suspect,
290            'sent_trimmed':trim_sent,
291            'rcvd_trimmed':trim_rcvd},len(sent),len(rcvd)
292
293
294# septasummary and mad for each dist of differences
295def evaluateTrim(db, unusual_case, strim, rtrim):
296    cursor = db.conn.cursor()
297    query="""
298      SELECT packet_rtt-(SELECT avg(packet_rtt) FROM probes,trim_analysis
299                         WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case!=:unusual_case AND sample=u.s AND probes.type in ('train','test'))
300      FROM (SELECT probes.sample s,packet_rtt FROM probes,trim_analysis WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case=:unusual_case AND probes.type in ('train','test') AND 1 NOT IN (select 1 from probes p,trim_analysis t WHERE p.sample=s AND t.probe_id=p.id AND t.suspect LIKE '%R%')) u
301    """
302    query="""
303      SELECT packet_rtt-(SELECT avg(packet_rtt) FROM probes,trim_analysis
304                         WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case!=:unusual_case AND sample=u.s AND probes.type in ('train','test'))
305      FROM (SELECT probes.sample s,packet_rtt FROM probes,trim_analysis WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case=:unusual_case AND probes.type in ('train','test')) u
306    """
307    #TODO: check for "N" in suspect field and return a flag
308   
309    params = {"strim":strim,"rtrim":rtrim,"unusual_case":unusual_case}
310    cursor.execute(query, params)
311    differences = [row[0] for row in cursor]
312   
313    return septasummary(differences),mad(differences)
314
315
316
317def analyzeProbes(db, trim=None, recompute=False):
318    db.conn.execute("CREATE INDEX IF NOT EXISTS packets_probe ON packets (probe_id)")
319    db.conn.commit()
320
321    pcursor = db.conn.cursor()
322    pcursor.execute("SELECT tcpts_mean FROM meta")
323    try:
324        timestamp_precision = pcursor.fetchone()[0]
325    except:
326        timestamp_precision = None
327   
328    pcursor.execute("DELETE FROM trim_analysis")
329    db.conn.commit()
330    if recompute:
331        pcursor.execute("DELETE FROM analysis")
332        db.conn.commit()
333
334    def loadPackets(db):
335        cursor = db.conn.cursor()
336        #cursor.execute("SELECT * FROM packets ORDER BY probe_id")
337        cursor.execute("SELECT * FROM packets WHERE probe_id NOT IN (SELECT probe_id FROM analysis) ORDER BY probe_id")
338
339        probe_id = None
340        entry = []
341        ret_val = []
342        for p in cursor:
343            if probe_id == None:
344                probe_id = p['probe_id']
345            if p['probe_id'] != probe_id:
346                ret_val.append((probe_id,entry))
347                probe_id = p['probe_id']
348                entry = []
349            entry.append(dict(p))
350        ret_val.append((probe_id,entry))
351        return ret_val
352
353    def processPackets(packet_cache, strim, rtrim):
354        sent_tally = []
355        rcvd_tally = []
356        analyses = []
357        for probe_id,packets in packet_cache:
358            try:
359                analysis,s,r = analyzePackets(packets, timestamp_precision)
360                analysis['probe_id'] = probe_id
361                analyses.append(analysis)
362                sent_tally.append(s)
363                rcvd_tally.append(r)
364            except Exception as e:
365                #traceback.print_exc()
366                sys.stderr.write("WARN: couldn't find enough packets for probe_id=%s\n" % probe_id)
367        db.addTrimAnalyses(analyses)
368        db.conn.commit()
369        return statistics.mode(sent_tally),statistics.mode(rcvd_tally)
370   
371    #start = time.time()
372    packet_cache = loadPackets(db)
373    #print("packets loaded in: %f" % (time.time()-start))
374
375    if trim != None:
376        best_strim,best_rtrim = trim
377        processPackets(packet_cache, best_strim, best_rtrim)
378    else:
379        num_sent,num_rcvd = processPackets(packet_cache, 0, 0)
380        print("num_sent: %d, num_rcvd: %d" % (num_sent,num_rcvd))
381   
382        for strim in range(0,num_sent):
383            for rtrim in range(0,num_rcvd):
384                #print(strim,rtrim)
385                if strim == 0 and rtrim == 0:
386                    continue # no point in doing 0,0 again
387                processPackets(packet_cache, strim, rtrim)
388
389   
390        unusual_case,delta = findUnusualTestCase(db, (0,0))
391        evaluations = {}
392        for strim in range(0,num_sent):
393            for rtrim in range(0,num_rcvd):
394                evaluations[(strim,rtrim)] = evaluateTrim(db, unusual_case, strim, rtrim)
395
396        import pprint
397        pprint.pprint(evaluations)
398
399        delta_margin = 0.15
400        best_strim = 0
401        best_rtrim = 0
402        good_delta,good_mad = evaluations[(0,0)]
403   
404        for strim in range(1,num_sent):
405            delta,mad = evaluations[(strim,0)]
406            if delta*good_delta > 0.0 and (abs(good_delta) - abs(delta)) < abs(delta_margin*good_delta) and mad < good_mad:
407                best_strim = strim
408            else:
409                break
410
411        good_delta,good_mad = evaluations[(best_strim,0)]
412        for rtrim in range(1,num_rcvd):
413            delta,mad = evaluations[(best_strim,rtrim)]
414            if delta*good_delta > 0.0 and (abs(good_delta) - abs(delta)) < abs(delta_margin*good_delta) and mad < good_mad:
415                best_rtrim = rtrim
416            else:
417                break
418
419        print("selected trim parameters:",(best_strim,best_rtrim))
420   
421    pcursor.execute("""INSERT OR IGNORE INTO analysis
422                         SELECT id,probe_id,suspect,packet_rtt,tsval_rtt
423                           FROM trim_analysis
424                           WHERE sent_trimmed=? AND rcvd_trimmed=?""",
425                    (best_strim,best_rtrim))
426    db.conn.commit()
427   
428    return len(packet_cache)
429
430
431       
432def parseJSONLines(fp):
433    for line in fp:
434        yield json.loads(line)
435
436
437def associatePackets(sniffer_fp, db):
438    sniffer_fp.seek(0)
439
440    # now combine sampler data with packet data
441    buffered = []
442
443    cursor = db.conn.cursor()
444    cursor.execute("SELECT count(*) count,min(time_of_day) start,max(time_of_day+userspace_rtt) end from probes")
445    ptimes = cursor.fetchone()
446    window_size = 100*int((ptimes['end']-ptimes['start'])/ptimes['count'])
447    #print("associate window_size:", window_size)
448
449    db.addPackets(parseJSONLines(sniffer_fp), window_size)
450
451    cursor.execute("SELECT count(*) count FROM packets WHERE probe_id is NULL")
452    unmatched = cursor.fetchone()['count']
453    if unmatched > 0:
454        sys.stderr.write("WARNING: %d observed packets didn't find a home...\n" % unmatched)
455 
456    return None
457
458
459def enumStoredTestCases(db):
460    cursor = db.conn.cursor()
461    cursor.execute("SELECT test_case FROM probes GROUP BY test_case")
462    return [tc[0] for tc in cursor]
463
464
465def findUnusualTestCase(db, trim=None):
466    test_cases = enumStoredTestCases(db)
467    if trim != None:
468        params = {'strim':trim[0], 'rtrim':trim[1]}
469        qsuffix = " AND sent_trimmed=:strim AND rcvd_trimmed=:rtrim"
470        table = "trim_analysis"
471    else:
472        params = {}
473        qsuffix = ""
474        table = "analysis"
475   
476    cursor = db.conn.cursor()
477    cursor.execute("SELECT packet_rtt FROM probes,"+table+" a WHERE probes.id=a.probe_id AND probes.type in ('train','test')"+qsuffix, params)
478    global_tm = quadsummary([row['packet_rtt'] for row in cursor])
479
480    tm_abs = []
481    tm_map = {}
482
483    # XXX: if more speed needed, percentile extension to sqlite might be handy...
484    for tc in test_cases:
485        params['test_case']=tc
486        query = """SELECT packet_rtt FROM probes,"""+table+""" a
487                   WHERE probes.id=a.probe_id AND probes.type in ('train','test')
488                   AND probes.test_case=:test_case""" + qsuffix
489        cursor.execute(query, params)
490        tm_map[tc] = quadsummary([row['packet_rtt'] for row in cursor])
491        tm_abs.append((abs(tm_map[tc]-global_tm), tc))
492
493    magnitude,tc = max(tm_abs)
494    params['test_case']=tc
495    query = """SELECT packet_rtt FROM probes,"""+table+""" a
496               WHERE probes.id=a.probe_id AND probes.type in ('train','test')
497               AND probes.test_case<>:test_case""" + qsuffix
498    cursor.execute(query,params)
499    remaining_tm = quadsummary([row['packet_rtt'] for row in cursor])
500
501    delta = tm_map[tc]-remaining_tm
502    # Hack to make the chosen unusual_case more intuitive to the user
503    if len(test_cases) == 2 and delta < 0.0:
504        tc = [t for t in test_cases if t != tc][0]
505        delta = abs(delta)
506
507    return tc,delta
508
509
510def reportProgress(db, sample_types, start_time):
511    cursor = db.conn.cursor()
512    output = ''
513    total_completed = 0
514    total_requested = 0
515    for st in sample_types:
516        cursor.execute("SELECT count(id) c FROM (SELECT id FROM probes WHERE type=? AND time_of_day>? GROUP BY sample)", (st[0],int(start_time*1000000000)))
517        count = cursor.fetchone()[0]
518        output += " | %s remaining: %6d" % (st[0], st[1]-count)
519        total_completed += count
520        total_requested += st[1]
521
522    rate = total_completed / (time.time() - start_time)
523    total_time = total_requested / rate
524    eta = datetime.datetime.fromtimestamp(start_time+total_time)
525    print("STATUS:",output[3:],"| est. total_time: %s | ETA: %s" % (str(datetime.timedelta(seconds=total_time)), eta.strftime("%Y-%m-%d %X")))
526
527
528
529def evaluateTestResults(db):
530    cursor = db.conn.cursor()
531    query = """
532      SELECT classifier FROM classifier_results GROUP BY classifier ORDER BY classifier;
533    """
534    cursor.execute(query)
535    classifiers = []
536    for c in cursor:
537        classifiers.append(c[0])
538
539    best_obs = []
540    best_error = []
541    max_obs = 0
542    for classifier in classifiers:
543        query="""
544        SELECT classifier,params,num_observations,(false_positives+false_negatives)/2 error
545        FROM classifier_results
546        WHERE trial_type='test'
547         AND classifier=:classifier
548         AND (false_positives+false_negatives)/2.0 < 5.0
549        ORDER BY num_observations,(false_positives+false_negatives)
550        LIMIT 1
551        """
552        cursor.execute(query, {'classifier':classifier})
553        row = cursor.fetchone()
554        if row == None:
555            query="""
556            SELECT classifier,params,num_observations,(false_positives+false_negatives)/2 error
557            FROM classifier_results
558            WHERE trial_type='test' and classifier=:classifier
559            ORDER BY (false_positives+false_negatives),num_observations
560            LIMIT 1
561            """
562            cursor.execute(query, {'classifier':classifier})
563            row = cursor.fetchone()
564            if row == None:
565                sys.stderr.write("WARN: couldn't find test results for classifier '%s'.\n" % classifier)
566                continue
567            row = dict(row)
568
569            best_error.append(dict(row))
570        else:
571            best_obs.append(dict(row))
572
573
574    return best_obs,best_error
Note: See TracBrowser for help on using the repository browser.