xref: /unit/test/unit/applications/tls.py (revision 1843:1dab4306e8da)
1import os
2import ssl
3import subprocess
4
5from unit.applications.proto import TestApplicationProto
6from unit.option import option
7
8
9class TestApplicationTLS(TestApplicationProto):
10    def setup_method(self):
11        self.context = ssl.create_default_context()
12        self.context.check_hostname = False
13        self.context.verify_mode = ssl.CERT_NONE
14
15    def certificate(self, name='default', load=True):
16        self.openssl_conf()
17
18        subprocess.call(
19            [
20                'openssl',
21                'req',
22                '-x509',
23                '-new',
24                '-subj',    '/CN=' + name + '/',
25                '-config',  option.temp_dir + '/openssl.conf',
26                '-out',     option.temp_dir + '/' + name + '.crt',
27                '-keyout',  option.temp_dir + '/' + name + '.key',
28            ],
29            stderr=subprocess.STDOUT,
30        )
31
32        if load:
33            self.certificate_load(name)
34
35    def certificate_load(self, crt, key=None):
36        if key is None:
37            key = crt
38
39        key_path = option.temp_dir + '/' + key + '.key'
40        crt_path = option.temp_dir + '/' + crt + '.crt'
41
42        with open(key_path, 'rb') as k, open(crt_path, 'rb') as c:
43            return self.conf(k.read() + c.read(), '/certificates/' + crt)
44
45    def get_ssl(self, **kwargs):
46        return self.get(wrapper=self.context.wrap_socket, **kwargs)
47
48    def post_ssl(self, **kwargs):
49        return self.post(wrapper=self.context.wrap_socket, **kwargs)
50
51    def get_server_certificate(self, addr=('127.0.0.1', 7080)):
52
53        ssl_list = dir(ssl)
54
55        if 'PROTOCOL_TLS' in ssl_list:
56            ssl_version = ssl.PROTOCOL_TLS
57
58        elif 'PROTOCOL_TLSv1_2' in ssl_list:
59            ssl_version = ssl.PROTOCOL_TLSv1_2
60
61        else:
62            ssl_version = ssl.PROTOCOL_TLSv1_1
63
64        return ssl.get_server_certificate(addr, ssl_version=ssl_version)
65
66    def openssl_conf(self, rewrite=False, alt_names=[]):
67        conf_path = option.temp_dir + '/openssl.conf'
68
69        if not rewrite and os.path.exists(conf_path):
70            return
71
72        # Generates alt_names section with dns names
73        a_names = "[alt_names]\n"
74        for i, k in enumerate(alt_names, 1):
75            a_names += "DNS.%d = %s\n" % (i, k)
76
77            # Generates section for sign request extension
78        a_sec  = """req_extensions = myca_req_extensions
79
80[ myca_req_extensions ]
81subjectAltName = @alt_names
82
83{a_names}""".format(a_names=a_names)
84
85        with open(conf_path, 'w') as f:
86            f.write(
87                """[ req ]
88default_bits = 2048
89encrypt_key = no
90distinguished_name = req_distinguished_name
91
92{a_sec}
93[ req_distinguished_name ]""".format(a_sec=a_sec if alt_names else "")
94            )
95
96    def load(self, script, name=None):
97        if name is None:
98            name = script
99
100        script_path = option.test_dir + '/python/' + script
101
102        self._load_conf(
103            {
104                "listeners": {"*:7080": {"pass": "applications/" + name}},
105                "applications": {
106                    name: {
107                        "type": "python",
108                        "processes": {"spare": 0},
109                        "path": script_path,
110                        "working_directory": script_path,
111                        "module": "wsgi",
112                    }
113                },
114            }
115        )
116