xref: /unit/test/unit/applications/websockets.py (revision 1477:b93d1acf81bd)
1import base64
2import hashlib
3import itertools
4import random
5import re
6import select
7import struct
8
9from unit.applications.proto import TestApplicationProto
10
11GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
12
13
14class TestApplicationWebsocket(TestApplicationProto):
15
16    OP_CONT = 0x00
17    OP_TEXT = 0x01
18    OP_BINARY = 0x02
19    OP_CLOSE = 0x08
20    OP_PING = 0x09
21    OP_PONG = 0x0A
22    CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011]
23
24    def __init__(self, preinit=False):
25        self.preinit = preinit
26
27    def key(self):
28        raw_key = bytes(random.getrandbits(8) for _ in range(16))
29        return base64.b64encode(raw_key).decode()
30
31    def accept(self, key):
32        sha1 = hashlib.sha1((key + GUID).encode()).digest()
33        return base64.b64encode(sha1).decode()
34
35    def upgrade(self, headers=None):
36        key = None
37
38        if headers is None:
39            key = self.key()
40            headers = {
41                'Host': 'localhost',
42                'Upgrade': 'websocket',
43                'Connection': 'Upgrade',
44                'Sec-WebSocket-Key': key,
45                'Sec-WebSocket-Protocol': 'chat',
46                'Sec-WebSocket-Version': 13,
47            }
48
49        _, sock = self.get(
50            headers=headers,
51            no_recv=True,
52            start=True,
53        )
54
55        resp = ''
56        while True:
57            rlist = select.select([sock], [], [], 60)[0]
58            if not rlist:
59                self.fail('Can\'t read response from server.')
60
61            resp += sock.recv(4096).decode()
62
63            if (
64                re.search('101 Switching Protocols', resp)
65                and resp[-4:] == '\r\n\r\n'
66            ):
67                resp = self._resp_to_dict(resp)
68                break
69
70        return (resp, sock, key)
71
72    def apply_mask(self, data, mask):
73        return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask)))
74
75    def serialize_close(self, code=1000, reason=''):
76        return struct.pack('!H', code) + reason.encode('utf-8')
77
78    def frame_read(self, sock, read_timeout=60):
79        def recv_bytes(sock, bytes):
80            data = b''
81            while True:
82                rlist = select.select([sock], [], [], read_timeout)[0]
83                if not rlist:
84                    # For all current cases if the "read_timeout" was changed
85                    # than test do not expect to get a response from server.
86                    if read_timeout == 60:
87                        self.fail('Can\'t read response from server.')
88                    break
89
90                data += sock.recv(bytes - len(data))
91
92                if len(data) == bytes:
93                    break
94
95            return data
96
97        frame = {}
98
99        head1, = struct.unpack('!B', recv_bytes(sock, 1))
100        head2, = struct.unpack('!B', recv_bytes(sock, 1))
101
102        frame['fin'] = bool(head1 & 0b10000000)
103        frame['rsv1'] = bool(head1 & 0b01000000)
104        frame['rsv2'] = bool(head1 & 0b00100000)
105        frame['rsv3'] = bool(head1 & 0b00010000)
106        frame['opcode'] = head1 & 0b00001111
107        frame['mask'] = head2 & 0b10000000
108
109        length = head2 & 0b01111111
110        if length == 126:
111            data = recv_bytes(sock, 2)
112            length, = struct.unpack('!H', data)
113        elif length == 127:
114            data = recv_bytes(sock, 8)
115            length, = struct.unpack('!Q', data)
116
117        if frame['mask']:
118            mask_bits = recv_bytes(sock, 4)
119
120        data = b''
121
122        if length != 0:
123            data = recv_bytes(sock, length)
124
125        if frame['mask']:
126            data = self.apply_mask(data, mask_bits)
127
128        if frame['opcode'] == self.OP_CLOSE:
129            if length >= 2:
130                code, = struct.unpack('!H', data[:2])
131                reason = data[2:].decode('utf-8')
132                if not (code in self.CLOSE_CODES or 3000 <= code < 5000):
133                    self.fail('Invalid status code')
134                frame['code'] = code
135                frame['reason'] = reason
136            elif length == 0:
137                frame['code'] = 1005
138                frame['reason'] = ''
139            else:
140                self.fail('Close frame too short')
141
142        frame['data'] = data
143
144        if frame['mask']:
145            self.fail('Received frame with mask')
146
147        return frame
148
149    def frame_to_send(
150        self,
151        opcode,
152        data,
153        fin=True,
154        length=None,
155        rsv1=False,
156        rsv2=False,
157        rsv3=False,
158        mask=True,
159    ):
160        frame = b''
161
162        if isinstance(data, str):
163            data = data.encode('utf-8')
164
165        head1 = (
166            (0b10000000 if fin else 0)
167            | (0b01000000 if rsv1 else 0)
168            | (0b00100000 if rsv2 else 0)
169            | (0b00010000 if rsv3 else 0)
170            | opcode
171        )
172
173        head2 = 0b10000000 if mask else 0
174
175        data_length = len(data) if length is None else length
176        if data_length < 126:
177            frame += struct.pack('!BB', head1, head2 | data_length)
178        elif data_length < 65536:
179            frame += struct.pack('!BBH', head1, head2 | 126, data_length)
180        else:
181            frame += struct.pack('!BBQ', head1, head2 | 127, data_length)
182
183        if mask:
184            mask_bits = struct.pack('!I', random.getrandbits(32))
185            frame += mask_bits
186
187        if mask:
188            frame += self.apply_mask(data, mask_bits)
189        else:
190            frame += data
191
192        return frame
193
194    def frame_write(self, sock, *args, **kwargs):
195        chopsize = kwargs.pop('chopsize') if 'chopsize' in kwargs else None
196
197        frame = self.frame_to_send(*args, **kwargs)
198
199        if chopsize is None:
200            try:
201                sock.sendall(frame)
202            except BrokenPipeError:
203                pass
204
205        else:
206            pos = 0
207            frame_len = len(frame)
208            while pos < frame_len:
209                end = min(pos + chopsize, frame_len)
210                try:
211                    sock.sendall(frame[pos:end])
212                except BrokenPipeError:
213                    end = frame_len
214                pos = end
215
216    def message(self, sock, type, message, fragmention_size=None, **kwargs):
217        message_len = len(message)
218
219        if fragmention_size is None:
220            fragmention_size = message_len
221
222        if message_len <= fragmention_size:
223            self.frame_write(sock, type, message, **kwargs)
224            return
225
226        pos = 0
227        op_code = type
228        while pos < message_len:
229            end = min(pos + fragmention_size, message_len)
230            fin = end == message_len
231            self.frame_write(
232                sock, op_code, message[pos:end], fin=fin, **kwargs
233            )
234            op_code = self.OP_CONT
235            pos = end
236
237    def message_read(self, sock, read_timeout=60):
238        frame = self.frame_read(sock, read_timeout=read_timeout)
239
240        while not frame['fin']:
241            temp = self.frame_read(sock, read_timeout=read_timeout)
242            frame['data'] += temp['data']
243            frame['fin'] = temp['fin']
244
245        return frame
246