xref: /unit/test/unit/applications/websockets.py (revision 1444:8f7f7970c07a)
1import re
2import random
3import base64
4import struct
5import select
6import hashlib
7import itertools
8from unit.applications.proto import TestApplicationProto
9
10GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
11
12
13class TestApplicationWebsocket(TestApplicationProto):
14
15    OP_CONT = 0x00
16    OP_TEXT = 0x01
17    OP_BINARY = 0x02
18    OP_CLOSE = 0x08
19    OP_PING = 0x09
20    OP_PONG = 0x0A
21    CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011]
22
23    def __init__(self, preinit=False):
24        self.preinit = preinit
25
26    def key(self):
27        raw_key = bytes(random.getrandbits(8) for _ in range(16))
28        return base64.b64encode(raw_key).decode()
29
30    def accept(self, key):
31        sha1 = hashlib.sha1((key + GUID).encode()).digest()
32        return base64.b64encode(sha1).decode()
33
34    def upgrade(self, headers=None):
35        key = None
36
37        if headers is None:
38            key = self.key()
39            headers = {
40                'Host': 'localhost',
41                'Upgrade': 'websocket',
42                'Connection': 'Upgrade',
43                'Sec-WebSocket-Key': key,
44                'Sec-WebSocket-Protocol': 'chat',
45                'Sec-WebSocket-Version': 13,
46            }
47
48        _, sock = self.get(
49            headers=headers,
50            no_recv=True,
51            start=True,
52        )
53
54        resp = ''
55        while True:
56            rlist = select.select([sock], [], [], 60)[0]
57            if not rlist:
58                self.fail('Can\'t read response from server.')
59
60            resp += sock.recv(4096).decode()
61
62            if (
63                re.search('101 Switching Protocols', resp)
64                and resp[-4:] == '\r\n\r\n'
65            ):
66                resp = self._resp_to_dict(resp)
67                break
68
69        return (resp, sock, key)
70
71    def apply_mask(self, data, mask):
72        return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask)))
73
74    def serialize_close(self, code=1000, reason=''):
75        return struct.pack('!H', code) + reason.encode('utf-8')
76
77    def frame_read(self, sock, read_timeout=60):
78        def recv_bytes(sock, bytes):
79            data = b''
80            while True:
81                rlist = select.select([sock], [], [], read_timeout)[0]
82                if not rlist:
83                    # For all current cases if the "read_timeout" was changed
84                    # than test do not expect to get a response from server.
85                    if read_timeout == 60:
86                        self.fail('Can\'t read response from server.')
87                    break
88
89                data += sock.recv(bytes - len(data))
90
91                if len(data) == bytes:
92                    break
93
94            return data
95
96        frame = {}
97
98        head1, = struct.unpack('!B', recv_bytes(sock, 1))
99        head2, = struct.unpack('!B', recv_bytes(sock, 1))
100
101        frame['fin'] = bool(head1 & 0b10000000)
102        frame['rsv1'] = bool(head1 & 0b01000000)
103        frame['rsv2'] = bool(head1 & 0b00100000)
104        frame['rsv3'] = bool(head1 & 0b00010000)
105        frame['opcode'] = head1 & 0b00001111
106        frame['mask'] = head2 & 0b10000000
107
108        length = head2 & 0b01111111
109        if length == 126:
110            data = recv_bytes(sock, 2)
111            length, = struct.unpack('!H', data)
112        elif length == 127:
113            data = recv_bytes(sock, 8)
114            length, = struct.unpack('!Q', data)
115
116        if frame['mask']:
117            mask_bits = recv_bytes(sock, 4)
118
119        data = b''
120
121        if length != 0:
122            data = recv_bytes(sock, length)
123
124        if frame['mask']:
125            data = self.apply_mask(data, mask_bits)
126
127        if frame['opcode'] == self.OP_CLOSE:
128            if length >= 2:
129                code, = struct.unpack('!H', data[:2])
130                reason = data[2:].decode('utf-8')
131                if not (code in self.CLOSE_CODES or 3000 <= code < 5000):
132                    self.fail('Invalid status code')
133                frame['code'] = code
134                frame['reason'] = reason
135            elif length == 0:
136                frame['code'] = 1005
137                frame['reason'] = ''
138            else:
139                self.fail('Close frame too short')
140
141        frame['data'] = data
142
143        if frame['mask']:
144            self.fail('Received frame with mask')
145
146        return frame
147
148    def frame_to_send(
149        self,
150        opcode,
151        data,
152        fin=True,
153        length=None,
154        rsv1=False,
155        rsv2=False,
156        rsv3=False,
157        mask=True,
158    ):
159        frame = b''
160
161        if isinstance(data, str):
162            data = data.encode('utf-8')
163
164        head1 = (
165            (0b10000000 if fin else 0)
166            | (0b01000000 if rsv1 else 0)
167            | (0b00100000 if rsv2 else 0)
168            | (0b00010000 if rsv3 else 0)
169            | opcode
170        )
171
172        head2 = 0b10000000 if mask else 0
173
174        data_length = len(data) if length is None else length
175        if data_length < 126:
176            frame += struct.pack('!BB', head1, head2 | data_length)
177        elif data_length < 65536:
178            frame += struct.pack('!BBH', head1, head2 | 126, data_length)
179        else:
180            frame += struct.pack('!BBQ', head1, head2 | 127, data_length)
181
182        if mask:
183            mask_bits = struct.pack('!I', random.getrandbits(32))
184            frame += mask_bits
185
186        if mask:
187            frame += self.apply_mask(data, mask_bits)
188        else:
189            frame += data
190
191        return frame
192
193    def frame_write(self, sock, *args, **kwargs):
194        chopsize = kwargs.pop('chopsize') if 'chopsize' in kwargs else None
195
196        frame = self.frame_to_send(*args, **kwargs)
197
198        if chopsize is None:
199            try:
200                sock.sendall(frame)
201            except BrokenPipeError:
202                pass
203
204        else:
205            pos = 0
206            frame_len = len(frame)
207            while pos < frame_len:
208                end = min(pos + chopsize, frame_len)
209                try:
210                    sock.sendall(frame[pos:end])
211                except BrokenPipeError:
212                    end = frame_len
213                pos = end
214
215    def message(self, sock, type, message, fragmention_size=None, **kwargs):
216        message_len = len(message)
217
218        if fragmention_size is None:
219            fragmention_size = message_len
220
221        if message_len <= fragmention_size:
222            self.frame_write(sock, type, message, **kwargs)
223            return
224
225        pos = 0
226        op_code = type
227        while pos < message_len:
228            end = min(pos + fragmention_size, message_len)
229            fin = end == message_len
230            self.frame_write(
231                sock, op_code, message[pos:end], fin=fin, **kwargs
232            )
233            op_code = self.OP_CONT
234            pos = end
235
236    def message_read(self, sock, read_timeout=60):
237        frame = self.frame_read(sock, read_timeout=read_timeout)
238
239        while not frame['fin']:
240            temp = self.frame_read(sock, read_timeout=read_timeout)
241            frame['data'] += temp['data']
242            frame['fin'] = temp['fin']
243
244        return frame
245