source: trunk/lib/nanownlib/__init__.py @ 12

Last change on this file since 12 was 11, checked in by tim, 9 years ago

.

File size: 16.4 KB
RevLine 
[4]1#!/usr/bin/env python3
2#-*- mode: Python;-*-
3
4import sys
5import time
[10]6import traceback
[4]7import random
8import argparse
9import socket
10import datetime
11import http.client
12import threading
13import queue
14import subprocess
15import multiprocessing
16import csv
17import json
18import gzip
19import statistics
20import numpy
21import netifaces
22try:
23    import requests
24except:
25    sys.stderr.write('ERROR: Could not import requests module.  Ensure it is installed.\n')
26    sys.stderr.write('       Under Debian, the package name is "python3-requests"\n.')
27    sys.exit(1)
28
29from .stats import *
30
31
32def getLocalIP(remote_host, remote_port):
33    connection = socket.create_connection((remote_host, remote_port))
34    ret_val = connection.getsockname()[0]
35    connection.close()
36
37    return ret_val
38
39
40def getIfaceForIP(ip):
41    for iface in netifaces.interfaces():
42        addrs = netifaces.ifaddresses(iface).get(netifaces.AF_INET, None)
43        if addrs:
44            for a in addrs:
45                if a.get('addr', None) == ip:
46                    return iface
47
48
49def setTCPTimestamps(enabled=True):
50    fh = open('/proc/sys/net/ipv4/tcp_timestamps', 'r+b')
51    ret_val = False
52    if fh.read(1) == b'1':
53        ret_val = True
54
55    fh.seek(0)
56    if enabled:
57        fh.write(b'1')
58    else:
59        fh.write(b'0')
60    fh.close()
61   
62    return ret_val
63
64
65def trickleHTTPRequest(ip,port,hostname):
66    my_port = None
67    try:
68        sock = socket.create_connection((ip, port))
69        my_port = sock.getsockname()[1]
70       
71        #print('.')
72        sock.sendall(b'GET / HTTP/1.1\r\n')
73        time.sleep(0.5)
74        rest = b'''Host: '''+hostname.encode('utf-8')+b'''\r\nUser-Agent: Secret Agent Man\r\nX-Extra: extra read all about it!\r\nConnection: close\r\n'''
75        for r in rest:
76            sock.sendall(bytearray([r]))
77            time.sleep(0.05)
78
79        time.sleep(0.5)
80        sock.sendall('\r\n')
81
82        r = None
83        while r != b'':
84            r = sock.recv(16)
85
86        sock.close()
87    except Exception as e:
88        pass
89
90    return my_port
91
92
93def runTimestampProbes(host_ip, port, hostname, num_trials, concurrency=4): 
94    myq = queue.Queue()
95    def threadWrapper(*args):
96        try:
97            myq.put(trickleHTTPRequest(*args))
98        except Exception as e:
99            sys.stderr.write("ERROR from trickleHTTPRequest: %s\n" % repr(e))
100            myq.put(None)
101
102    threads = []
103    ports = []
104    for i in range(num_trials):
105        if len(threads) >= concurrency:
106            ports.append(myq.get())
107        t = threading.Thread(target=threadWrapper, args=(host_ip, port, hostname))
108        t.start()
109        threads.append(t)
110
111    for t in threads:
112        t.join()
113
114    while myq.qsize() > 0:
115        ports.append(myq.get())
116
117    return ports
118
119
120def computeTimestampPrecision(sniffer_fp, ports):
121    rcvd = []
122    for line in sniffer_fp:
123        p = json.loads(line)
124        if p['sent']==0:
125            rcvd.append((p['observed'],p['tsval'],int(p['local_port'])))
126
127    slopes = []
128    for port in ports:
129        trcvd = [tr for tr in rcvd if tr[2]==port and tr[1]!=0]
130
131        if len(trcvd) < 2:
132            sys.stderr.write("WARN: Inadequate data points.\n")
133            continue
134       
135        if trcvd[0][1] > trcvd[-1][1]:
136            sys.stderr.write("WARN: TSval wrap.\n")
137            continue
138
139        x = [tr[1] for tr in trcvd]
140        y = [tr[0] for tr in trcvd]
141
142        slope,intercept = OLSRegression(x, y)
143        slopes.append(slope)
144
145    if len(slopes) == 0:
146        return None,None,None
147
148    m = statistics.mean(slopes)
149    if len(slopes) == 1:
150        return (m, None, slopes)
151    else:
152        return (m, statistics.stdev(slopes), slopes)
153
154   
155def OLSRegression(x,y):
156    #print(x,y)
157    x = numpy.array(x)
158    y = numpy.array(y)
159    #A = numpy.vstack([x, numpy.ones(len(x))]).T
160    #m, c = numpy.linalg.lstsq(A, y)[0] # broken
161    #c,m = numpy.polynomial.polynomial.polyfit(x, y, 1) # less accurate
162    c,m = numpy.polynomial.Polynomial.fit(x,y,1).convert().coef
163
164    #print(m,c)
165
166    #import matplotlib.pyplot as plt
167    #plt.clf()
168    #plt.scatter(x, y)
169    #plt.plot(x, m*x + c, 'r', label='Fitted line')
170    #plt.show()
171   
172    return (m,c)
173
174
175def startSniffer(target_ip, target_port, output_file):
176    my_ip = getLocalIP(target_ip, target_port)
177    my_iface = getIfaceForIP(my_ip)
178    return subprocess.Popen(['chrt', '-r', '99', './bin/csamp', my_iface, my_ip,
179                             target_ip, "%d" % target_port, output_file, '0'])
180
181def stopSniffer(sniffer):
182    sniffer.terminate()
183    sniffer.wait(2)
184    if sniffer.poll() == None:
185        sniffer.kill()
186        sniffer.wait(1)
187
188       
189def setCPUAffinity():
190    import ctypes
191    from ctypes import cdll,c_int,byref
192    cpus = multiprocessing.cpu_count()
193   
194    libc = cdll.LoadLibrary("libc.so.6")
195    #libc.sched_setaffinity(os.getpid(), 1, ctypes.byref(ctypes.c_int(0x01)))
196    return libc.sched_setaffinity(0, 4, byref(c_int(0x00000001<<(cpus-1))))
197
198
199# Monkey patching that instruments the HTTPResponse to collect connection source port info
200class MonitoredHTTPResponse(http.client.HTTPResponse):
201    local_address = None
202
203    def __init__(self, sock, *args, **kwargs):
204        self.local_address = sock.getsockname()
205        super(MonitoredHTTPResponse, self).__init__(sock,*args,**kwargs)
206           
207requests.packages.urllib3.connection.HTTPConnection.response_class = MonitoredHTTPResponse
208
209
210def removeDuplicatePackets(packets):
211    #return packets
[10]212    suspect = ''
[4]213    seen = {}
214    # XXX: Need to review this deduplication algorithm and make sure it is correct
215    for p in packets:
216        key = (p['sent'],p['tcpseq'],p['tcpack'],p['payload_len'])
[10]217        if (key not in seen):
[4]218            seen[key] = p
[10]219            continue
220        if p['sent']==1 and (seen[key]['observed'] > p['observed']): #earliest sent
221            seen[key] = p
[11]222            suspect += 's' # duplicated sent packets
[10]223            continue 
224        if p['sent']==0 and (seen[key]['observed'] > p['observed']): #earliest rcvd
225            seen[key] = p
[11]226            suspect += 'r' # duplicated received packets
[10]227            continue
[4]228   
[10]229    #if len(seen) < len(packets):
230    #   sys.stderr.write("INFO: removed %d duplicate packets.\n" % (len(packets) - len(seen)))
[4]231
232    return suspect,seen.values()
233
234
235def analyzePackets(packets, timestamp_precision, trim_sent=0, trim_rcvd=0):
236    suspect,packets = removeDuplicatePackets(packets)
237
[11]238    sort_key = lambda d: (d['observed'],d['tcpseq'])
239    alt_key = lambda d: (d['tcpseq'],d['observed'])
[4]240    sent = sorted((p for p in packets if p['sent']==1 and p['payload_len']>0), key=sort_key)
241    rcvd = sorted((p for p in packets if p['sent']==0 and p['payload_len']>0), key=sort_key)
[10]242    rcvd_alt = sorted((p for p in packets if p['sent']==0 and p['payload_len']>0), key=alt_key)
[4]243
[10]244    s_off = trim_sent
245    if s_off >= len(sent):
[11]246        suspect += 'd' # dropped packet?
[10]247        s_off = -1
248    last_sent = sent[s_off]
249
250    r_off = len(rcvd) - trim_rcvd - 1
[11]251    if r_off < 0:
252        suspect += 'd' # dropped packet?
[10]253        r_off = 0
254    last_rcvd = rcvd[r_off]
255    if last_rcvd != rcvd_alt[r_off]:
[11]256        suspect += 'R' # reordered received packets
[4]257   
258    packet_rtt = last_rcvd['observed'] - last_sent['observed']
259    if packet_rtt < 0:
260        sys.stderr.write("WARN: Negative packet_rtt. last_rcvd=%s,last_sent=%s\n" % (last_rcvd, last_sent))
261
262    last_sent_ack = None
263    try:
264        last_sent_ack = min(((p['observed'],p) for p in packets
265                             if p['sent']==0 and p['payload_len']+last_sent['tcpseq']==p['tcpack']))[1]
[11]266       
[4]267    except Exception as e:
268        sys.stderr.write("WARN: Could not find last_sent_ack.\n")
269
270    tsval_rtt = None
271    if None not in (timestamp_precision, last_sent_ack):
272        tsval_rtt = int(round((last_rcvd['tsval'] - last_sent_ack['tsval'])*timestamp_precision))
273
274    return {'packet_rtt':packet_rtt,
275            'tsval_rtt':tsval_rtt,
276            'suspect':suspect,
277            'sent_trimmed':trim_sent,
278            'rcvd_trimmed':trim_rcvd},len(sent),len(rcvd)
279
280
281# trimean and mad for each dist of differences
282def evaluateTrim(db, unusual_case, strim, rtrim):
283    cursor = db.conn.cursor()
284    query="""
285      SELECT packet_rtt-(SELECT avg(packet_rtt) FROM probes,trim_analysis
[10]286                         WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case!=:unusual_case AND sample=u.s AND probes.type in ('train','test'))
287      FROM (SELECT probes.sample s,packet_rtt FROM probes,trim_analysis WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case=:unusual_case AND probes.type in ('train','test') AND 1 NOT IN (select 1 from probes p,trim_analysis t WHERE p.sample=s AND t.probe_id=p.id AND t.suspect LIKE '%R%')) u
[4]288    """
[10]289    query="""
290      SELECT packet_rtt-(SELECT avg(packet_rtt) FROM probes,trim_analysis
291                         WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case!=:unusual_case AND sample=u.s AND probes.type in ('train','test'))
292      FROM (SELECT probes.sample s,packet_rtt FROM probes,trim_analysis WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case=:unusual_case AND probes.type in ('train','test')) u
293    """
[4]294
295    params = {"strim":strim,"rtrim":rtrim,"unusual_case":unusual_case}
296    cursor.execute(query, params)
297    differences = [row[0] for row in cursor]
298   
[10]299    return ubersummary(differences),mad(differences)
[4]300
301
302
303def analyzeProbes(db):
304    db.conn.execute("CREATE INDEX IF NOT EXISTS packets_probe ON packets (probe_id)")
[10]305    db.conn.commit()
[4]306
[11]307    pcursor = db.conn.cursor()
[4]308    pcursor.execute("SELECT tcpts_mean FROM meta")
309    try:
310        timestamp_precision = pcursor.fetchone()[0]
311    except:
312        timestamp_precision = None
313   
314    pcursor.execute("DELETE FROM trim_analysis")
315    db.conn.commit()
[10]316
317    def loadPackets(db):
318        cursor = db.conn.cursor()
319        cursor.execute("SELECT * FROM packets ORDER BY probe_id")
320
321        probe_id = None
322        entry = []
323        ret_val = []
324        for p in cursor:
325            if probe_id == None:
326                probe_id = p['probe_id']
327            if p['probe_id'] != probe_id:
328                ret_val.append((probe_id,entry))
329                probe_id = p['probe_id']
330                entry = []
331            entry.append(dict(p))
332        ret_val.append((probe_id,entry))
333        return ret_val
[4]334   
[10]335    start = time.time()
336    packet_cache = loadPackets(db)
337    print("packets loaded in: %f" % (time.time()-start))
338   
[4]339    count = 0
340    sent_tally = []
341    rcvd_tally = []
[10]342    for probe_id,packets in packet_cache:
[4]343        try:
[10]344            analysis,s,r = analyzePackets(packets, timestamp_precision)
345            analysis['probe_id'] = probe_id
[4]346            sent_tally.append(s)
347            rcvd_tally.append(r)
[10]348            db.addTrimAnalyses([analysis])
[4]349        except Exception as e:
[11]350            #traceback.print_exc()
[10]351            sys.stderr.write("WARN: couldn't find enough packets for probe_id=%s\n" % probe_id)
[4]352       
353        #print(pid,analysis)
354        count += 1
355    db.conn.commit()
356    num_sent = statistics.mode(sent_tally)
357    num_rcvd = statistics.mode(rcvd_tally)
358    sent_tally = None
359    rcvd_tally = None
360    print("num_sent: %d, num_rcvd: %d" % (num_sent,num_rcvd))
361   
362    for strim in range(0,num_sent):
363        for rtrim in range(0,num_rcvd):
364            if strim == 0 and rtrim == 0:
365                continue # no point in doing 0,0 again
[10]366            for probe_id,packets in packet_cache:
[4]367                try:
[10]368                    analysis,s,r = analyzePackets(packets, timestamp_precision, strim, rtrim)
369                    analysis['probe_id'] = probe_id
[4]370                except Exception as e:
[11]371                    #traceback.print_exc()
372                    sys.stderr.write("WARN: couldn't find enough packets for probe_id=%s\n" % probe_id)
[4]373                   
374                db.addTrimAnalyses([analysis])
[10]375    db.conn.commit()
[4]376
377    # Populate analysis table so findUnusualTestCase can give us a starting point
378    pcursor.execute("DELETE FROM analysis")
379    db.conn.commit()
380    pcursor.execute("INSERT INTO analysis SELECT id,probe_id,suspect,packet_rtt,tsval_rtt FROM trim_analysis WHERE sent_trimmed=0 AND rcvd_trimmed=0")
381   
382    unusual_case,delta = findUnusualTestCase(db)
383    evaluations = {}
384    for strim in range(0,num_sent):
385        for rtrim in range(0,num_rcvd):
386            evaluations[(strim,rtrim)] = evaluateTrim(db, unusual_case, strim, rtrim)
387
388    import pprint
389    pprint.pprint(evaluations)
390
[6]391    delta_margin = 0.15
[4]392    best_strim = 0
393    best_rtrim = 0
394    good_delta,good_mad = evaluations[(0,0)]
395   
396    for strim in range(1,num_sent):
397        delta,mad = evaluations[(strim,0)]
[10]398        if delta*good_delta > 0.0 and (abs(good_delta) - abs(delta)) < abs(delta_margin*good_delta) and mad < good_mad:
[4]399            best_strim = strim
400        else:
401            break
402
403    good_delta,good_mad = evaluations[(best_strim,0)]
404    for rtrim in range(1,num_rcvd):
405        delta,mad = evaluations[(best_strim,rtrim)]
[10]406        if delta*good_delta > 0.0 and (abs(good_delta) - abs(delta)) < abs(delta_margin*good_delta) and mad < good_mad:
[4]407            best_rtrim = rtrim
408        else:
409            break
410
411    print("selected trim parameters:",(best_strim,best_rtrim))
412   
413    if best_strim != 0 or best_rtrim !=0:
414        pcursor.execute("DELETE FROM analysis")
415        db.conn.commit()
416        pcursor.execute("INSERT INTO analysis SELECT id,probe_id,suspect,packet_rtt,tsval_rtt FROM trim_analysis WHERE sent_trimmed=? AND rcvd_trimmed=?",
417                        (best_strim,best_rtrim))
[5]418
419    #pcursor.execute("DELETE FROM trim_analysis")
420    db.conn.commit()
[4]421   
422    return count
423
424
425       
426def parseJSONLines(fp):
427    for line in fp:
428        yield json.loads(line)
429
430
431def associatePackets(sniffer_fp, db):
432    sniffer_fp.seek(0)
433
434    # now combine sampler data with packet data
435    buffered = []
436
437    cursor = db.conn.cursor()
438    cursor.execute("SELECT count(*) count,min(time_of_day) start,max(time_of_day+userspace_rtt) end from probes")
439    ptimes = cursor.fetchone()
440    window_size = 100*int((ptimes['end']-ptimes['start'])/ptimes['count'])
441    print("associate window_size:", window_size)
442
443    db.addPackets(parseJSONLines(sniffer_fp), window_size)
444
445    cursor.execute("SELECT count(*) count FROM packets WHERE probe_id is NULL")
446    unmatched = cursor.fetchone()['count']
447    if unmatched > 0:
448        sys.stderr.write("WARNING: %d observed packets didn't find a home...\n" % unmatched)
449 
450    return None
451
452
453def enumStoredTestCases(db):
454    cursor = db.conn.cursor()
455    cursor.execute("SELECT test_case FROM probes GROUP BY test_case")
456    return [tc[0] for tc in cursor]
457
458
459def findUnusualTestCase(db):
460    test_cases = enumStoredTestCases(db)
461
462    cursor = db.conn.cursor()
463    cursor.execute("SELECT packet_rtt FROM probes,analysis WHERE probes.id=analysis.probe_id AND probes.type in ('train','test')")
[10]464    global_tm = quadsummary([row['packet_rtt'] for row in cursor])
[4]465
466    tm_abs = []
467    tm_map = {}
468    # XXX: if more speed needed, percentile extension to sqlite might be handy...
469    for tc in test_cases:
470        cursor.execute("SELECT packet_rtt FROM probes,analysis WHERE probes.id=analysis.probe_id AND probes.type in ('train','test') AND probes.test_case=?", (tc,))
[10]471        tm_map[tc] = quadsummary([row['packet_rtt'] for row in cursor])
[4]472        tm_abs.append((abs(tm_map[tc]-global_tm), tc))
473
474    magnitude,tc = max(tm_abs)
475    cursor.execute("SELECT packet_rtt FROM probes,analysis WHERE probes.id=analysis.probe_id AND probes.type in ('train','test') AND probes.test_case<>?", (tc,))
[10]476    remaining_tm = quadsummary([row['packet_rtt'] for row in cursor])
[4]477
[6]478    ret_val = (tc, tm_map[tc]-remaining_tm)
479    print("unusual_case: %s, delta: %f" % ret_val)
480    return ret_val
[4]481
482
483def reportProgress(db, sample_types, start_time):
484    cursor = db.conn.cursor()
485    output = ''
486    total_completed = 0
487    total_requested = 0
488    for st in sample_types:
489        cursor.execute("SELECT count(id) c FROM (SELECT id FROM probes WHERE type=? AND time_of_day>? GROUP BY sample)", (st[0],int(start_time*1000000000)))
490        count = cursor.fetchone()[0]
491        output += " | %s remaining: %d" % (st[0], st[1]-count)
492        total_completed += count
493        total_requested += st[1]
494
495    rate = total_completed / (time.time() - start_time)
496    total_time = total_requested / rate       
497    eta = datetime.datetime.fromtimestamp(start_time+total_time)
498    print("STATUS:",output[3:],"| est. total_time: %s | est. ETA: %s" % (str(datetime.timedelta(seconds=total_time)), str(eta)))
Note: See TracBrowser for help on using the repository browser.