xref: /unit/test/unit/applications/websockets.py (revision 1246:1b6adb628960)
1import re
2import random
3import base64
4import struct
5import select
6import hashlib
7import itertools
8from unit.applications.proto import TestApplicationProto
9
10GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
11
12
13class TestApplicationWebsocket(TestApplicationProto):
14
15    OP_CONT = 0x00
16    OP_TEXT = 0x01
17    OP_BINARY = 0x02
18    OP_CLOSE = 0x08
19    OP_PING = 0x09
20    OP_PONG = 0x0A
21    CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011]
22
23    def __init__(self, preinit=False):
24        self.preinit = preinit
25
26    def key(self):
27        raw_key = bytes(random.getrandbits(8) for _ in range(16))
28        return base64.b64encode(raw_key).decode()
29
30    def accept(self, key):
31        sha1 = hashlib.sha1((key + GUID).encode()).digest()
32        return base64.b64encode(sha1).decode()
33
34    def upgrade(self):
35        key = self.key()
36        _, sock = self.get(
37            headers={
38                'Host': 'localhost',
39                'Upgrade': 'websocket',
40                'Connection': 'Upgrade',
41                'Sec-WebSocket-Key': key,
42                'Sec-WebSocket-Protocol': 'chat',
43                'Sec-WebSocket-Version': 13,
44            },
45            no_recv=True,
46            start=True,
47        )
48
49        resp = ''
50        while select.select([sock], [], [], 30)[0]:
51            resp += sock.recv(4096).decode()
52
53            if (
54                re.search('101 Switching Protocols', resp)
55                and resp[-4:] == '\r\n\r\n'
56            ):
57                resp = self._resp_to_dict(resp)
58                break
59
60        return (resp, sock, key)
61
62    def apply_mask(self, data, mask):
63        return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask)))
64
65    def serialize_close(self, code=1000, reason=''):
66        return struct.pack('!H', code) + reason.encode('utf-8')
67
68    def frame_read(self, sock, read_timeout=30):
69        def recv_bytes(sock, bytes):
70            data = b''
71            while select.select([sock], [], [], read_timeout)[0]:
72                data += sock.recv(bytes - len(data))
73
74                if len(data) == bytes:
75                    break
76
77            return data
78
79        frame = {}
80
81        head1, = struct.unpack('!B', recv_bytes(sock, 1))
82        head2, = struct.unpack('!B', recv_bytes(sock, 1))
83
84        frame['fin'] = bool(head1 & 0b10000000)
85        frame['rsv1'] = bool(head1 & 0b01000000)
86        frame['rsv2'] = bool(head1 & 0b00100000)
87        frame['rsv3'] = bool(head1 & 0b00010000)
88        frame['opcode'] = head1 & 0b00001111
89        frame['mask'] = head2 & 0b10000000
90
91        length = head2 & 0b01111111
92        if length == 126:
93            data = recv_bytes(sock, 2)
94            length, = struct.unpack('!H', data)
95        elif length == 127:
96            data = recv_bytes(sock, 8)
97            length, = struct.unpack('!Q', data)
98
99        if frame['mask']:
100            mask_bits = recv_bytes(sock, 4)
101
102        data = b''
103
104        if length != 0:
105            data = recv_bytes(sock, length)
106
107        if frame['mask']:
108            data = self.apply_mask(data, mask_bits)
109
110        if frame['opcode'] == self.OP_CLOSE:
111            if length >= 2:
112                code, = struct.unpack('!H', data[:2])
113                reason = data[2:].decode('utf-8')
114                if not (code in self.CLOSE_CODES or 3000 <= code < 5000):
115                    self.fail('Invalid status code')
116                frame['code'] = code
117                frame['reason'] = reason
118            elif length == 0:
119                frame['code'] = 1005
120                frame['reason'] = ''
121            else:
122                self.fail('Close frame too short')
123
124        frame['data'] = data
125
126        if frame['mask']:
127            self.fail('Received frame with mask')
128
129        return frame
130
131    def frame_to_send(
132        self,
133        opcode,
134        data,
135        fin=True,
136        length=None,
137        rsv1=False,
138        rsv2=False,
139        rsv3=False,
140        mask=True,
141    ):
142        frame = b''
143
144        if isinstance(data, str):
145            data = data.encode('utf-8')
146
147        head1 = (
148            (0b10000000 if fin else 0)
149            | (0b01000000 if rsv1 else 0)
150            | (0b00100000 if rsv2 else 0)
151            | (0b00010000 if rsv3 else 0)
152            | opcode
153        )
154
155        head2 = 0b10000000 if mask else 0
156
157        data_length = len(data) if length is None else length
158        if data_length < 126:
159            frame += struct.pack('!BB', head1, head2 | data_length)
160        elif data_length < 65536:
161            frame += struct.pack('!BBH', head1, head2 | 126, data_length)
162        else:
163            frame += struct.pack('!BBQ', head1, head2 | 127, data_length)
164
165        if mask:
166            mask_bits = struct.pack('!I', random.getrandbits(32))
167            frame += mask_bits
168
169        if mask:
170            frame += self.apply_mask(data, mask_bits)
171        else:
172            frame += data
173
174        return frame
175
176    def frame_write(self, sock, *args, **kwargs):
177        chopsize = kwargs.pop('chopsize') if 'chopsize' in kwargs else None
178
179        frame = self.frame_to_send(*args, **kwargs)
180
181        if chopsize is None:
182            try:
183                sock.sendall(frame)
184            except BrokenPipeError:
185                pass
186
187        else:
188            pos = 0
189            frame_len = len(frame)
190            while pos < frame_len:
191                end = min(pos + chopsize, frame_len)
192                try:
193                    sock.sendall(frame[pos:end])
194                except BrokenPipeError:
195                    end = frame_len
196                pos = end
197
198    def message(self, sock, type, message, fragmention_size=None, **kwargs):
199        message_len = len(message)
200
201        if fragmention_size is None:
202            fragmention_size = message_len
203
204        if message_len <= fragmention_size:
205            self.frame_write(sock, type, message, **kwargs)
206            return
207
208        pos = 0
209        op_code = type
210        while pos < message_len:
211            end = min(pos + fragmention_size, message_len)
212            fin = end == message_len
213            self.frame_write(
214                sock, op_code, message[pos:end], fin=fin, **kwargs
215            )
216            op_code = self.OP_CONT
217            pos = end
218
219    def message_read(self, sock, read_timeout=10):
220        frame = self.frame_read(sock, read_timeout=read_timeout)
221
222        while not frame['fin']:
223            temp = self.frame_read(sock, read_timeout=read_timeout)
224            frame['data'] += temp['data']
225            frame['fin'] = temp['fin']
226
227        return frame
228