source: trunk/lib/bletchley/ssltls.py @ 131

Last change on this file since 131 was 131, checked in by tim, 7 years ago

minor annoyances

File size: 10.4 KB
Line 
1'''
2Utilities for manipulating certificates and SSL/TLS connections.
3
4Copyright (C) 2014,2016,2017 Blindspot Security LLC
5Author: Timothy D. Morgan
6
7 This program is free software: you can redistribute it and/or modify
8 it under the terms of the GNU Lesser General Public License, version 3,
9 as published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program.  If not, see <http://www.gnu.org/licenses/>.
18'''
19
20import sys
21import argparse
22import traceback
23import random
24import time
25import socket
26try:
27    import OpenSSL
28    from OpenSSL import SSL
29except:
30    sys.stderr.write('ERROR: Could not locate pyOpenSSL module.  Under Debian-based systems, try:\n')
31    sys.stderr.write('       # apt-get install python3-openssl\n')
32    sys.stderr.write('NOTE: pyOpenSSL version 0.14 or later is required!\n')
33    sys.exit(2)
34try:
35    import cffi
36except:
37    sys.stderr.write('ERROR: Could not locate cffi module.  Under Debian-based systems, try:\n')
38    sys.stderr.write('       # apt-get install python3-cffi\n')
39    sys.stderr.write('NOTE: This is a requirement because pyOpenSSL does not provide '
40                     'certificate extension removal procedures.  Consider lobbying for the '
41                     'implementation of this:\n  https://github.com/pyca/pyopenssl/issues/152\n')
42    sys.exit(2)
43
44
45def createContext(method=SSL.TLSv1_METHOD, key=None, certChain=[]):
46    context = SSL.Context(method)
47    context.set_verify(SSL.VERIFY_NONE, (lambda a,b,c,d,e: True))
48
49    if key and len(certChain) > 0:
50        context.use_privatekey(key)
51        context.use_certificate(certChain[0])
52        for c in certChain[1:]:
53            context.add_extra_chain_cert(c)
54   
55    return context
56
57
58def startSSLTLS(sock, mode='client', protocol=SSL.TLSv1_METHOD, key=None, certChain=[], cipher_list=None, timeout=None):
59    '''
60    cipher_list names drawn from:
61      openssl ciphers -v "ALL:@SECLEVEL=0"
62    '''
63   
64    context = createContext(protocol, key=key, certChain=certChain)
65    if cipher_list:
66        context.set_cipher_list(cipher_list)
67    if timeout:
68        context.set_timeout(timeout)
69       
70    #if not key and mode == 'server':
71    #context.set_options(OpenSSL.SSL.OP_SINGLE_DH_USE)
72    #context.set_options(OpenSSL.SSL.OP_EPHEMERAL_RSA)
73   
74    conn = SSL.Connection(context, sock)
75    if mode == 'client':
76        conn.set_connect_state()
77        if timeout:
78            # This polling is needed because the socket timeouts have put the
79            # socket in non-blocking mode
80            start = time.time()+timeout
81            while time.time() < start:
82                try:
83                    conn.do_handshake()
84                    break
85                except (OpenSSL.SSL.WantReadError,OpenSSL.SSL.WantWriteError) as e:
86                    time.sleep(0.00001)
87        else:
88            conn.do_handshake()
89
90    else:
91        conn.set_accept_state()
92   
93    return conn
94
95
96def ConnectSSLTLS(host, port, cipher_list=None, timeout=None, handshake_callback=None, verbose=True):
97    backup_cipher_list = b'DES-CBC3-SHA:RC4-MD5:RC4-SHA:AES128-SHA:AES256-SHA:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:DHE-RSA-AES256-SHA256:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA256:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES128-SHA:DHE-RSA-AES128-SHA:ADH-AES256-GCM-SHA384'
98    protocols = [("SSL 2/3", SSL.SSLv23_METHOD, cipher_list),
99                 ("SSL 2/3", SSL.SSLv23_METHOD, backup_cipher_list),
100                 ("TLS 1.0", SSL.TLSv1_METHOD, cipher_list),
101                 ("TLS 1.0", SSL.TLSv1_METHOD, backup_cipher_list), 
102                 ("TLS 1.1", SSL.TLSv1_1_METHOD, cipher_list),
103                 ("TLS 1.1", SSL.TLSv1_1_METHOD, backup_cipher_list),
104                 ("TLS 1.2", SSL.TLSv1_2_METHOD, cipher_list),
105                 ("TLS 1.2", SSL.TLSv1_2_METHOD, backup_cipher_list),
106                 ("SSL 3.0", SSL.SSLv3_METHOD, cipher_list),
107                 ("SSL 3.0", SSL.SSLv3_METHOD, backup_cipher_list),
108                 ("SSL 2.0", SSL.SSLv2_METHOD, cipher_list),
109                 ("SSL 2.0", SSL.SSLv2_METHOD, backup_cipher_list)]
110
111    conn = None
112    for pname,p,cl in protocols:
113        try:
114            serverSock = socket.socket()
115            serverSock.connect((host,port))
116            if timeout:
117                serverSock.settimeout(timeout)
118        except Exception as e:
119            if verbose:
120                sys.stderr.write("Unable to connect to %s:%s\n" % (host,port))
121            return None
122           
123        try:
124            if handshake_callback:
125                if not handshake_callback(serverSock):
126                    return None
127        except Exception as e:
128            traceback.print_exc(file=sys.stderr)
129            return None
130           
131        try:
132            conn = startSSLTLS(serverSock, mode='client', protocol=p, cipher_list=cl, timeout=timeout)
133            break
134        except ValueError as e:
135            if verbose:
136                sys.stderr.write("%s protocol not supported by your openssl library, trying others...\n" % pname)
137        except SSL.Error as e:
138            if verbose:
139                sys.stderr.write("Exception during %s handshake with server. (%s)" % (pname, e))
140                #traceback.print_exc(file=sys.stderr)
141                sys.stderr.write("\nThis could happen because the server requires "
142                                 "certain SSL/TLS versions or a client certificate."
143                                 "  Have no fear, we'll keep trying...\n")
144        except Exception as e:
145            sys.stderr.write("Unknown exception during %s handshake with server: \n" % pname)
146            traceback.print_exc(file=sys.stderr)
147
148    return conn
149
150
151def fetchCertificateChain(connection):
152    chain = connection.get_peer_cert_chain()
153    if chain:
154        return chain
155    return None
156
157
158def normalizeCertificateName(cert_name):
159    n = cert_name.get_components()
160    n.sort()
161    return tuple(n)
162
163
164def normalizeCertificateChain(chain):
165    # Organize certificates by subject and issuer for quick lookups
166    subject_table = {}
167    issuer_table = {}
168    for c in chain:
169        subject_table[normalizeCertificateName(c.get_subject())] = c
170        issuer_table[normalizeCertificateName(c.get_issuer())] = c
171
172    # Now find root or highest-level intermediary
173    root = None
174    for c in chain:
175        i = normalizeCertificateName(c.get_issuer())
176        s = normalizeCertificateName(c.get_subject())
177        if (i == s) or (i not in subject_table):
178            if root != None:
179                sys.stderr.write("WARN: Multiple root certificates found or broken certificate chain detected.")
180            else:
181                # Go with the first identified "root", since that's more likely to link up with the server cert
182                root = c
183
184    # Finally, build the chain from the top-down in the correct order
185    new_chain = []
186    nxt = root
187    while nxt != None:
188        new_chain = [nxt] + new_chain
189        s = normalizeCertificateName(nxt.get_subject())
190        nxt = issuer_table.get(s)
191   
192    return new_chain
193   
194
195def genFakeKey(certificate):
196    fake_key = OpenSSL.crypto.PKey()
197    old_pubkey = certificate.get_pubkey()
198    fake_key.generate_key(old_pubkey.type(), old_pubkey.bits())
199
200    return fake_key
201
202
203def getDigestAlgorithm(certificate):
204    # XXX: ugly hack because openssl API for this is limited
205    algo = certificate.get_signature_algorithm()
206    if b'With' in algo:
207        return algo.split(b'With', 1)[0].decode('utf-8')
208    return None
209
210
211def deleteExtension(certificate, index):
212    '''
213    A dirty hack until this is implemented in pyOpenSSL. See:
214    https://github.com/pyca/pyopenssl/issues/152
215    '''
216    ffi = cffi.FFI()
217    ffi.cdef('''void* X509_delete_ext(void* x, int loc);''')
218
219    # Try to load libssl using several recent names because package
220    # maintainers have the blinders on and don't have a universal
221    # symlink to the most recent version.
222    libssl = None
223    for libname in ('libssl.so','libssl.so.1.0.2', 'libssl.so.1.0.1', 'libssl.so.1.0.0','libssl.so.0.9.8'):
224        try:
225            libssl = ffi.dlopen(libname)
226            break
227        except OSError as e:
228            pass
229   
230    ext = libssl.X509_delete_ext(certificate._x509, index)
231    #XXX: memory leak.  supposed to free ext here
232
233
234def removePeskyExtensions(certificate):
235    #for index in range(0,certificate.get_extension_count()):
236    #    e = certificate.get_extension(index)
237    #    print("extension %d: %s\n" % (index, e.get_short_name()), e)
238
239    index = 0
240    while index < certificate.get_extension_count():
241        e = certificate.get_extension(index)
242        if e.get_short_name() in (b'subjectKeyIdentifier', b'authorityKeyIdentifier'):
243            deleteExtension(certificate, index)
244            #XXX: would be nice if each of these extensions were re-added with appropriate values
245            index -= 1
246        index += 1
247   
248    #for index in range(0,certificate.get_extension_count()):
249    #    e = certificate.get_extension(index)
250    #    print("extension %d: %s\n" % (index, e.get_short_name()), e)
251
252
253def randomizeSerialNumber(certificate):
254    certificate.set_serial_number(random.randint(0,2**64))
255   
256def genFakeCertificateChain(cert_chain):
257    ret_val = []
258    cert_chain.reverse() # start with highest level authority
259
260    c = cert_chain[0]
261    i = normalizeCertificateName(c.get_issuer())
262    s = normalizeCertificateName(c.get_subject())
263    if s != i:
264        # XXX: consider retrieving root locally and including a forged version instead
265        c.set_issuer(c.get_subject())
266    k = genFakeKey(c)
267    c.set_pubkey(k)
268    removePeskyExtensions(c)
269    randomizeSerialNumber(c)
270    c.sign(k, getDigestAlgorithm(c))
271    ret_val.append(c)
272
273    prev = k
274    for c in cert_chain[1:]:
275        k = genFakeKey(c)
276        c.set_pubkey(k)
277        removePeskyExtensions(c)
278        randomizeSerialNumber(c)
279        c.sign(prev, getDigestAlgorithm(c))
280        prev = k
281        ret_val.append(c)
282
283    ret_val.reverse()
284    return k,ret_val
Note: See TracBrowser for help on using the repository browser.