xref: /unit/test/unit/applications/websockets.py (revision 1156:2649a249721f)
1import random
2import base64
3import struct
4import select
5import hashlib
6import itertools
7from unit.applications.proto import TestApplicationProto
8
9GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
10
11
12class TestApplicationWebsocket(TestApplicationProto):
13
14    OP_CONT = 0x00
15    OP_TEXT = 0x01
16    OP_BINARY = 0x02
17    OP_CLOSE = 0x08
18    OP_PING = 0x09
19    OP_PONG = 0x0A
20    CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011]
21
22    def __init__(self, preinit=False):
23        self.preinit = preinit
24
25    def key(self):
26        raw_key = bytes(random.getrandbits(8) for _ in range(16))
27        return base64.b64encode(raw_key).decode()
28
29    def accept(self, key):
30        sha1 = hashlib.sha1((key + GUID).encode()).digest()
31        return base64.b64encode(sha1).decode()
32
33    def upgrade(self):
34        key = self.key()
35
36        if self.preinit:
37            self.get()
38
39        resp, sock = self.get(
40            headers={
41                'Host': 'localhost',
42                'Upgrade': 'websocket',
43                'Connection': 'Upgrade',
44                'Sec-WebSocket-Key': key,
45                'Sec-WebSocket-Protocol': 'chat',
46                'Sec-WebSocket-Version': 13,
47            },
48            read_timeout=1,
49            start=True,
50        )
51
52        return (resp, sock, key)
53
54    def apply_mask(self, data, mask):
55        return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask)))
56
57    def serialize_close(self, code=1000, reason=''):
58        return struct.pack('!H', code) + reason.encode('utf-8')
59
60    def frame_read(self, sock, read_timeout=30):
61        def recv_bytes(sock, bytes):
62            data = b''
63            while select.select([sock], [], [], read_timeout)[0]:
64                data += sock.recv(bytes - len(data))
65
66                if len(data) == bytes:
67                    break
68
69            return data
70
71        frame = {}
72
73        head1, = struct.unpack('!B', recv_bytes(sock, 1))
74        head2, = struct.unpack('!B', recv_bytes(sock, 1))
75
76        frame['fin'] = bool(head1 & 0b10000000)
77        frame['rsv1'] = bool(head1 & 0b01000000)
78        frame['rsv2'] = bool(head1 & 0b00100000)
79        frame['rsv3'] = bool(head1 & 0b00010000)
80        frame['opcode'] = head1 & 0b00001111
81        frame['mask'] = head2 & 0b10000000
82
83        length = head2 & 0b01111111
84        if length == 126:
85            data = recv_bytes(sock, 2)
86            length, = struct.unpack('!H', data)
87        elif length == 127:
88            data = recv_bytes(sock, 8)
89            length, = struct.unpack('!Q', data)
90
91        if frame['mask']:
92            mask_bits = recv_bytes(sock, 4)
93
94        data = b''
95
96        if length != 0:
97            data = recv_bytes(sock, length)
98
99        if frame['mask']:
100            data = self.apply_mask(data, mask_bits)
101
102        if frame['opcode'] == self.OP_CLOSE:
103            if length >= 2:
104                code, = struct.unpack('!H', data[:2])
105                reason = data[2:].decode('utf-8')
106                if not (code in self.CLOSE_CODES or 3000 <= code < 5000):
107                    self.fail('Invalid status code')
108                frame['code'] = code
109                frame['reason'] = reason
110            elif length == 0:
111                frame['code'] = 1005
112                frame['reason'] = ''
113            else:
114                self.fail('Close frame too short')
115
116        frame['data'] = data
117
118        if frame['mask']:
119            self.fail('Received frame with mask')
120
121        return frame
122
123    def frame_to_send(
124        self,
125        opcode,
126        data,
127        fin=True,
128        length=None,
129        rsv1=False,
130        rsv2=False,
131        rsv3=False,
132        mask=True,
133    ):
134        frame = b''
135
136        if isinstance(data, str):
137            data = data.encode('utf-8')
138
139        head1 = (
140            (0b10000000 if fin else 0)
141            | (0b01000000 if rsv1 else 0)
142            | (0b00100000 if rsv2 else 0)
143            | (0b00010000 if rsv3 else 0)
144            | opcode
145        )
146
147        head2 = 0b10000000 if mask else 0
148
149        data_length = len(data) if length is None else length
150        if data_length < 126:
151            frame += struct.pack('!BB', head1, head2 | data_length)
152        elif data_length < 65536:
153            frame += struct.pack('!BBH', head1, head2 | 126, data_length)
154        else:
155            frame += struct.pack('!BBQ', head1, head2 | 127, data_length)
156
157        if mask:
158            mask_bits = struct.pack('!I', random.getrandbits(32))
159            frame += mask_bits
160
161        if mask:
162            frame += self.apply_mask(data, mask_bits)
163        else:
164            frame += data
165
166        return frame
167
168    def frame_write(self, sock, *args, **kwargs):
169        chopsize = kwargs.pop('chopsize') if 'chopsize' in kwargs else None
170
171        frame = self.frame_to_send(*args, **kwargs)
172
173        if chopsize is None:
174            try:
175                sock.sendall(frame)
176            except BrokenPipeError:
177                pass
178
179        else:
180            pos = 0
181            frame_len = len(frame)
182            while pos < frame_len:
183                end = min(pos + chopsize, frame_len)
184                try:
185                    sock.sendall(frame[pos:end])
186                except BrokenPipeError:
187                    end = frame_len
188                pos = end
189
190    def message(self, sock, type, message, fragmention_size=None, **kwargs):
191        message_len = len(message)
192
193        if fragmention_size is None:
194            fragmention_size = message_len
195
196        if message_len <= fragmention_size:
197            self.frame_write(sock, type, message, **kwargs)
198            return
199
200        pos = 0
201        op_code = type
202        while pos < message_len:
203            end = min(pos + fragmention_size, message_len)
204            fin = end == message_len
205            self.frame_write(
206                sock, op_code, message[pos:end], fin=fin, **kwargs
207            )
208            op_code = self.OP_CONT
209            pos = end
210
211    def message_read(self, sock, read_timeout=10):
212        frame = self.frame_read(sock, read_timeout=read_timeout)
213
214        while not frame['fin']:
215            temp = self.frame_read(sock, read_timeout=read_timeout)
216            frame['data'] += temp['data']
217            frame['fin'] = temp['fin']
218
219        return frame
220