xref: /unit/test/unit/applications/websockets.py (revision 2616:ab2896c980ab)
1import base64
2import hashlib
3import itertools
4import random
5import select
6import struct
7
8import pytest
9
10from unit.applications.proto import ApplicationProto
11
12GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
13
14
15class ApplicationWebsocket(ApplicationProto):
16
17    OP_CONT = 0x00
18    OP_TEXT = 0x01
19    OP_BINARY = 0x02
20    OP_CLOSE = 0x08
21    OP_PING = 0x09
22    OP_PONG = 0x0A
23    CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011]
24
25    def key(self):
26        raw_key = bytes(random.getrandbits(8) for _ in range(16))
27        return base64.b64encode(raw_key).decode()
28
29    def accept(self, key):
30        sha1 = hashlib.sha1((key + GUID).encode()).digest()
31        return base64.b64encode(sha1).decode()
32
33    def upgrade(self, headers=None):
34        key = None
35
36        if headers is None:
37            key = self.key()
38            headers = {
39                'Host': 'localhost',
40                'Upgrade': 'websocket',
41                'Connection': 'Upgrade',
42                'Sec-WebSocket-Key': key,
43                'Sec-WebSocket-Protocol': 'chat, phone, video',
44                'Sec-WebSocket-Version': 13,
45            }
46
47        sock = self.get(
48            headers=headers,
49            no_recv=True,
50        )
51
52        resp = ''
53        while True:
54            rlist = select.select([sock], [], [], 60)[0]
55            if not rlist:
56                pytest.fail("Can't read response from server.")
57
58            resp += sock.recv(4096).decode()
59
60            if resp.startswith('HTTP/') and '\r\n\r\n' in resp:
61                resp = self._resp_to_dict(resp)
62                break
63
64        return (resp, sock, key)
65
66    def apply_mask(self, data, mask):
67        return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask)))
68
69    def serialize_close(self, code=1000, reason=''):
70        return struct.pack('!H', code) + reason.encode('utf-8')
71
72    def frame_read(self, sock, read_timeout=60):
73        def recv_bytes(sock, bytes_len):
74            data = b''
75            while True:
76                rlist = select.select([sock], [], [], read_timeout)[0]
77                if not rlist:
78                    # For all current cases if the "read_timeout" was changed
79                    # than test do not expect to get a response from server.
80                    if read_timeout == 60:
81                        pytest.fail("Can't read response from server.")
82                    break
83
84                data += sock.recv(bytes_len - len(data))
85
86                if len(data) == bytes_len:
87                    break
88
89            return data
90
91        frame = {}
92
93        (head1,) = struct.unpack('!B', recv_bytes(sock, 1))
94        (head2,) = struct.unpack('!B', recv_bytes(sock, 1))
95
96        frame['fin'] = bool(head1 & 0b10000000)
97        frame['rsv1'] = bool(head1 & 0b01000000)
98        frame['rsv2'] = bool(head1 & 0b00100000)
99        frame['rsv3'] = bool(head1 & 0b00010000)
100        frame['opcode'] = head1 & 0b00001111
101        frame['mask'] = head2 & 0b10000000
102
103        length = head2 & 0b01111111
104        if length == 126:
105            data = recv_bytes(sock, 2)
106            (length,) = struct.unpack('!H', data)
107        elif length == 127:
108            data = recv_bytes(sock, 8)
109            (length,) = struct.unpack('!Q', data)
110
111        if frame['mask']:
112            mask_bits = recv_bytes(sock, 4)
113
114        data = b''
115
116        if length != 0:
117            data = recv_bytes(sock, length)
118
119        if frame['mask']:
120            data = self.apply_mask(data, mask_bits)
121
122        if frame['opcode'] == self.OP_CLOSE:
123            if length >= 2:
124                (code,) = struct.unpack('!H', data[:2])
125                reason = data[2:].decode('utf-8')
126                if not (code in self.CLOSE_CODES or 3000 <= code < 5000):
127                    pytest.fail('Invalid status code')
128                frame['code'] = code
129                frame['reason'] = reason
130            elif length == 0:
131                frame['code'] = 1005
132                frame['reason'] = ''
133            else:
134                pytest.fail('Close frame too short')
135
136        frame['data'] = data
137
138        if frame['mask']:
139            pytest.fail('Received frame with mask')
140
141        return frame
142
143    def frame_to_send(
144        self,
145        opcode,
146        data,
147        fin=True,
148        length=None,
149        rsv1=False,
150        rsv2=False,
151        rsv3=False,
152        mask=True,
153    ):
154        frame = b''
155
156        if isinstance(data, str):
157            data = data.encode('utf-8')
158
159        head1 = (
160            (0b10000000 if fin else 0)
161            | (0b01000000 if rsv1 else 0)
162            | (0b00100000 if rsv2 else 0)
163            | (0b00010000 if rsv3 else 0)
164            | opcode
165        )
166
167        head2 = 0b10000000 if mask else 0
168
169        data_length = len(data) if length is None else length
170        if data_length < 126:
171            frame += struct.pack('!BB', head1, head2 | data_length)
172        elif data_length < 65536:
173            frame += struct.pack('!BBH', head1, head2 | 126, data_length)
174        else:
175            frame += struct.pack('!BBQ', head1, head2 | 127, data_length)
176
177        if mask:
178            mask_bits = struct.pack('!I', random.getrandbits(32))
179            frame += mask_bits
180
181        if mask:
182            frame += self.apply_mask(data, mask_bits)
183        else:
184            frame += data
185
186        return frame
187
188    def frame_write(self, sock, *args, **kwargs):
189        chopsize = kwargs.pop('chopsize') if 'chopsize' in kwargs else None
190
191        frame = self.frame_to_send(*args, **kwargs)
192
193        if chopsize is None:
194            try:
195                sock.sendall(frame)
196            except BrokenPipeError:
197                pass
198
199        else:
200            pos = 0
201            frame_len = len(frame)
202            while pos < frame_len:
203                end = min(pos + chopsize, frame_len)
204                try:
205                    sock.sendall(frame[pos:end])
206                except BrokenPipeError:
207                    end = frame_len
208                pos = end
209
210    def message(self, sock, mes_type, message, fragmention_size=None, **kwargs):
211        message_len = len(message)
212
213        if fragmention_size is None:
214            fragmention_size = message_len
215
216        if message_len <= fragmention_size:
217            self.frame_write(sock, mes_type, message, **kwargs)
218            return
219
220        pos = 0
221        op_code = mes_type
222        while pos < message_len:
223            end = min(pos + fragmention_size, message_len)
224            fin = end == message_len
225            self.frame_write(sock, op_code, message[pos:end], fin=fin, **kwargs)
226            op_code = self.OP_CONT
227            pos = end
228
229    def message_read(self, sock, read_timeout=60):
230        frame = self.frame_read(sock, read_timeout=read_timeout)
231
232        while not frame['fin']:
233            temp = self.frame_read(sock, read_timeout=read_timeout)
234            frame['data'] += temp['data']
235            frame['fin'] = temp['fin']
236
237        return frame
238