1import base64 2import hashlib 3import itertools 4import random 5import select 6import struct 7 8import pytest 9from unit.applications.proto import TestApplicationProto 10 11GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 12 13 14class TestApplicationWebsocket(TestApplicationProto): 15 16 OP_CONT = 0x00 17 OP_TEXT = 0x01 18 OP_BINARY = 0x02 19 OP_CLOSE = 0x08 20 OP_PING = 0x09 21 OP_PONG = 0x0A 22 CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011] 23 24 def key(self): 25 raw_key = bytes(random.getrandbits(8) for _ in range(16)) 26 return base64.b64encode(raw_key).decode() 27 28 def accept(self, key): 29 sha1 = hashlib.sha1((key + GUID).encode()).digest() 30 return base64.b64encode(sha1).decode() 31 32 def upgrade(self, headers=None): 33 key = None 34 35 if headers is None: 36 key = self.key() 37 headers = { 38 'Host': 'localhost', 39 'Upgrade': 'websocket', 40 'Connection': 'Upgrade', 41 'Sec-WebSocket-Key': key, 42 'Sec-WebSocket-Protocol': 'chat, phone, video', 43 'Sec-WebSocket-Version': 13, 44 } 45 46 _, sock = self.get( 47 headers=headers, 48 no_recv=True, 49 start=True, 50 ) 51 52 resp = '' 53 while True: 54 rlist = select.select([sock], [], [], 60)[0] 55 if not rlist: 56 pytest.fail('Can\'t read response from server.') 57 58 resp += sock.recv(4096).decode() 59 60 if resp.startswith('HTTP/') and '\r\n\r\n' in resp: 61 resp = self._resp_to_dict(resp) 62 break 63 64 return (resp, sock, key) 65 66 def apply_mask(self, data, mask): 67 return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask))) 68 69 def serialize_close(self, code=1000, reason=''): 70 return struct.pack('!H', code) + reason.encode('utf-8') 71 72 def frame_read(self, sock, read_timeout=60): 73 def recv_bytes(sock, bytes): 74 data = b'' 75 while True: 76 rlist = select.select([sock], [], [], read_timeout)[0] 77 if not rlist: 78 # For all current cases if the "read_timeout" was changed 79 # than test do not expect to get a response from server. 80 if read_timeout == 60: 81 pytest.fail('Can\'t read response from server.') 82 break 83 84 data += sock.recv(bytes - len(data)) 85 86 if len(data) == bytes: 87 break 88 89 return data 90 91 frame = {} 92 93 (head1,) = struct.unpack('!B', recv_bytes(sock, 1)) 94 (head2,) = struct.unpack('!B', recv_bytes(sock, 1)) 95 96 frame['fin'] = bool(head1 & 0b10000000) 97 frame['rsv1'] = bool(head1 & 0b01000000) 98 frame['rsv2'] = bool(head1 & 0b00100000) 99 frame['rsv3'] = bool(head1 & 0b00010000) 100 frame['opcode'] = head1 & 0b00001111 101 frame['mask'] = head2 & 0b10000000 102 103 length = head2 & 0b01111111 104 if length == 126: 105 data = recv_bytes(sock, 2) 106 (length,) = struct.unpack('!H', data) 107 elif length == 127: 108 data = recv_bytes(sock, 8) 109 (length,) = struct.unpack('!Q', data) 110 111 if frame['mask']: 112 mask_bits = recv_bytes(sock, 4) 113 114 data = b'' 115 116 if length != 0: 117 data = recv_bytes(sock, length) 118 119 if frame['mask']: 120 data = self.apply_mask(data, mask_bits) 121 122 if frame['opcode'] == self.OP_CLOSE: 123 if length >= 2: 124 (code,) = struct.unpack('!H', data[:2]) 125 reason = data[2:].decode('utf-8') 126 if not (code in self.CLOSE_CODES or 3000 <= code < 5000): 127 pytest.fail('Invalid status code') 128 frame['code'] = code 129 frame['reason'] = reason 130 elif length == 0: 131 frame['code'] = 1005 132 frame['reason'] = '' 133 else: 134 pytest.fail('Close frame too short') 135 136 frame['data'] = data 137 138 if frame['mask']: 139 pytest.fail('Received frame with mask') 140 141 return frame 142 143 def frame_to_send( 144 self, 145 opcode, 146 data, 147 fin=True, 148 length=None, 149 rsv1=False, 150 rsv2=False, 151 rsv3=False, 152 mask=True, 153 ): 154 frame = b'' 155 156 if isinstance(data, str): 157 data = data.encode('utf-8') 158 159 head1 = ( 160 (0b10000000 if fin else 0) 161 | (0b01000000 if rsv1 else 0) 162 | (0b00100000 if rsv2 else 0) 163 | (0b00010000 if rsv3 else 0) 164 | opcode 165 ) 166 167 head2 = 0b10000000 if mask else 0 168 169 data_length = len(data) if length is None else length 170 if data_length < 126: 171 frame += struct.pack('!BB', head1, head2 | data_length) 172 elif data_length < 65536: 173 frame += struct.pack('!BBH', head1, head2 | 126, data_length) 174 else: 175 frame += struct.pack('!BBQ', head1, head2 | 127, data_length) 176 177 if mask: 178 mask_bits = struct.pack('!I', random.getrandbits(32)) 179 frame += mask_bits 180 181 if mask: 182 frame += self.apply_mask(data, mask_bits) 183 else: 184 frame += data 185 186 return frame 187 188 def frame_write(self, sock, *args, **kwargs): 189 chopsize = kwargs.pop('chopsize') if 'chopsize' in kwargs else None 190 191 frame = self.frame_to_send(*args, **kwargs) 192 193 if chopsize is None: 194 try: 195 sock.sendall(frame) 196 except BrokenPipeError: 197 pass 198 199 else: 200 pos = 0 201 frame_len = len(frame) 202 while pos < frame_len: 203 end = min(pos + chopsize, frame_len) 204 try: 205 sock.sendall(frame[pos:end]) 206 except BrokenPipeError: 207 end = frame_len 208 pos = end 209 210 def message(self, sock, type, message, fragmention_size=None, **kwargs): 211 message_len = len(message) 212 213 if fragmention_size is None: 214 fragmention_size = message_len 215 216 if message_len <= fragmention_size: 217 self.frame_write(sock, type, message, **kwargs) 218 return 219 220 pos = 0 221 op_code = type 222 while pos < message_len: 223 end = min(pos + fragmention_size, message_len) 224 fin = end == message_len 225 self.frame_write(sock, op_code, message[pos:end], fin=fin, **kwargs) 226 op_code = self.OP_CONT 227 pos = end 228 229 def message_read(self, sock, read_timeout=60): 230 frame = self.frame_read(sock, read_timeout=read_timeout) 231 232 while not frame['fin']: 233 temp = self.frame_read(sock, read_timeout=read_timeout) 234 frame['data'] += temp['data'] 235 frame['fin'] = temp['fin'] 236 237 return frame 238