source: trunk/lib/nanownlib/__init__.py @ 4

Last change on this file since 4 was 4, checked in by tim, 9 years ago

.

File size: 14.7 KB
RevLine 
[4]1#!/usr/bin/env python3
2#-*- mode: Python;-*-
3
4import sys
5import time
6import random
7import argparse
8import socket
9import datetime
10import http.client
11import threading
12import queue
13import subprocess
14import multiprocessing
15import csv
16import json
17import gzip
18import statistics
19import numpy
20import netifaces
21try:
22    import requests
23except:
24    sys.stderr.write('ERROR: Could not import requests module.  Ensure it is installed.\n')
25    sys.stderr.write('       Under Debian, the package name is "python3-requests"\n.')
26    sys.exit(1)
27
28from .stats import *
29
30
31def getLocalIP(remote_host, remote_port):
32    connection = socket.create_connection((remote_host, remote_port))
33    ret_val = connection.getsockname()[0]
34    connection.close()
35
36    return ret_val
37
38
39def getIfaceForIP(ip):
40    for iface in netifaces.interfaces():
41        addrs = netifaces.ifaddresses(iface).get(netifaces.AF_INET, None)
42        if addrs:
43            for a in addrs:
44                if a.get('addr', None) == ip:
45                    return iface
46
47
48def setTCPTimestamps(enabled=True):
49    fh = open('/proc/sys/net/ipv4/tcp_timestamps', 'r+b')
50    ret_val = False
51    if fh.read(1) == b'1':
52        ret_val = True
53
54    fh.seek(0)
55    if enabled:
56        fh.write(b'1')
57    else:
58        fh.write(b'0')
59    fh.close()
60   
61    return ret_val
62
63
64def trickleHTTPRequest(ip,port,hostname):
65    my_port = None
66    try:
67        sock = socket.create_connection((ip, port))
68        my_port = sock.getsockname()[1]
69       
70        #print('.')
71        sock.sendall(b'GET / HTTP/1.1\r\n')
72        time.sleep(0.5)
73        rest = b'''Host: '''+hostname.encode('utf-8')+b'''\r\nUser-Agent: Secret Agent Man\r\nX-Extra: extra read all about it!\r\nConnection: close\r\n'''
74        for r in rest:
75            sock.sendall(bytearray([r]))
76            time.sleep(0.05)
77
78        time.sleep(0.5)
79        sock.sendall('\r\n')
80
81        r = None
82        while r != b'':
83            r = sock.recv(16)
84
85        sock.close()
86    except Exception as e:
87        pass
88
89    return my_port
90
91
92def runTimestampProbes(host_ip, port, hostname, num_trials, concurrency=4): 
93    myq = queue.Queue()
94    def threadWrapper(*args):
95        try:
96            myq.put(trickleHTTPRequest(*args))
97        except Exception as e:
98            sys.stderr.write("ERROR from trickleHTTPRequest: %s\n" % repr(e))
99            myq.put(None)
100
101    threads = []
102    ports = []
103    for i in range(num_trials):
104        if len(threads) >= concurrency:
105            ports.append(myq.get())
106        t = threading.Thread(target=threadWrapper, args=(host_ip, port, hostname))
107        t.start()
108        threads.append(t)
109
110    for t in threads:
111        t.join()
112
113    while myq.qsize() > 0:
114        ports.append(myq.get())
115
116    return ports
117
118
119def computeTimestampPrecision(sniffer_fp, ports):
120    rcvd = []
121    for line in sniffer_fp:
122        p = json.loads(line)
123        if p['sent']==0:
124            rcvd.append((p['observed'],p['tsval'],int(p['local_port'])))
125
126    slopes = []
127    for port in ports:
128        trcvd = [tr for tr in rcvd if tr[2]==port and tr[1]!=0]
129
130        if len(trcvd) < 2:
131            sys.stderr.write("WARN: Inadequate data points.\n")
132            continue
133       
134        if trcvd[0][1] > trcvd[-1][1]:
135            sys.stderr.write("WARN: TSval wrap.\n")
136            continue
137
138        x = [tr[1] for tr in trcvd]
139        y = [tr[0] for tr in trcvd]
140
141        slope,intercept = OLSRegression(x, y)
142        slopes.append(slope)
143
144    if len(slopes) == 0:
145        return None,None,None
146
147    m = statistics.mean(slopes)
148    if len(slopes) == 1:
149        return (m, None, slopes)
150    else:
151        return (m, statistics.stdev(slopes), slopes)
152
153   
154def OLSRegression(x,y):
155    #print(x,y)
156    x = numpy.array(x)
157    y = numpy.array(y)
158    #A = numpy.vstack([x, numpy.ones(len(x))]).T
159    #m, c = numpy.linalg.lstsq(A, y)[0] # broken
160    #c,m = numpy.polynomial.polynomial.polyfit(x, y, 1) # less accurate
161    c,m = numpy.polynomial.Polynomial.fit(x,y,1).convert().coef
162
163    #print(m,c)
164
165    #import matplotlib.pyplot as plt
166    #plt.clf()
167    #plt.scatter(x, y)
168    #plt.plot(x, m*x + c, 'r', label='Fitted line')
169    #plt.show()
170   
171    return (m,c)
172
173
174def startSniffer(target_ip, target_port, output_file):
175    my_ip = getLocalIP(target_ip, target_port)
176    my_iface = getIfaceForIP(my_ip)
177    return subprocess.Popen(['chrt', '-r', '99', './bin/csamp', my_iface, my_ip,
178                             target_ip, "%d" % target_port, output_file, '0'])
179
180def stopSniffer(sniffer):
181    sniffer.terminate()
182    sniffer.wait(2)
183    if sniffer.poll() == None:
184        sniffer.kill()
185        sniffer.wait(1)
186
187       
188def setCPUAffinity():
189    import ctypes
190    from ctypes import cdll,c_int,byref
191    cpus = multiprocessing.cpu_count()
192   
193    libc = cdll.LoadLibrary("libc.so.6")
194    #libc.sched_setaffinity(os.getpid(), 1, ctypes.byref(ctypes.c_int(0x01)))
195    return libc.sched_setaffinity(0, 4, byref(c_int(0x00000001<<(cpus-1))))
196
197
198# Monkey patching that instruments the HTTPResponse to collect connection source port info
199class MonitoredHTTPResponse(http.client.HTTPResponse):
200    local_address = None
201
202    def __init__(self, sock, *args, **kwargs):
203        self.local_address = sock.getsockname()
204        super(MonitoredHTTPResponse, self).__init__(sock,*args,**kwargs)
205           
206requests.packages.urllib3.connection.HTTPConnection.response_class = MonitoredHTTPResponse
207
208
209def removeDuplicatePackets(packets):
210    #return packets
211    suspect = None
212    seen = {}
213    # XXX: Need to review this deduplication algorithm and make sure it is correct
214    for p in packets:
215        key = (p['sent'],p['tcpseq'],p['tcpack'],p['payload_len'])
216        #if (key not in seen)\
217        #   or p['sent']==1 and (seen[key]['observed'] < p['observed'])\
218        #   or p['sent']==0 and (seen[key]['observed'] > p['observed']):
219        if (key not in seen) or (seen[key]['observed'] > p['observed']):
220            seen[key] = p
221   
222    if len(seen) < len(packets):
223        suspect = 'd'
224        #sys.stderr.write("INFO: removed %d duplicate packets.\n" % (len(packets) - len(seen)))
225
226    return suspect,seen.values()
227
228
229def analyzePackets(packets, timestamp_precision, trim_sent=0, trim_rcvd=0):
230    suspect,packets = removeDuplicatePackets(packets)
231
232    #sort_key = lambda d: (d['tcpseq'],d['tcpack'])
233    sort_key = lambda d: (d['observed'],d['tcpseq'])
234    sent = sorted((p for p in packets if p['sent']==1 and p['payload_len']>0), key=sort_key)
235    rcvd = sorted((p for p in packets if p['sent']==0 and p['payload_len']>0), key=sort_key)
236
237    if len(sent) <= trim_sent:
238        last_sent = sent[-1]
239    else:
240        last_sent = sent[trim_sent]
241
242    if len(rcvd) <= trim_rcvd:
243        last_rcvd = rcvd[0]
244    else:
245        last_rcvd = rcvd[len(rcvd)-1-trim_rcvd]
246   
247    packet_rtt = last_rcvd['observed'] - last_sent['observed']
248    if packet_rtt < 0:
249        sys.stderr.write("WARN: Negative packet_rtt. last_rcvd=%s,last_sent=%s\n" % (last_rcvd, last_sent))
250
251    last_sent_ack = None
252    try:
253        last_sent_ack = min(((p['observed'],p) for p in packets
254                             if p['sent']==0 and p['payload_len']+last_sent['tcpseq']==p['tcpack']))[1]
255    except Exception as e:
256        sys.stderr.write("WARN: Could not find last_sent_ack.\n")
257
258    tsval_rtt = None
259    if None not in (timestamp_precision, last_sent_ack):
260        tsval_rtt = int(round((last_rcvd['tsval'] - last_sent_ack['tsval'])*timestamp_precision))
261
262    return {'packet_rtt':packet_rtt,
263            'tsval_rtt':tsval_rtt,
264            'suspect':suspect,
265            'sent_trimmed':trim_sent,
266            'rcvd_trimmed':trim_rcvd},len(sent),len(rcvd)
267
268
269# trimean and mad for each dist of differences
270def evaluateTrim(db, unusual_case, strim, rtrim):
271    cursor = db.conn.cursor()
272    query="""
273      SELECT packet_rtt-(SELECT avg(packet_rtt) FROM probes,trim_analysis
274                         WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case!=:unusual_case AND sample=u.sample AND probes.type in ('train','test'))
275      FROM (SELECT probes.sample,packet_rtt FROM probes,trim_analysis WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case=:unusual_case AND probes.type in ('train','test')) u
276    """
277
278    params = {"strim":strim,"rtrim":rtrim,"unusual_case":unusual_case}
279    cursor.execute(query, params)
280    differences = [row[0] for row in cursor]
281   
282    return trimean(differences),mad(differences)
283
284
285
286def analyzeProbes(db):
287    db.conn.execute("CREATE INDEX IF NOT EXISTS packets_probe ON packets (probe_id)")
288    pcursor = db.conn.cursor()
289    kcursor = db.conn.cursor()
290
291    pcursor.execute("SELECT tcpts_mean FROM meta")
292    try:
293        timestamp_precision = pcursor.fetchone()[0]
294    except:
295        timestamp_precision = None
296   
297    pcursor.execute("DELETE FROM trim_analysis")
298    db.conn.commit()
299   
300    count = 0
301    sent_tally = []
302    rcvd_tally = []
303    for pid, in pcursor.execute("SELECT id FROM probes"):
304        kcursor.execute("SELECT * FROM packets WHERE probe_id=?", (pid,))
305        try:
306            analysis,s,r = analyzePackets(kcursor.fetchall(), timestamp_precision)
307            analysis['probe_id'] = pid
308            sent_tally.append(s)
309            rcvd_tally.append(r)
310        except Exception as e:
311            print(e)
312            sys.stderr.write("WARN: couldn't find enough packets for probe_id=%s\n" % pid)
313       
314        #print(pid,analysis)
315        db.addTrimAnalyses([analysis])
316        count += 1
317    db.conn.commit()
318    num_sent = statistics.mode(sent_tally)
319    num_rcvd = statistics.mode(rcvd_tally)
320    sent_tally = None
321    rcvd_tally = None
322    print("num_sent: %d, num_rcvd: %d" % (num_sent,num_rcvd))
323   
324    for strim in range(0,num_sent):
325        for rtrim in range(0,num_rcvd):
326            if strim == 0 and rtrim == 0:
327                continue # no point in doing 0,0 again
328            for pid, in pcursor.execute("SELECT id FROM probes"):
329                kcursor.execute("SELECT * FROM packets WHERE probe_id=?", (pid,))
330                try:
331                    analysis,s,r = analyzePackets(kcursor.fetchall(), timestamp_precision, strim, rtrim)
332                    analysis['probe_id'] = pid
333                except Exception as e:
334                    print(e)
335                    sys.stderr.write("WARN: couldn't find enough packets for probe_id=%s\n" % pid)
336                   
337                db.addTrimAnalyses([analysis])
338            db.conn.commit()
339
340    # Populate analysis table so findUnusualTestCase can give us a starting point
341    pcursor.execute("DELETE FROM analysis")
342    db.conn.commit()
343    pcursor.execute("INSERT INTO analysis SELECT id,probe_id,suspect,packet_rtt,tsval_rtt FROM trim_analysis WHERE sent_trimmed=0 AND rcvd_trimmed=0")
344   
345    unusual_case,delta = findUnusualTestCase(db)
346    evaluations = {}
347    for strim in range(0,num_sent):
348        for rtrim in range(0,num_rcvd):
349            evaluations[(strim,rtrim)] = evaluateTrim(db, unusual_case, strim, rtrim)
350
351    import pprint
352    pprint.pprint(evaluations)
353
354    delta_margin = 0.1
355    best_strim = 0
356    best_rtrim = 0
357    good_delta,good_mad = evaluations[(0,0)]
358   
359    for strim in range(1,num_sent):
360        delta,mad = evaluations[(strim,0)]
361        if abs(good_delta - delta) < abs(delta_margin*good_delta) and mad < good_mad:
362            best_strim = strim
363        else:
364            break
365
366    good_delta,good_mad = evaluations[(best_strim,0)]
367    for rtrim in range(1,num_rcvd):
368        delta,mad = evaluations[(best_strim,rtrim)]
369        if (abs(delta) > abs(good_delta) or abs(good_delta - delta) < abs(delta_margin*good_delta)) and mad < good_mad:
370            best_rtrim = rtrim
371        else:
372            break
373
374    print("selected trim parameters:",(best_strim,best_rtrim))
375   
376    if best_strim != 0 or best_rtrim !=0:
377        pcursor.execute("DELETE FROM analysis")
378        db.conn.commit()
379        pcursor.execute("INSERT INTO analysis SELECT id,probe_id,suspect,packet_rtt,tsval_rtt FROM trim_analysis WHERE sent_trimmed=? AND rcvd_trimmed=?",
380                        (best_strim,best_rtrim))
381   
382    return count
383
384
385       
386def parseJSONLines(fp):
387    for line in fp:
388        yield json.loads(line)
389
390
391def associatePackets(sniffer_fp, db):
392    sniffer_fp.seek(0)
393
394    # now combine sampler data with packet data
395    buffered = []
396
397    cursor = db.conn.cursor()
398    cursor.execute("SELECT count(*) count,min(time_of_day) start,max(time_of_day+userspace_rtt) end from probes")
399    ptimes = cursor.fetchone()
400    window_size = 100*int((ptimes['end']-ptimes['start'])/ptimes['count'])
401    print("associate window_size:", window_size)
402
403    db.addPackets(parseJSONLines(sniffer_fp), window_size)
404
405    cursor.execute("SELECT count(*) count FROM packets WHERE probe_id is NULL")
406    unmatched = cursor.fetchone()['count']
407    if unmatched > 0:
408        sys.stderr.write("WARNING: %d observed packets didn't find a home...\n" % unmatched)
409 
410    return None
411
412
413def enumStoredTestCases(db):
414    cursor = db.conn.cursor()
415    cursor.execute("SELECT test_case FROM probes GROUP BY test_case")
416    return [tc[0] for tc in cursor]
417
418
419def findUnusualTestCase(db):
420    test_cases = enumStoredTestCases(db)
421
422    cursor = db.conn.cursor()
423    cursor.execute("SELECT packet_rtt FROM probes,analysis WHERE probes.id=analysis.probe_id AND probes.type in ('train','test')")
424    global_tm = trimean([row['packet_rtt'] for row in cursor])
425
426    tm_abs = []
427    tm_map = {}
428    # XXX: if more speed needed, percentile extension to sqlite might be handy...
429    for tc in test_cases:
430        cursor.execute("SELECT packet_rtt FROM probes,analysis WHERE probes.id=analysis.probe_id AND probes.type in ('train','test') AND probes.test_case=?", (tc,))
431        tm_map[tc] = trimean([row['packet_rtt'] for row in cursor])
432        tm_abs.append((abs(tm_map[tc]-global_tm), tc))
433
434    magnitude,tc = max(tm_abs)
435    cursor.execute("SELECT packet_rtt FROM probes,analysis WHERE probes.id=analysis.probe_id AND probes.type in ('train','test') AND probes.test_case<>?", (tc,))
436    remaining_tm = trimean([row['packet_rtt'] for row in cursor])
437
438    return (tc, tm_map[tc]-remaining_tm)
439
440
441def reportProgress(db, sample_types, start_time):
442    cursor = db.conn.cursor()
443    output = ''
444    total_completed = 0
445    total_requested = 0
446    for st in sample_types:
447        cursor.execute("SELECT count(id) c FROM (SELECT id FROM probes WHERE type=? AND time_of_day>? GROUP BY sample)", (st[0],int(start_time*1000000000)))
448        count = cursor.fetchone()[0]
449        output += " | %s remaining: %d" % (st[0], st[1]-count)
450        total_completed += count
451        total_requested += st[1]
452
453    rate = total_completed / (time.time() - start_time)
454    total_time = total_requested / rate       
455    eta = datetime.datetime.fromtimestamp(start_time+total_time)
456    print("STATUS:",output[3:],"| est. total_time: %s | est. ETA: %s" % (str(datetime.timedelta(seconds=total_time)), str(eta)))
Note: See TracBrowser for help on using the repository browser.