xref: /unit/test/unit/applications/websockets.py (revision 1848:4bd548074e2c)
1import base64
2import hashlib
3import itertools
4import random
5import select
6import struct
7
8import pytest
9from unit.applications.proto import TestApplicationProto
10
11GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
12
13
14class TestApplicationWebsocket(TestApplicationProto):
15
16    OP_CONT = 0x00
17    OP_TEXT = 0x01
18    OP_BINARY = 0x02
19    OP_CLOSE = 0x08
20    OP_PING = 0x09
21    OP_PONG = 0x0A
22    CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011]
23
24    def key(self):
25        raw_key = bytes(random.getrandbits(8) for _ in range(16))
26        return base64.b64encode(raw_key).decode()
27
28    def accept(self, key):
29        sha1 = hashlib.sha1((key + GUID).encode()).digest()
30        return base64.b64encode(sha1).decode()
31
32    def upgrade(self, headers=None):
33        key = None
34
35        if headers is None:
36            key = self.key()
37            headers = {
38                'Host': 'localhost',
39                'Upgrade': 'websocket',
40                'Connection': 'Upgrade',
41                'Sec-WebSocket-Key': key,
42                'Sec-WebSocket-Protocol': 'chat, phone, video',
43                'Sec-WebSocket-Version': 13,
44            }
45
46        _, sock = self.get(headers=headers, no_recv=True, start=True,)
47
48        resp = ''
49        while True:
50            rlist = select.select([sock], [], [], 60)[0]
51            if not rlist:
52                pytest.fail('Can\'t read response from server.')
53
54            resp += sock.recv(4096).decode()
55
56            if resp.startswith('HTTP/') and '\r\n\r\n' in resp:
57                resp = self._resp_to_dict(resp)
58                break
59
60        return (resp, sock, key)
61
62    def apply_mask(self, data, mask):
63        return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask)))
64
65    def serialize_close(self, code=1000, reason=''):
66        return struct.pack('!H', code) + reason.encode('utf-8')
67
68    def frame_read(self, sock, read_timeout=60):
69        def recv_bytes(sock, bytes):
70            data = b''
71            while True:
72                rlist = select.select([sock], [], [], read_timeout)[0]
73                if not rlist:
74                    # For all current cases if the "read_timeout" was changed
75                    # than test do not expect to get a response from server.
76                    if read_timeout == 60:
77                        pytest.fail('Can\'t read response from server.')
78                    break
79
80                data += sock.recv(bytes - len(data))
81
82                if len(data) == bytes:
83                    break
84
85            return data
86
87        frame = {}
88
89        (head1,) = struct.unpack('!B', recv_bytes(sock, 1))
90        (head2,) = struct.unpack('!B', recv_bytes(sock, 1))
91
92        frame['fin'] = bool(head1 & 0b10000000)
93        frame['rsv1'] = bool(head1 & 0b01000000)
94        frame['rsv2'] = bool(head1 & 0b00100000)
95        frame['rsv3'] = bool(head1 & 0b00010000)
96        frame['opcode'] = head1 & 0b00001111
97        frame['mask'] = head2 & 0b10000000
98
99        length = head2 & 0b01111111
100        if length == 126:
101            data = recv_bytes(sock, 2)
102            (length,) = struct.unpack('!H', data)
103        elif length == 127:
104            data = recv_bytes(sock, 8)
105            (length,) = struct.unpack('!Q', data)
106
107        if frame['mask']:
108            mask_bits = recv_bytes(sock, 4)
109
110        data = b''
111
112        if length != 0:
113            data = recv_bytes(sock, length)
114
115        if frame['mask']:
116            data = self.apply_mask(data, mask_bits)
117
118        if frame['opcode'] == self.OP_CLOSE:
119            if length >= 2:
120                (code,) = struct.unpack('!H', data[:2])
121                reason = data[2:].decode('utf-8')
122                if not (code in self.CLOSE_CODES or 3000 <= code < 5000):
123                    pytest.fail('Invalid status code')
124                frame['code'] = code
125                frame['reason'] = reason
126            elif length == 0:
127                frame['code'] = 1005
128                frame['reason'] = ''
129            else:
130                pytest.fail('Close frame too short')
131
132        frame['data'] = data
133
134        if frame['mask']:
135            pytest.fail('Received frame with mask')
136
137        return frame
138
139    def frame_to_send(
140        self,
141        opcode,
142        data,
143        fin=True,
144        length=None,
145        rsv1=False,
146        rsv2=False,
147        rsv3=False,
148        mask=True,
149    ):
150        frame = b''
151
152        if isinstance(data, str):
153            data = data.encode('utf-8')
154
155        head1 = (
156            (0b10000000 if fin else 0)
157            | (0b01000000 if rsv1 else 0)
158            | (0b00100000 if rsv2 else 0)
159            | (0b00010000 if rsv3 else 0)
160            | opcode
161        )
162
163        head2 = 0b10000000 if mask else 0
164
165        data_length = len(data) if length is None else length
166        if data_length < 126:
167            frame += struct.pack('!BB', head1, head2 | data_length)
168        elif data_length < 65536:
169            frame += struct.pack('!BBH', head1, head2 | 126, data_length)
170        else:
171            frame += struct.pack('!BBQ', head1, head2 | 127, data_length)
172
173        if mask:
174            mask_bits = struct.pack('!I', random.getrandbits(32))
175            frame += mask_bits
176
177        if mask:
178            frame += self.apply_mask(data, mask_bits)
179        else:
180            frame += data
181
182        return frame
183
184    def frame_write(self, sock, *args, **kwargs):
185        chopsize = kwargs.pop('chopsize') if 'chopsize' in kwargs else None
186
187        frame = self.frame_to_send(*args, **kwargs)
188
189        if chopsize is None:
190            try:
191                sock.sendall(frame)
192            except BrokenPipeError:
193                pass
194
195        else:
196            pos = 0
197            frame_len = len(frame)
198            while pos < frame_len:
199                end = min(pos + chopsize, frame_len)
200                try:
201                    sock.sendall(frame[pos:end])
202                except BrokenPipeError:
203                    end = frame_len
204                pos = end
205
206    def message(self, sock, type, message, fragmention_size=None, **kwargs):
207        message_len = len(message)
208
209        if fragmention_size is None:
210            fragmention_size = message_len
211
212        if message_len <= fragmention_size:
213            self.frame_write(sock, type, message, **kwargs)
214            return
215
216        pos = 0
217        op_code = type
218        while pos < message_len:
219            end = min(pos + fragmention_size, message_len)
220            fin = end == message_len
221            self.frame_write(
222                sock, op_code, message[pos:end], fin=fin, **kwargs
223            )
224            op_code = self.OP_CONT
225            pos = end
226
227    def message_read(self, sock, read_timeout=60):
228        frame = self.frame_read(sock, read_timeout=read_timeout)
229
230        while not frame['fin']:
231            temp = self.frame_read(sock, read_timeout=read_timeout)
232            frame['data'] += temp['data']
233            frame['fin'] = temp['fin']
234
235        return frame
236