xref: /unit/test/unit/applications/websockets.py (revision 1596:b7e2d4d92624)
1import base64
2import hashlib
3import itertools
4import pytest
5import random
6import re
7import select
8import struct
9
10from unit.applications.proto import TestApplicationProto
11
12GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
13
14
15class TestApplicationWebsocket(TestApplicationProto):
16
17    OP_CONT = 0x00
18    OP_TEXT = 0x01
19    OP_BINARY = 0x02
20    OP_CLOSE = 0x08
21    OP_PING = 0x09
22    OP_PONG = 0x0A
23    CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011]
24
25    def key(self):
26        raw_key = bytes(random.getrandbits(8) for _ in range(16))
27        return base64.b64encode(raw_key).decode()
28
29    def accept(self, key):
30        sha1 = hashlib.sha1((key + GUID).encode()).digest()
31        return base64.b64encode(sha1).decode()
32
33    def upgrade(self, headers=None):
34        key = None
35
36        if headers is None:
37            key = self.key()
38            headers = {
39                'Host': 'localhost',
40                'Upgrade': 'websocket',
41                'Connection': 'Upgrade',
42                'Sec-WebSocket-Key': key,
43                'Sec-WebSocket-Protocol': 'chat',
44                'Sec-WebSocket-Version': 13,
45            }
46
47        _, sock = self.get(
48            headers=headers,
49            no_recv=True,
50            start=True,
51        )
52
53        resp = ''
54        while True:
55            rlist = select.select([sock], [], [], 60)[0]
56            if not rlist:
57                pytest.fail('Can\'t read response from server.')
58
59            resp += sock.recv(4096).decode()
60
61            if (
62                re.search('101 Switching Protocols', resp)
63                and resp[-4:] == '\r\n\r\n'
64            ):
65                resp = self._resp_to_dict(resp)
66                break
67
68        return (resp, sock, key)
69
70    def apply_mask(self, data, mask):
71        return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask)))
72
73    def serialize_close(self, code=1000, reason=''):
74        return struct.pack('!H', code) + reason.encode('utf-8')
75
76    def frame_read(self, sock, read_timeout=60):
77        def recv_bytes(sock, bytes):
78            data = b''
79            while True:
80                rlist = select.select([sock], [], [], read_timeout)[0]
81                if not rlist:
82                    # For all current cases if the "read_timeout" was changed
83                    # than test do not expect to get a response from server.
84                    if read_timeout == 60:
85                        pytest.fail('Can\'t read response from server.')
86                    break
87
88                data += sock.recv(bytes - len(data))
89
90                if len(data) == bytes:
91                    break
92
93            return data
94
95        frame = {}
96
97        head1, = struct.unpack('!B', recv_bytes(sock, 1))
98        head2, = struct.unpack('!B', recv_bytes(sock, 1))
99
100        frame['fin'] = bool(head1 & 0b10000000)
101        frame['rsv1'] = bool(head1 & 0b01000000)
102        frame['rsv2'] = bool(head1 & 0b00100000)
103        frame['rsv3'] = bool(head1 & 0b00010000)
104        frame['opcode'] = head1 & 0b00001111
105        frame['mask'] = head2 & 0b10000000
106
107        length = head2 & 0b01111111
108        if length == 126:
109            data = recv_bytes(sock, 2)
110            length, = struct.unpack('!H', data)
111        elif length == 127:
112            data = recv_bytes(sock, 8)
113            length, = struct.unpack('!Q', data)
114
115        if frame['mask']:
116            mask_bits = recv_bytes(sock, 4)
117
118        data = b''
119
120        if length != 0:
121            data = recv_bytes(sock, length)
122
123        if frame['mask']:
124            data = self.apply_mask(data, mask_bits)
125
126        if frame['opcode'] == self.OP_CLOSE:
127            if length >= 2:
128                code, = struct.unpack('!H', data[:2])
129                reason = data[2:].decode('utf-8')
130                if not (code in self.CLOSE_CODES or 3000 <= code < 5000):
131                    pytest.fail('Invalid status code')
132                frame['code'] = code
133                frame['reason'] = reason
134            elif length == 0:
135                frame['code'] = 1005
136                frame['reason'] = ''
137            else:
138                pytest.fail('Close frame too short')
139
140        frame['data'] = data
141
142        if frame['mask']:
143            pytest.fail('Received frame with mask')
144
145        return frame
146
147    def frame_to_send(
148        self,
149        opcode,
150        data,
151        fin=True,
152        length=None,
153        rsv1=False,
154        rsv2=False,
155        rsv3=False,
156        mask=True,
157    ):
158        frame = b''
159
160        if isinstance(data, str):
161            data = data.encode('utf-8')
162
163        head1 = (
164            (0b10000000 if fin else 0)
165            | (0b01000000 if rsv1 else 0)
166            | (0b00100000 if rsv2 else 0)
167            | (0b00010000 if rsv3 else 0)
168            | opcode
169        )
170
171        head2 = 0b10000000 if mask else 0
172
173        data_length = len(data) if length is None else length
174        if data_length < 126:
175            frame += struct.pack('!BB', head1, head2 | data_length)
176        elif data_length < 65536:
177            frame += struct.pack('!BBH', head1, head2 | 126, data_length)
178        else:
179            frame += struct.pack('!BBQ', head1, head2 | 127, data_length)
180
181        if mask:
182            mask_bits = struct.pack('!I', random.getrandbits(32))
183            frame += mask_bits
184
185        if mask:
186            frame += self.apply_mask(data, mask_bits)
187        else:
188            frame += data
189
190        return frame
191
192    def frame_write(self, sock, *args, **kwargs):
193        chopsize = kwargs.pop('chopsize') if 'chopsize' in kwargs else None
194
195        frame = self.frame_to_send(*args, **kwargs)
196
197        if chopsize is None:
198            try:
199                sock.sendall(frame)
200            except BrokenPipeError:
201                pass
202
203        else:
204            pos = 0
205            frame_len = len(frame)
206            while pos < frame_len:
207                end = min(pos + chopsize, frame_len)
208                try:
209                    sock.sendall(frame[pos:end])
210                except BrokenPipeError:
211                    end = frame_len
212                pos = end
213
214    def message(self, sock, type, message, fragmention_size=None, **kwargs):
215        message_len = len(message)
216
217        if fragmention_size is None:
218            fragmention_size = message_len
219
220        if message_len <= fragmention_size:
221            self.frame_write(sock, type, message, **kwargs)
222            return
223
224        pos = 0
225        op_code = type
226        while pos < message_len:
227            end = min(pos + fragmention_size, message_len)
228            fin = end == message_len
229            self.frame_write(
230                sock, op_code, message[pos:end], fin=fin, **kwargs
231            )
232            op_code = self.OP_CONT
233            pos = end
234
235    def message_read(self, sock, read_timeout=60):
236        frame = self.frame_read(sock, read_timeout=read_timeout)
237
238        while not frame['fin']:
239            temp = self.frame_read(sock, read_timeout=read_timeout)
240            frame['data'] += temp['data']
241            frame['fin'] = temp['fin']
242
243        return frame
244