source: trunk/lib/nanownlib/__init__.py @ 21

Last change on this file since 21 was 20, checked in by tim, 9 years ago

major code refactoring, better organizing location of library functions

File size: 16.5 KB
Line 
1#!/usr/bin/env python3
2#-*- mode: Python;-*-
3
4import sys
5import time
6import traceback
7import socket
8import datetime
9import http.client
10import subprocess
11import tempfile
12import json
13import gzip
14import statistics
15
16try:
17    import requests
18except:
19    sys.stderr.write('ERROR: Could not import requests module.  Ensure it is installed.\n')
20    sys.stderr.write('       Under Debian, the package name is "python3-requests"\n.')
21    sys.exit(1)
22
23from .stats import *
24
25
26def getLocalIP(remote_host, remote_port):
27    connection = socket.create_connection((remote_host, remote_port))
28    ret_val = connection.getsockname()[0]
29    connection.close()
30
31    return ret_val
32
33
34def getIfaceForIP(ip):
35    try:
36        import netifaces
37    except:
38        sys.stderr.write('ERROR: Could not import netifaces module.  Ensure it is installed.\n')
39        sys.stderr.write('       Try: pip3 install netifaces\n.')
40        sys.exit(1)
41   
42    for iface in netifaces.interfaces():
43        addrs = netifaces.ifaddresses(iface).get(netifaces.AF_INET, None)
44        if addrs:
45            for a in addrs:
46                if a.get('addr', None) == ip:
47                    return iface
48
49
50class snifferProcess(object):
51    my_ip = None
52    my_iface = None
53    target_ip = None
54    target_port = None
55    _proc = None
56    _spool = None
57   
58    def __init__(self, target_ip, target_port):
59        self.target_ip = target_ip
60        self.target_port = target_port
61        self.my_ip = getLocalIP(target_ip, target_port)
62        self.my_iface = getIfaceForIP(self.my_ip)
63        print(self.my_ip, self.my_iface)
64
65    def start(self):
66        self._spool = tempfile.NamedTemporaryFile('w+t')
67        self._proc = subprocess.Popen(['chrt', '-r', '99', 'nanown-listen',
68                                       self.my_iface, self.my_ip,
69                                       self.target_ip, "%d" % self.target_port,
70                                       self._spool.name, '0'])
71        time.sleep(0.25)
72
73    def openPacketLog(self):
74        return open(self._spool.name, 'rt')
75       
76    def stop(self):
77        if self._proc:
78            self._proc.terminate()
79            self._proc.wait(2)
80            if self._proc.poll() == None:
81                self._proc.kill()
82                self._proc.wait(1)
83            self._proc = None
84   
85    def is_running(self):
86        return (self._proc.poll() == None)
87           
88    def __del__(self):
89        self.stop()
90
91           
92def startSniffer(target_ip, target_port, output_file):
93    my_ip = getLocalIP(target_ip, target_port)
94    my_iface = getIfaceForIP(my_ip)
95    return subprocess.Popen(['chrt', '-r', '99', 'nanown-listen', my_iface, my_ip,
96                             target_ip, "%d" % target_port, output_file, '0'])
97
98def stopSniffer(sniffer):
99    sniffer.terminate()
100    sniffer.wait(2)
101    if sniffer.poll() == None:
102        sniffer.kill()
103        sniffer.wait(1)
104
105
106# Monkey patching that instruments the HTTPResponse to collect connection source port info
107class MonitoredHTTPResponse(http.client.HTTPResponse):
108    local_address = None
109
110    def __init__(self, sock, *args, **kwargs):
111        self.local_address = sock.getsockname()
112        #print(self.local_address)
113        super(MonitoredHTTPResponse, self).__init__(sock,*args,**kwargs)
114           
115requests.packages.urllib3.connection.HTTPConnection.response_class = MonitoredHTTPResponse
116
117
118def removeDuplicatePackets(packets):
119    #return packets
120    suspect = ''
121    seen = {}
122    # XXX: Need to review this deduplication algorithm and make sure it is correct
123    for p in packets:
124        key = (p['sent'],p['tcpseq'],p['tcpack'],p['payload_len'])
125        if (key not in seen):
126            seen[key] = p
127            continue
128        if p['sent']==1 and (seen[key]['observed'] > p['observed']): #earliest sent
129            seen[key] = p
130            suspect += 's' # duplicated sent packets
131            continue 
132        if p['sent']==0 and (seen[key]['observed'] > p['observed']): #earliest rcvd
133            seen[key] = p
134            suspect += 'r' # duplicated received packets
135            continue
136   
137    #if len(seen) < len(packets):
138    #   sys.stderr.write("INFO: removed %d duplicate packets.\n" % (len(packets) - len(seen)))
139
140    return suspect,seen.values()
141
142
143def analyzePackets(packets, timestamp_precision, trim_sent=0, trim_rcvd=0):
144    suspect,packets = removeDuplicatePackets(packets)
145
146    sort_key = lambda d: (d['observed'],d['tcpseq'])
147    alt_key = lambda d: (d['tcpseq'],d['observed'])
148    sent = sorted((p for p in packets if p['sent']==1 and p['payload_len']>0), key=sort_key)
149    rcvd = sorted((p for p in packets if p['sent']==0 and p['payload_len']>0), key=sort_key)
150    rcvd_alt = sorted((p for p in packets if p['sent']==0 and p['payload_len']>0), key=alt_key)
151
152    s_off = trim_sent
153    if s_off >= len(sent):
154        suspect += 'd' # dropped packet?
155        s_off = -1
156    last_sent = sent[s_off]
157
158    r_off = len(rcvd) - trim_rcvd - 1
159    if r_off < 0:
160        suspect += 'd' # dropped packet?
161        r_off = 0
162    last_rcvd = rcvd[r_off]
163    if last_rcvd != rcvd_alt[r_off]:
164        suspect += 'R' # reordered received packets
165   
166    last_sent_ack = None
167    try:
168        last_sent_ack = min(((p['tcpack'],p['observed'],p) for p in packets
169                             if p['sent']==0 and p['payload_len']+last_sent['tcpseq']>=p['tcpack']))[2]
170       
171    except Exception as e:
172        sys.stderr.write("WARN: Could not find last_sent_ack.\n")
173
174    packet_rtt = last_rcvd['observed'] - last_sent['observed']
175    tsval_rtt = None
176    if None not in (timestamp_precision, last_sent_ack):
177        tsval_rtt = int(round((last_rcvd['tsval'] - last_sent_ack['tsval'])*timestamp_precision))
178
179    if packet_rtt < 0 or (tsval_rtt != None and tsval_rtt < 0):
180        #sys.stderr.write("WARN: Negative packet or tsval RTT. last_rcvd=%s,last_sent=%s\n" % (last_rcvd, last_sent))
181        suspect += 'N'
182       
183    return {'packet_rtt':packet_rtt,
184            'tsval_rtt':tsval_rtt,
185            'suspect':suspect,
186            'sent_trimmed':trim_sent,
187            'rcvd_trimmed':trim_rcvd},len(sent),len(rcvd)
188
189
190# septasummary and mad for each dist of differences
191def evaluateTrim(db, unusual_case, strim, rtrim):
192    cursor = db.conn.cursor()
193    query="""
194      SELECT packet_rtt-(SELECT avg(packet_rtt) FROM probes,trim_analysis
195                         WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case!=:unusual_case AND sample=u.s AND probes.type in ('train','test'))
196      FROM (SELECT probes.sample s,packet_rtt FROM probes,trim_analysis WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case=:unusual_case AND probes.type in ('train','test') AND 1 NOT IN (select 1 from probes p,trim_analysis t WHERE p.sample=s AND t.probe_id=p.id AND t.suspect LIKE '%R%')) u
197    """
198    query="""
199      SELECT packet_rtt-(SELECT avg(packet_rtt) FROM probes,trim_analysis
200                         WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case!=:unusual_case AND sample=u.s AND probes.type in ('train','test'))
201      FROM (SELECT probes.sample s,packet_rtt FROM probes,trim_analysis WHERE sent_trimmed=:strim AND rcvd_trimmed=:rtrim AND trim_analysis.probe_id=probes.id AND probes.test_case=:unusual_case AND probes.type in ('train','test')) u
202    """
203    #TODO: check for "N" in suspect field and return a flag
204   
205    params = {"strim":strim,"rtrim":rtrim,"unusual_case":unusual_case}
206    cursor.execute(query, params)
207    differences = [row[0] for row in cursor]
208   
209    return septasummary(differences),mad(differences)
210
211
212
213def analyzeProbes(db, trim=None, recompute=False):
214    db.conn.execute("CREATE INDEX IF NOT EXISTS packets_probe ON packets (probe_id)")
215    db.conn.commit()
216
217    pcursor = db.conn.cursor()
218    pcursor.execute("SELECT tcpts_mean FROM meta")
219    try:
220        timestamp_precision = pcursor.fetchone()[0]
221    except:
222        timestamp_precision = None
223   
224    pcursor.execute("DELETE FROM trim_analysis")
225    db.conn.commit()
226    if recompute:
227        pcursor.execute("DELETE FROM analysis")
228        db.conn.commit()
229
230    def loadPackets(db):
231        cursor = db.conn.cursor()
232        #cursor.execute("SELECT * FROM packets ORDER BY probe_id")
233        cursor.execute("SELECT * FROM packets WHERE probe_id NOT IN (SELECT probe_id FROM analysis) ORDER BY probe_id")
234
235        probe_id = None
236        entry = []
237        ret_val = []
238        for p in cursor:
239            if probe_id == None:
240                probe_id = p['probe_id']
241            if p['probe_id'] != probe_id:
242                ret_val.append((probe_id,entry))
243                probe_id = p['probe_id']
244                entry = []
245            entry.append(dict(p))
246        ret_val.append((probe_id,entry))
247        return ret_val
248
249    def processPackets(packet_cache, strim, rtrim):
250        sent_tally = []
251        rcvd_tally = []
252        analyses = []
253        for probe_id,packets in packet_cache:
254            try:
255                analysis,s,r = analyzePackets(packets, timestamp_precision)
256                analysis['probe_id'] = probe_id
257                analyses.append(analysis)
258                sent_tally.append(s)
259                rcvd_tally.append(r)
260            except Exception as e:
261                #traceback.print_exc()
262                sys.stderr.write("WARN: couldn't find enough packets for probe_id=%s\n" % probe_id)
263        db.addTrimAnalyses(analyses)
264        db.conn.commit()
265        return statistics.mode(sent_tally),statistics.mode(rcvd_tally)
266   
267    #start = time.time()
268    packet_cache = loadPackets(db)
269    #print("packets loaded in: %f" % (time.time()-start))
270
271    if trim != None:
272        best_strim,best_rtrim = trim
273        processPackets(packet_cache, best_strim, best_rtrim)
274    else:
275        num_sent,num_rcvd = processPackets(packet_cache, 0, 0)
276        print("num_sent: %d, num_rcvd: %d" % (num_sent,num_rcvd))
277   
278        for strim in range(0,num_sent):
279            for rtrim in range(0,num_rcvd):
280                #print(strim,rtrim)
281                if strim == 0 and rtrim == 0:
282                    continue # no point in doing 0,0 again
283                processPackets(packet_cache, strim, rtrim)
284
285   
286        unusual_case,delta = findUnusualTestCase(db, (0,0))
287        evaluations = {}
288        for strim in range(0,num_sent):
289            for rtrim in range(0,num_rcvd):
290                evaluations[(strim,rtrim)] = evaluateTrim(db, unusual_case, strim, rtrim)
291
292        import pprint
293        pprint.pprint(evaluations)
294
295        delta_margin = 0.15
296        best_strim = 0
297        best_rtrim = 0
298        good_delta,good_mad = evaluations[(0,0)]
299   
300        for strim in range(1,num_sent):
301            delta,mad = evaluations[(strim,0)]
302            if delta*good_delta > 0.0 and (abs(good_delta) - abs(delta)) < abs(delta_margin*good_delta) and mad < good_mad:
303                best_strim = strim
304            else:
305                break
306
307        good_delta,good_mad = evaluations[(best_strim,0)]
308        for rtrim in range(1,num_rcvd):
309            delta,mad = evaluations[(best_strim,rtrim)]
310            if delta*good_delta > 0.0 and (abs(good_delta) - abs(delta)) < abs(delta_margin*good_delta) and mad < good_mad:
311                best_rtrim = rtrim
312            else:
313                break
314
315        print("selected trim parameters:",(best_strim,best_rtrim))
316   
317    pcursor.execute("""INSERT OR IGNORE INTO analysis
318                         SELECT id,probe_id,suspect,packet_rtt,tsval_rtt
319                           FROM trim_analysis
320                           WHERE sent_trimmed=? AND rcvd_trimmed=?""",
321                    (best_strim,best_rtrim))
322    db.conn.commit()
323   
324    return len(packet_cache)
325
326
327       
328def parseJSONLines(fp):
329    for line in fp:
330        yield json.loads(line)
331
332
333def associatePackets(sniffer_fp, db):
334    sniffer_fp.seek(0)
335
336    # now combine sampler data with packet data
337    buffered = []
338
339    cursor = db.conn.cursor()
340    cursor.execute("SELECT count(*) count,min(time_of_day) start,max(time_of_day+userspace_rtt) end from probes")
341    ptimes = cursor.fetchone()
342    window_size = 100*int((ptimes['end']-ptimes['start'])/ptimes['count'])
343    #print("associate window_size:", window_size)
344
345    db.addPackets(parseJSONLines(sniffer_fp), window_size)
346
347    cursor.execute("SELECT count(*) count FROM packets WHERE probe_id is NULL")
348    unmatched = cursor.fetchone()['count']
349    if unmatched > 0:
350        sys.stderr.write("WARNING: %d observed packets didn't find a home...\n" % unmatched)
351 
352    return None
353
354
355def enumStoredTestCases(db):
356    cursor = db.conn.cursor()
357    cursor.execute("SELECT test_case FROM probes GROUP BY test_case")
358    return [tc[0] for tc in cursor]
359
360
361def findUnusualTestCase(db, trim=None):
362    test_cases = enumStoredTestCases(db)
363    if trim != None:
364        params = {'strim':trim[0], 'rtrim':trim[1]}
365        qsuffix = " AND sent_trimmed=:strim AND rcvd_trimmed=:rtrim"
366        table = "trim_analysis"
367    else:
368        params = {}
369        qsuffix = ""
370        table = "analysis"
371   
372    cursor = db.conn.cursor()
373    cursor.execute("SELECT packet_rtt FROM probes,"+table+" a WHERE probes.id=a.probe_id AND probes.type in ('train','test')"+qsuffix, params)
374    global_tm = quadsummary([row['packet_rtt'] for row in cursor])
375
376    tm_abs = []
377    tm_map = {}
378
379    # XXX: if more speed needed, percentile extension to sqlite might be handy...
380    for tc in test_cases:
381        params['test_case']=tc
382        query = """SELECT packet_rtt FROM probes,"""+table+""" a
383                   WHERE probes.id=a.probe_id AND probes.type in ('train','test')
384                   AND probes.test_case=:test_case""" + qsuffix
385        cursor.execute(query, params)
386        tm_map[tc] = quadsummary([row['packet_rtt'] for row in cursor])
387        tm_abs.append((abs(tm_map[tc]-global_tm), tc))
388
389    magnitude,tc = max(tm_abs)
390    params['test_case']=tc
391    query = """SELECT packet_rtt FROM probes,"""+table+""" a
392               WHERE probes.id=a.probe_id AND probes.type in ('train','test')
393               AND probes.test_case<>:test_case""" + qsuffix
394    cursor.execute(query,params)
395    remaining_tm = quadsummary([row['packet_rtt'] for row in cursor])
396
397    delta = tm_map[tc]-remaining_tm
398    # Hack to make the chosen unusual_case more intuitive to the user
399    if len(test_cases) == 2 and delta < 0.0:
400        tc = [t for t in test_cases if t != tc][0]
401        delta = abs(delta)
402
403    return tc,delta
404
405
406def reportProgress(db, sample_types, start_time):
407    cursor = db.conn.cursor()
408    output = ''
409    total_completed = 0
410    total_requested = 0
411    for st in sample_types:
412        cursor.execute("SELECT count(id) c FROM (SELECT id FROM probes WHERE type=? AND time_of_day>? GROUP BY sample)", (st[0],int(start_time*1000000000)))
413        count = cursor.fetchone()[0]
414        output += " | %s remaining: %6d" % (st[0], st[1]-count)
415        total_completed += count
416        total_requested += st[1]
417
418    rate = total_completed / (time.time() - start_time)
419    total_time = total_requested / rate
420    eta = datetime.datetime.fromtimestamp(start_time+total_time)
421    print("STATUS:",output[3:],"| est. total_time: %s | ETA: %s" % (str(datetime.timedelta(seconds=total_time)), eta.strftime("%Y-%m-%d %X")))
422
423
424
425def evaluateTestResults(db):
426    cursor = db.conn.cursor()
427    query = """
428      SELECT classifier FROM classifier_results GROUP BY classifier ORDER BY classifier;
429    """
430    cursor.execute(query)
431    classifiers = []
432    for c in cursor:
433        classifiers.append(c[0])
434
435    best_obs = []
436    best_error = []
437    max_obs = 0
438    for classifier in classifiers:
439        query="""
440        SELECT classifier,params,num_observations,(false_positives+false_negatives)/2 error
441        FROM classifier_results
442        WHERE trial_type='test'
443         AND classifier=:classifier
444         AND (false_positives+false_negatives)/2.0 < 5.0
445        ORDER BY num_observations,(false_positives+false_negatives)
446        LIMIT 1
447        """
448        cursor.execute(query, {'classifier':classifier})
449        row = cursor.fetchone()
450        if row == None:
451            query="""
452            SELECT classifier,params,num_observations,(false_positives+false_negatives)/2 error
453            FROM classifier_results
454            WHERE trial_type='test' and classifier=:classifier
455            ORDER BY (false_positives+false_negatives),num_observations
456            LIMIT 1
457            """
458            cursor.execute(query, {'classifier':classifier})
459            row = cursor.fetchone()
460            if row == None:
461                sys.stderr.write("WARN: couldn't find test results for classifier '%s'.\n" % classifier)
462                continue
463            row = dict(row)
464
465            best_error.append(dict(row))
466        else:
467            best_obs.append(dict(row))
468
469
470    return best_obs,best_error
Note: See TracBrowser for help on using the repository browser.