source: trunk/lib/bletchley/ssltls.py

Last change on this file was 134, checked in by tim, 7 years ago

cleanup

File size: 9.6 KB
Line 
1'''
2Utilities for manipulating certificates and SSL/TLS connections.
3
4Copyright (C) 2014,2016,2017 Blindspot Security LLC
5Author: Timothy D. Morgan
6
7 This program is free software: you can redistribute it and/or modify
8 it under the terms of the GNU Lesser General Public License, version 3,
9 as published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program.  If not, see <http://www.gnu.org/licenses/>.
18'''
19
20import sys
21import argparse
22import traceback
23import random
24import time
25import socket
26try:
27    import OpenSSL
28    from OpenSSL import SSL
29except:
30    sys.stderr.write('ERROR: Could not locate pyOpenSSL module.  Under Debian-based systems, try:\n')
31    sys.stderr.write('       # apt-get install python3-openssl\n')
32    sys.stderr.write('NOTE: pyOpenSSL version 0.14 or later is required!\n')
33    sys.exit(2)
34
35
36def createContext(method=SSL.TLSv1_METHOD, key=None, certChain=[]):
37    context = SSL.Context(method)
38    context.set_verify(SSL.VERIFY_NONE, (lambda a,b,c,d,e: True))
39
40    if key and len(certChain) > 0:
41        context.use_privatekey(key)
42        context.use_certificate(certChain[0])
43        for c in certChain[1:]:
44            context.add_extra_chain_cert(c)
45   
46    return context
47
48
49def startSSLTLS(sock, mode='client', protocol=SSL.TLSv1_METHOD, key=None, certChain=[], cipher_list=None, timeout=None):
50    '''
51    cipher_list names drawn from:
52      openssl ciphers -v "ALL:@SECLEVEL=0"
53    '''
54   
55    context = createContext(protocol, key=key, certChain=certChain)
56    if cipher_list:
57        context.set_cipher_list(cipher_list)
58    if timeout:
59        context.set_timeout(timeout)
60       
61    #if not key and mode == 'server':
62    #context.set_options(OpenSSL.SSL.OP_SINGLE_DH_USE)
63    #context.set_options(OpenSSL.SSL.OP_EPHEMERAL_RSA)
64   
65    conn = SSL.Connection(context, sock)
66    if mode == 'client':
67        conn.set_connect_state()
68        if timeout:
69            # This polling is needed because the socket timeouts have put the
70            # socket in non-blocking mode
71            start = time.time()+timeout
72            while time.time() < start:
73                try:
74                    conn.do_handshake()
75                    break
76                except (OpenSSL.SSL.WantReadError,OpenSSL.SSL.WantWriteError) as e:
77                    time.sleep(0.00001)
78        else:
79            conn.do_handshake()
80
81    else:
82        conn.set_accept_state()
83   
84    return conn
85
86
87def ConnectSSLTLS(host, port, cipher_list=None, timeout=None, handshake_callback=None, verbose=True):
88    backup_cipher_list = b'DES-CBC3-SHA:RC4-MD5:RC4-SHA:AES128-SHA:AES256-SHA:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:DHE-RSA-AES256-SHA256:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA256:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES128-SHA:DHE-RSA-AES128-SHA:ADH-AES256-GCM-SHA384'
89    protocols = [("SSL 2/3", SSL.SSLv23_METHOD, cipher_list),
90                 ("SSL 2/3", SSL.SSLv23_METHOD, backup_cipher_list),
91                 ("TLS 1.0", SSL.TLSv1_METHOD, cipher_list),
92                 ("TLS 1.0", SSL.TLSv1_METHOD, backup_cipher_list), 
93                 ("TLS 1.1", SSL.TLSv1_1_METHOD, cipher_list),
94                 ("TLS 1.1", SSL.TLSv1_1_METHOD, backup_cipher_list),
95                 ("TLS 1.2", SSL.TLSv1_2_METHOD, cipher_list),
96                 ("TLS 1.2", SSL.TLSv1_2_METHOD, backup_cipher_list),
97                 ("SSL 3.0", SSL.SSLv3_METHOD, cipher_list),
98                 ("SSL 3.0", SSL.SSLv3_METHOD, backup_cipher_list),
99                 ("SSL 2.0", SSL.SSLv2_METHOD, cipher_list),
100                 ("SSL 2.0", SSL.SSLv2_METHOD, backup_cipher_list)]
101
102    conn = None
103    for pname,p,cl in protocols:
104        try:
105            serverSock = socket.socket()
106            serverSock.connect((host,port))
107            if timeout:
108                serverSock.settimeout(timeout)
109        except Exception as e:
110            if verbose:
111                sys.stderr.write("Unable to connect to %s:%s\n" % (host,port))
112            return None
113           
114        try:
115            if handshake_callback:
116                if not handshake_callback(serverSock):
117                    return None
118        except Exception as e:
119            traceback.print_exc(file=sys.stderr)
120            return None
121           
122        try:
123            conn = startSSLTLS(serverSock, mode='client', protocol=p, cipher_list=cl, timeout=timeout)
124            break
125        except ValueError as e:
126            if verbose:
127                sys.stderr.write("%s protocol not supported by your openssl library, trying others...\n" % pname)
128        except SSL.Error as e:
129            if verbose:
130                sys.stderr.write("Exception during %s handshake with server. (%s)" % (pname, e))
131                #traceback.print_exc(file=sys.stderr)
132                sys.stderr.write("\nThis could happen because the server requires "
133                                 "certain SSL/TLS versions or a client certificate."
134                                 "  Have no fear, we'll keep trying...\n")
135        except Exception as e:
136            sys.stderr.write("Unknown exception during %s handshake with server: \n" % pname)
137            traceback.print_exc(file=sys.stderr)
138
139    return conn
140
141
142def fetchCertificateChain(connection):
143    chain = connection.get_peer_cert_chain()
144    if chain:
145        return chain
146    return None
147
148
149def normalizeCertificateName(cert_name):
150    n = cert_name.get_components()
151    n.sort()
152    return tuple(n)
153
154
155def normalizeCertificateChain(chain):
156    # Organize certificates by subject and issuer for quick lookups
157    subject_table = {}
158    issuer_table = {}
159    for c in chain:
160        subject_table[normalizeCertificateName(c.get_subject())] = c
161        issuer_table[normalizeCertificateName(c.get_issuer())] = c
162
163    # Now find root or highest-level intermediary
164    root = None
165    for c in chain:
166        i = normalizeCertificateName(c.get_issuer())
167        s = normalizeCertificateName(c.get_subject())
168        if (i == s) or (i not in subject_table):
169            if root != None:
170                sys.stderr.write("WARN: Multiple root certificates found or broken certificate chain detected.")
171            else:
172                # Go with the first identified "root", since that's more likely to link up with the server cert
173                root = c
174
175    # Finally, build the chain from the top-down in the correct order
176    new_chain = []
177    nxt = root
178    while nxt != None:
179        new_chain = [nxt] + new_chain
180        s = normalizeCertificateName(nxt.get_subject())
181        nxt = issuer_table.get(s)
182   
183    return new_chain
184   
185
186def genFakeKey(certificate):
187    fake_key = OpenSSL.crypto.PKey()
188    old_pubkey = certificate.get_pubkey()
189    fake_key.generate_key(old_pubkey.type(), old_pubkey.bits())
190
191    return fake_key
192
193
194def getDigestAlgorithm(certificate):
195    # XXX: ugly hack because openssl API for this is limited
196    algo = certificate.get_signature_algorithm()
197    if b'With' in algo:
198        return algo.split(b'With', 1)[0].decode('utf-8')
199    return None
200
201
202def deleteExtension(certificate, index):
203    '''
204    A dirty hack until this is implemented in pyOpenSSL. See:
205    https://github.com/pyca/pyopenssl/issues/152
206    '''
207    from OpenSSL._util import lib as libssl
208
209    #print(certificate._x509, index)
210    #print(libssl.X509_get_ext_count(certificate._x509))
211    ext = libssl.X509_delete_ext(certificate._x509, index)
212    #XXX: memory leak.  supposed to free ext here
213
214
215def removePeskyExtensions(certificate):
216    #for index in range(0,certificate.get_extension_count()):
217    #    e = certificate.get_extension(index)
218    #    print("extension %d: %s\n" % (index, e.get_short_name()), e)
219
220    index = 0
221    while index < certificate.get_extension_count():
222        e = certificate.get_extension(index)
223        if e.get_short_name() in (b'subjectKeyIdentifier', b'authorityKeyIdentifier'):
224            deleteExtension(certificate, index)
225            #XXX: would be nice if each of these extensions were re-added with appropriate values
226            index -= 1
227        index += 1
228   
229    #for index in range(0,certificate.get_extension_count()):
230    #    e = certificate.get_extension(index)
231    #    print("extension %d: %s\n" % (index, e.get_short_name()), e)
232
233
234def randomizeSerialNumber(certificate):
235    certificate.set_serial_number(random.randint(0,2**64))
236   
237def genFakeCertificateChain(cert_chain):
238    ret_val = []
239    cert_chain.reverse() # start with highest level authority
240
241    c = cert_chain[0]
242    i = normalizeCertificateName(c.get_issuer())
243    s = normalizeCertificateName(c.get_subject())
244    if s != i:
245        # XXX: consider retrieving root locally and including a forged version instead
246        c.set_issuer(c.get_subject())
247    k = genFakeKey(c)
248    c.set_pubkey(k)
249    removePeskyExtensions(c)
250    randomizeSerialNumber(c)
251    c.sign(k, getDigestAlgorithm(c))
252    ret_val.append(c)
253
254    prev = k
255    for c in cert_chain[1:]:
256        k = genFakeKey(c)
257        c.set_pubkey(k)
258        removePeskyExtensions(c)
259        randomizeSerialNumber(c)
260        c.sign(prev, getDigestAlgorithm(c))
261        prev = k
262        ret_val.append(c)
263
264    ret_val.reverse()
265    return k,ret_val
Note: See TracBrowser for help on using the repository browser.