source: trunk/lib/nanownlib/__init__.py @ 13

Last change on this file since 13 was 13, checked in by tim, 9 years ago

.

File size: 16.6 KB
Line 
1#!/usr/bin/env python3
2#-*- mode: Python;-*-
3
4import sys
5import time
6import traceback
7import random
8import argparse
9import socket
10import datetime
11import http.client
12import threading
13import queue
14import subprocess
15import multiprocessing
16import csv
17import json
18import gzip
19import statistics
20import numpy
21import netifaces
22try:
23    import requests
24except:
25    sys.stderr.write('ERROR: Could not import requests module.  Ensure it is installed.\n')
26    sys.stderr.write('       Under Debian, the package name is "python3-requests"\n.')
27    sys.exit(1)
28
29from .stats import *
30
31
32def getLocalIP(remote_host, remote_port):
33    connection = socket.create_connection((remote_host, remote_port))
34    ret_val = connection.getsockname()[0]
35    connection.close()
36
37    return ret_val
38
39
40def getIfaceForIP(ip):
41    for iface in netifaces.interfaces():
42        addrs = netifaces.ifaddresses(iface).get(netifaces.AF_INET, None)
43        if addrs:
44            for a in addrs:
45                if a.get('addr', None) == ip:
46                    return iface
47
48
49def setTCPTimestamps(enabled=True):
50    fh = open('/proc/sys/net/ipv4/tcp_timestamps', 'r+b')
51    ret_val = False
52    if fh.read(1) == b'1':
53        ret_val = True
54
55    fh.seek(0)
56    if enabled:
57        fh.write(b'1')
58    else:
59        fh.write(b'0')
60    fh.close()
61   
62    return ret_val
63
64
65def trickleHTTPRequest(ip,port,hostname):
66    my_port = None
67    try:
68        sock = socket.create_connection((ip, port))
69        my_port = sock.getsockname()[1]
70       
71        #print('.')
72        sock.sendall(b'GET / HTTP/1.1\r\n')
73        time.sleep(0.5)
74        rest = b'''Host: '''+hostname.encode('utf-8')+b'''\r\nUser-Agent: Secret Agent Man\r\nX-Extra: extra read all about it!\r\nConnection: close\r\n'''
75        for r in rest:
76            sock.sendall(bytearray([r]))
77            time.sleep(0.05)
78
79        time.sleep(0.5)
80        sock.sendall('\r\n')
81
82        r = None
83        while r != b'':
84            r = sock.recv(16)
85
86        sock.close()
87    except Exception as e:
88        pass
89
90    return my_port
91
92
93def runTimestampProbes(host_ip, port, hostname, num_trials, concurrency=4): 
94    myq = queue.Queue()
95    def threadWrapper(*args):
96        try:
97            myq.put(trickleHTTPRequest(*args))
98        except Exception as e:
99            sys.stderr.write("ERROR from trickleHTTPRequest: %s\n" % repr(e))
100            myq.put(None)
101
102    threads = []
103    ports = []
104    for i in range(num_trials):
105        if len(threads) >= concurrency:
106            ports.append(myq.get())
107        t = threading.Thread(target=threadWrapper, args=(host_ip, port, hostname))
108        t.start()
109        threads.append(t)
110
111    for t in threads:
112        t.join()
113
114    while myq.qsize() > 0:
115        ports.append(myq.get())
116
117    return ports
118
119
120def computeTimestampPrecision(sniffer_fp, ports):
121    rcvd = []
122    for line in sniffer_fp:
123        p = json.loads(line)
124        if p['sent']==0:
125            rcvd.append((p['observed'],p['tsval'],int(p['local_port'])))
126
127    slopes = []
128    for port in ports:
129        trcvd = [tr for tr in rcvd if tr[2]==port and tr[1]!=0]
130
131        if len(trcvd) < 2:
132            sys.stderr.write("WARN: Inadequate data points.\n")
133            continue
134       
135        if trcvd[0][1] > trcvd[-1][1]:
136            sys.stderr.write("WARN: TSval wrap.\n")
137            continue
138
139        x = [tr[1] for tr in trcvd]
140        y = [tr[0] for tr in trcvd]
141
142        slope,intercept = OLSRegression(x, y)
143        slopes.append(slope)
144
145    if len(slopes) == 0:
146        return None,None,None
147
148    m = statistics.mean(slopes)
149    if len(slopes) == 1:
150        return (m, None, slopes)
151    else:
152        return (m, statistics.stdev(slopes), slopes)
153
154   
155def OLSRegression(x,y):
156    #print(x,y)
157    x = numpy.array(x)
158    y = numpy.array(y)
159    #A = numpy.vstack([x, numpy.ones(len(x))]).T
160    #m, c = numpy.linalg.lstsq(A, y)[0] # broken
161    #c,m = numpy.polynomial.polynomial.polyfit(x, y, 1) # less accurate
162    c,m = numpy.polynomial.Polynomial.fit(x,y,1).convert().coef
163
164    #print(m,c)
165
166    #import matplotlib.pyplot as plt
167    #plt.clf()
168    #plt.scatter(x, y)
169    #plt.plot(x, m*x + c, 'r', label='Fitted line')
170    #plt.show()
171   
172    return (m,c)
173
174
175def startSniffer(target_ip, target_port, output_file):
176    my_ip = getLocalIP(target_ip, target_port)
177    my_iface = getIfaceForIP(my_ip)
178    return subprocess.Popen(['chrt', '-r', '99', 'nanown-csamp', my_iface, my_ip,
179                             target_ip, "%d" % target_port, output_file, '0'])
180
181def stopSniffer(sniffer):
182    sniffer.terminate()
183    sniffer.wait(2)
184    if sniffer.poll() == None:
185        sniffer.kill()
186        sniffer.wait(1)
187
188       
189def setCPUAffinity():
190    import ctypes
191    from ctypes import cdll,c_int,byref
192    cpus = multiprocessing.cpu_count()
193   
194    libc = cdll.LoadLibrary("libc.so.6")
195    #libc.sched_setaffinity(os.getpid(), 1, ctypes.byref(ctypes.c_int(0x01)))
196    return libc.sched_setaffinity(0, 4, byref(c_int(0x00000001<<(cpus-1))))
197
198
199# Monkey patching that instruments the HTTPResponse to collect connection source port info
200class MonitoredHTTPResponse(http.client.HTTPResponse):
201    local_address = None
202
203    def __init__(self, sock, *args, **kwargs):
204        self.local_address = sock.getsockname()
205        super(MonitoredHTTPResponse, self).__init__(sock,*args,**kwargs)
206           
207requests.packages.urllib3.connection.HTTPConnection.response_class = MonitoredHTTPResponse
208
209
210def removeDuplicatePackets(packets):
211    #return packets
212    suspect = ''
213    seen = {}
214    # XXX: Need to review this deduplication algorithm and make sure it is correct
215    for p in packets:
216        key = (p['sent'],p['tcpseq'],p['tcpack'],p['payload_len'])
217        if (key not in seen):
218            seen[key] = p
219            continue
220        if p['sent']==1 and (seen[key]['observed'] > p['observed']): #earliest sent
221            seen[key] = p
222            suspect += 's' # duplicated sent packets
223            continue 
224        if p['sent']==0 and (seen[key]['observed'] > p['observed']): #earliest rcvd
225            seen[key] = p
226            suspect += 'r' # duplicated received packets
227            continue
228   
229    #if len(seen) < len(packets):
230    #   sys.stderr.write("INFO: removed %d duplicate packets.\n" % (len(packets) - len(seen)))
231
232    return suspect,seen.values()
233
234
235def analyzePackets(packets, timestamp_precision, trim_sent=0, trim_rcvd=0):
236    suspect,packets = removeDuplicatePackets(packets)
237
238    sort_key = lambda d: (d['observed'],d['tcpseq'])
239    alt_key = lambda d: (d['tcpseq'],d['observed'])
240    sent = sorted((p for p in packets if p['sent']==1 and p['payload_len']>0), key=sort_key)
241    rcvd = sorted((p for p in packets if p['sent']==0 and p['payload_len']>0), key=sort_key)
242    rcvd_alt = sorted((p for p in packets if p['sent']==0 and p['payload_len']>0), key=alt_key)
243
244    s_off = trim_sent
245    if s_off >= len(sent):
246        suspect += 'd' # dropped packet?
247        s_off = -1
248    last_sent = sent[s_off]
249
250    r_off = len(rcvd) - trim_rcvd - 1
251    if r_off < 0:
252        suspect += 'd' # dropped packet?
253        r_off = 0
254    last_rcvd = rcvd[r_off]
255    if last_rcvd != rcvd_alt[r_off]:
256        suspect += 'R' # reordered received packets
257   
258    last_sent_ack = None
259    try:
260        last_sent_ack = min(((p['tcpack'],p['observed'],p) for p in packets
261                             if p['sent']==0 and p['payload_len']+last_sent['tcpseq']>=p['tcpack']))[2]
262       
263    except Exception as e:
264        sys.stderr.write("WARN: Could not find last_sent_ack.\n")
265
266    packet_rtt = last_rcvd['observed'] - last_sent['observed']
267    tsval_rtt = None
268    if None not in (timestamp_precision, last_sent_ack):
269        tsval_rtt = int(round((last_rcvd['tsval'] - last_sent_ack['tsval'])*timestamp_precision))
270
271    if packet_rtt < 0 or (tsval_rtt != None and tsval_rtt < 0):
272        #sys.stderr.write("WARN: Negative packet or tsval RTT. last_rcvd=%s,last_sent=%s\n" % (last_rcvd, last_sent))
273        suspect += 'N'
274       
275    return {'packet_rtt':packet_rtt,
276            'tsval_rtt':tsval_rtt,
277            'suspect':suspect,
278            'sent_trimmed':trim_sent,
279            'rcvd_trimmed':trim_rcvd},len(sent),len(rcvd)
280
281
282# septasummary and mad for each dist of differences
283def evaluateTrim(db, unusual_case, strim, rtrim):
284    cursor = db.conn.cursor()
285    query="""
286      SELECT packet_rtt-(SELECT avg(packet_rtt) FROM probes,trim_analysis
287                         WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case!=:unusual_case AND sample=u.s AND probes.type in ('train','test'))
288      FROM (SELECT probes.sample s,packet_rtt FROM probes,trim_analysis WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case=:unusual_case AND probes.type in ('train','test') AND 1 NOT IN (select 1 from probes p,trim_analysis t WHERE p.sample=s AND t.probe_id=p.id AND t.suspect LIKE '%R%')) u
289    """
290    query="""
291      SELECT packet_rtt-(SELECT avg(packet_rtt) FROM probes,trim_analysis
292                         WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case!=:unusual_case AND sample=u.s AND probes.type in ('train','test'))
293      FROM (SELECT probes.sample s,packet_rtt FROM probes,trim_analysis WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case=:unusual_case AND probes.type in ('train','test')) u
294    """
295    #TODO: check for "N" in suspect field and return a flag
296   
297    params = {"strim":strim,"rtrim":rtrim,"unusual_case":unusual_case}
298    cursor.execute(query, params)
299    differences = [row[0] for row in cursor]
300   
301    return septasummary(differences),mad(differences)
302
303
304
305def analyzeProbes(db):
306    db.conn.execute("CREATE INDEX IF NOT EXISTS packets_probe ON packets (probe_id)")
307    db.conn.commit()
308
309    pcursor = db.conn.cursor()
310    pcursor.execute("SELECT tcpts_mean FROM meta")
311    try:
312        timestamp_precision = pcursor.fetchone()[0]
313    except:
314        timestamp_precision = None
315   
316    pcursor.execute("DELETE FROM trim_analysis")
317    db.conn.commit()
318
319    def loadPackets(db):
320        cursor = db.conn.cursor()
321        cursor.execute("SELECT * FROM packets ORDER BY probe_id")
322
323        probe_id = None
324        entry = []
325        ret_val = []
326        for p in cursor:
327            if probe_id == None:
328                probe_id = p['probe_id']
329            if p['probe_id'] != probe_id:
330                ret_val.append((probe_id,entry))
331                probe_id = p['probe_id']
332                entry = []
333            entry.append(dict(p))
334        ret_val.append((probe_id,entry))
335        return ret_val
336   
337    start = time.time()
338    packet_cache = loadPackets(db)
339    print("packets loaded in: %f" % (time.time()-start))
340   
341    count = 0
342    sent_tally = []
343    rcvd_tally = []
344    for probe_id,packets in packet_cache:
345        try:
346            analysis,s,r = analyzePackets(packets, timestamp_precision)
347            analysis['probe_id'] = probe_id
348            sent_tally.append(s)
349            rcvd_tally.append(r)
350            db.addTrimAnalyses([analysis])
351        except Exception as e:
352            #traceback.print_exc()
353            sys.stderr.write("WARN: couldn't find enough packets for probe_id=%s\n" % probe_id)
354       
355        #print(pid,analysis)
356        count += 1
357    db.conn.commit()
358    num_sent = statistics.mode(sent_tally)
359    num_rcvd = statistics.mode(rcvd_tally)
360    sent_tally = None
361    rcvd_tally = None
362    print("num_sent: %d, num_rcvd: %d" % (num_sent,num_rcvd))
363   
364    for strim in range(0,num_sent):
365        for rtrim in range(0,num_rcvd):
366            #print(strim,rtrim)
367            if strim == 0 and rtrim == 0:
368                continue # no point in doing 0,0 again
369            for probe_id,packets in packet_cache:
370                try:
371                    analysis,s,r = analyzePackets(packets, timestamp_precision, strim, rtrim)
372                    analysis['probe_id'] = probe_id
373                except Exception as e:
374                    #traceback.print_exc()
375                    sys.stderr.write("WARN: couldn't find enough packets for probe_id=%s\n" % probe_id)
376                   
377                db.addTrimAnalyses([analysis])
378    db.conn.commit()
379
380    # Populate analysis table so findUnusualTestCase can give us a starting point
381    pcursor.execute("DELETE FROM analysis")
382    db.conn.commit()
383    pcursor.execute("INSERT INTO analysis SELECT id,probe_id,suspect,packet_rtt,tsval_rtt FROM trim_analysis WHERE sent_trimmed=0 AND rcvd_trimmed=0")
384   
385    unusual_case,delta = findUnusualTestCase(db)
386    evaluations = {}
387    for strim in range(0,num_sent):
388        for rtrim in range(0,num_rcvd):
389            evaluations[(strim,rtrim)] = evaluateTrim(db, unusual_case, strim, rtrim)
390
391    import pprint
392    pprint.pprint(evaluations)
393
394    delta_margin = 0.15
395    best_strim = 0
396    best_rtrim = 0
397    good_delta,good_mad = evaluations[(0,0)]
398   
399    for strim in range(1,num_sent):
400        delta,mad = evaluations[(strim,0)]
401        if delta*good_delta > 0.0 and (abs(good_delta) - abs(delta)) < abs(delta_margin*good_delta) and mad < good_mad:
402            best_strim = strim
403        else:
404            break
405
406    good_delta,good_mad = evaluations[(best_strim,0)]
407    for rtrim in range(1,num_rcvd):
408        delta,mad = evaluations[(best_strim,rtrim)]
409        if delta*good_delta > 0.0 and (abs(good_delta) - abs(delta)) < abs(delta_margin*good_delta) and mad < good_mad:
410            best_rtrim = rtrim
411        else:
412            break
413
414    print("selected trim parameters:",(best_strim,best_rtrim))
415   
416    if best_strim != 0 or best_rtrim !=0:
417        pcursor.execute("DELETE FROM analysis")
418        db.conn.commit()
419        pcursor.execute("INSERT INTO analysis SELECT id,probe_id,suspect,packet_rtt,tsval_rtt FROM trim_analysis WHERE sent_trimmed=? AND rcvd_trimmed=?",
420                        (best_strim,best_rtrim))
421
422    #pcursor.execute("DELETE FROM trim_analysis")
423    db.conn.commit()
424   
425    return count
426
427
428       
429def parseJSONLines(fp):
430    for line in fp:
431        yield json.loads(line)
432
433
434def associatePackets(sniffer_fp, db):
435    sniffer_fp.seek(0)
436
437    # now combine sampler data with packet data
438    buffered = []
439
440    cursor = db.conn.cursor()
441    cursor.execute("SELECT count(*) count,min(time_of_day) start,max(time_of_day+userspace_rtt) end from probes")
442    ptimes = cursor.fetchone()
443    window_size = 100*int((ptimes['end']-ptimes['start'])/ptimes['count'])
444    print("associate window_size:", window_size)
445
446    db.addPackets(parseJSONLines(sniffer_fp), window_size)
447
448    cursor.execute("SELECT count(*) count FROM packets WHERE probe_id is NULL")
449    unmatched = cursor.fetchone()['count']
450    if unmatched > 0:
451        sys.stderr.write("WARNING: %d observed packets didn't find a home...\n" % unmatched)
452 
453    return None
454
455
456def enumStoredTestCases(db):
457    cursor = db.conn.cursor()
458    cursor.execute("SELECT test_case FROM probes GROUP BY test_case")
459    return [tc[0] for tc in cursor]
460
461
462def findUnusualTestCase(db):
463    test_cases = enumStoredTestCases(db)
464
465    cursor = db.conn.cursor()
466    cursor.execute("SELECT packet_rtt FROM probes,analysis WHERE probes.id=analysis.probe_id AND probes.type in ('train','test')")
467    global_tm = quadsummary([row['packet_rtt'] for row in cursor])
468
469    tm_abs = []
470    tm_map = {}
471    # XXX: if more speed needed, percentile extension to sqlite might be handy...
472    for tc in test_cases:
473        cursor.execute("SELECT packet_rtt FROM probes,analysis WHERE probes.id=analysis.probe_id AND probes.type in ('train','test') AND probes.test_case=?", (tc,))
474        tm_map[tc] = quadsummary([row['packet_rtt'] for row in cursor])
475        tm_abs.append((abs(tm_map[tc]-global_tm), tc))
476
477    magnitude,tc = max(tm_abs)
478    cursor.execute("SELECT packet_rtt FROM probes,analysis WHERE probes.id=analysis.probe_id AND probes.type in ('train','test') AND probes.test_case<>?", (tc,))
479    remaining_tm = quadsummary([row['packet_rtt'] for row in cursor])
480
481    ret_val = (tc, tm_map[tc]-remaining_tm)
482    print("unusual_case: %s, delta: %f" % ret_val)
483    return ret_val
484
485
486def reportProgress(db, sample_types, start_time):
487    cursor = db.conn.cursor()
488    output = ''
489    total_completed = 0
490    total_requested = 0
491    for st in sample_types:
492        cursor.execute("SELECT count(id) c FROM (SELECT id FROM probes WHERE type=? AND time_of_day>? GROUP BY sample)", (st[0],int(start_time*1000000000)))
493        count = cursor.fetchone()[0]
494        output += " | %s remaining: %d" % (st[0], st[1]-count)
495        total_completed += count
496        total_requested += st[1]
497
498    rate = total_completed / (time.time() - start_time)
499    total_time = total_requested / rate       
500    eta = datetime.datetime.fromtimestamp(start_time+total_time)
501    print("STATUS:",output[3:],"| est. total_time: %s | est. ETA: %s" % (str(datetime.timedelta(seconds=total_time)), str(eta)))
Note: See TracBrowser for help on using the repository browser.