1import base64 2import hashlib 3import itertools 4import pytest 5import random 6import re 7import select 8import struct 9 10from unit.applications.proto import TestApplicationProto 11 12GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 13 14 15class TestApplicationWebsocket(TestApplicationProto): 16 17 OP_CONT = 0x00 18 OP_TEXT = 0x01 19 OP_BINARY = 0x02 20 OP_CLOSE = 0x08 21 OP_PING = 0x09 22 OP_PONG = 0x0A 23 CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011] 24 25 def key(self): 26 raw_key = bytes(random.getrandbits(8) for _ in range(16)) 27 return base64.b64encode(raw_key).decode() 28 29 def accept(self, key): 30 sha1 = hashlib.sha1((key + GUID).encode()).digest() 31 return base64.b64encode(sha1).decode() 32 33 def upgrade(self, headers=None): 34 key = None 35 36 if headers is None: 37 key = self.key() 38 headers = { 39 'Host': 'localhost', 40 'Upgrade': 'websocket', 41 'Connection': 'Upgrade', 42 'Sec-WebSocket-Key': key, 43 'Sec-WebSocket-Protocol': 'chat', 44 'Sec-WebSocket-Version': 13, 45 } 46 47 _, sock = self.get( 48 headers=headers, 49 no_recv=True, 50 start=True, 51 ) 52 53 resp = '' 54 while True: 55 rlist = select.select([sock], [], [], 60)[0] 56 if not rlist: 57 pytest.fail('Can\'t read response from server.') 58 59 resp += sock.recv(4096).decode() 60 61 if ( 62 re.search('101 Switching Protocols', resp) 63 and resp[-4:] == '\r\n\r\n' 64 ): 65 resp = self._resp_to_dict(resp) 66 break 67 68 return (resp, sock, key) 69 70 def apply_mask(self, data, mask): 71 return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask))) 72 73 def serialize_close(self, code=1000, reason=''): 74 return struct.pack('!H', code) + reason.encode('utf-8') 75 76 def frame_read(self, sock, read_timeout=60): 77 def recv_bytes(sock, bytes): 78 data = b'' 79 while True: 80 rlist = select.select([sock], [], [], read_timeout)[0] 81 if not rlist: 82 # For all current cases if the "read_timeout" was changed 83 # than test do not expect to get a response from server. 84 if read_timeout == 60: 85 pytest.fail('Can\'t read response from server.') 86 break 87 88 data += sock.recv(bytes - len(data)) 89 90 if len(data) == bytes: 91 break 92 93 return data 94 95 frame = {} 96 97 head1, = struct.unpack('!B', recv_bytes(sock, 1)) 98 head2, = struct.unpack('!B', recv_bytes(sock, 1)) 99 100 frame['fin'] = bool(head1 & 0b10000000) 101 frame['rsv1'] = bool(head1 & 0b01000000) 102 frame['rsv2'] = bool(head1 & 0b00100000) 103 frame['rsv3'] = bool(head1 & 0b00010000) 104 frame['opcode'] = head1 & 0b00001111 105 frame['mask'] = head2 & 0b10000000 106 107 length = head2 & 0b01111111 108 if length == 126: 109 data = recv_bytes(sock, 2) 110 length, = struct.unpack('!H', data) 111 elif length == 127: 112 data = recv_bytes(sock, 8) 113 length, = struct.unpack('!Q', data) 114 115 if frame['mask']: 116 mask_bits = recv_bytes(sock, 4) 117 118 data = b'' 119 120 if length != 0: 121 data = recv_bytes(sock, length) 122 123 if frame['mask']: 124 data = self.apply_mask(data, mask_bits) 125 126 if frame['opcode'] == self.OP_CLOSE: 127 if length >= 2: 128 code, = struct.unpack('!H', data[:2]) 129 reason = data[2:].decode('utf-8') 130 if not (code in self.CLOSE_CODES or 3000 <= code < 5000): 131 pytest.fail('Invalid status code') 132 frame['code'] = code 133 frame['reason'] = reason 134 elif length == 0: 135 frame['code'] = 1005 136 frame['reason'] = '' 137 else: 138 pytest.fail('Close frame too short') 139 140 frame['data'] = data 141 142 if frame['mask']: 143 pytest.fail('Received frame with mask') 144 145 return frame 146 147 def frame_to_send( 148 self, 149 opcode, 150 data, 151 fin=True, 152 length=None, 153 rsv1=False, 154 rsv2=False, 155 rsv3=False, 156 mask=True, 157 ): 158 frame = b'' 159 160 if isinstance(data, str): 161 data = data.encode('utf-8') 162 163 head1 = ( 164 (0b10000000 if fin else 0) 165 | (0b01000000 if rsv1 else 0) 166 | (0b00100000 if rsv2 else 0) 167 | (0b00010000 if rsv3 else 0) 168 | opcode 169 ) 170 171 head2 = 0b10000000 if mask else 0 172 173 data_length = len(data) if length is None else length 174 if data_length < 126: 175 frame += struct.pack('!BB', head1, head2 | data_length) 176 elif data_length < 65536: 177 frame += struct.pack('!BBH', head1, head2 | 126, data_length) 178 else: 179 frame += struct.pack('!BBQ', head1, head2 | 127, data_length) 180 181 if mask: 182 mask_bits = struct.pack('!I', random.getrandbits(32)) 183 frame += mask_bits 184 185 if mask: 186 frame += self.apply_mask(data, mask_bits) 187 else: 188 frame += data 189 190 return frame 191 192 def frame_write(self, sock, *args, **kwargs): 193 chopsize = kwargs.pop('chopsize') if 'chopsize' in kwargs else None 194 195 frame = self.frame_to_send(*args, **kwargs) 196 197 if chopsize is None: 198 try: 199 sock.sendall(frame) 200 except BrokenPipeError: 201 pass 202 203 else: 204 pos = 0 205 frame_len = len(frame) 206 while pos < frame_len: 207 end = min(pos + chopsize, frame_len) 208 try: 209 sock.sendall(frame[pos:end]) 210 except BrokenPipeError: 211 end = frame_len 212 pos = end 213 214 def message(self, sock, type, message, fragmention_size=None, **kwargs): 215 message_len = len(message) 216 217 if fragmention_size is None: 218 fragmention_size = message_len 219 220 if message_len <= fragmention_size: 221 self.frame_write(sock, type, message, **kwargs) 222 return 223 224 pos = 0 225 op_code = type 226 while pos < message_len: 227 end = min(pos + fragmention_size, message_len) 228 fin = end == message_len 229 self.frame_write( 230 sock, op_code, message[pos:end], fin=fin, **kwargs 231 ) 232 op_code = self.OP_CONT 233 pos = end 234 235 def message_read(self, sock, read_timeout=60): 236 frame = self.frame_read(sock, read_timeout=read_timeout) 237 238 while not frame['fin']: 239 temp = self.frame_read(sock, read_timeout=read_timeout) 240 frame['data'] += temp['data'] 241 frame['fin'] = temp['fin'] 242 243 return frame 244