1import base64 2import hashlib 3import itertools 4import random 5import select 6import struct 7 8import pytest 9from unit.applications.proto import TestApplicationProto 10 11GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 12 13 14class TestApplicationWebsocket(TestApplicationProto): 15 16 OP_CONT = 0x00 17 OP_TEXT = 0x01 18 OP_BINARY = 0x02 19 OP_CLOSE = 0x08 20 OP_PING = 0x09 21 OP_PONG = 0x0A 22 CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011] 23 24 def key(self): 25 raw_key = bytes(random.getrandbits(8) for _ in range(16)) 26 return base64.b64encode(raw_key).decode() 27 28 def accept(self, key): 29 sha1 = hashlib.sha1((key + GUID).encode()).digest() 30 return base64.b64encode(sha1).decode() 31 32 def upgrade(self, headers=None): 33 key = None 34 35 if headers is None: 36 key = self.key() 37 headers = { 38 'Host': 'localhost', 39 'Upgrade': 'websocket', 40 'Connection': 'Upgrade', 41 'Sec-WebSocket-Key': key, 42 'Sec-WebSocket-Protocol': 'chat, phone, video', 43 'Sec-WebSocket-Version': 13, 44 } 45 46 _, sock = self.get(headers=headers, no_recv=True, start=True,) 47 48 resp = '' 49 while True: 50 rlist = select.select([sock], [], [], 60)[0] 51 if not rlist: 52 pytest.fail('Can\'t read response from server.') 53 54 resp += sock.recv(4096).decode() 55 56 if resp.startswith('HTTP/') and '\r\n\r\n' in resp: 57 resp = self._resp_to_dict(resp) 58 break 59 60 return (resp, sock, key) 61 62 def apply_mask(self, data, mask): 63 return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask))) 64 65 def serialize_close(self, code=1000, reason=''): 66 return struct.pack('!H', code) + reason.encode('utf-8') 67 68 def frame_read(self, sock, read_timeout=60): 69 def recv_bytes(sock, bytes): 70 data = b'' 71 while True: 72 rlist = select.select([sock], [], [], read_timeout)[0] 73 if not rlist: 74 # For all current cases if the "read_timeout" was changed 75 # than test do not expect to get a response from server. 76 if read_timeout == 60: 77 pytest.fail('Can\'t read response from server.') 78 break 79 80 data += sock.recv(bytes - len(data)) 81 82 if len(data) == bytes: 83 break 84 85 return data 86 87 frame = {} 88 89 (head1,) = struct.unpack('!B', recv_bytes(sock, 1)) 90 (head2,) = struct.unpack('!B', recv_bytes(sock, 1)) 91 92 frame['fin'] = bool(head1 & 0b10000000) 93 frame['rsv1'] = bool(head1 & 0b01000000) 94 frame['rsv2'] = bool(head1 & 0b00100000) 95 frame['rsv3'] = bool(head1 & 0b00010000) 96 frame['opcode'] = head1 & 0b00001111 97 frame['mask'] = head2 & 0b10000000 98 99 length = head2 & 0b01111111 100 if length == 126: 101 data = recv_bytes(sock, 2) 102 (length,) = struct.unpack('!H', data) 103 elif length == 127: 104 data = recv_bytes(sock, 8) 105 (length,) = struct.unpack('!Q', data) 106 107 if frame['mask']: 108 mask_bits = recv_bytes(sock, 4) 109 110 data = b'' 111 112 if length != 0: 113 data = recv_bytes(sock, length) 114 115 if frame['mask']: 116 data = self.apply_mask(data, mask_bits) 117 118 if frame['opcode'] == self.OP_CLOSE: 119 if length >= 2: 120 (code,) = struct.unpack('!H', data[:2]) 121 reason = data[2:].decode('utf-8') 122 if not (code in self.CLOSE_CODES or 3000 <= code < 5000): 123 pytest.fail('Invalid status code') 124 frame['code'] = code 125 frame['reason'] = reason 126 elif length == 0: 127 frame['code'] = 1005 128 frame['reason'] = '' 129 else: 130 pytest.fail('Close frame too short') 131 132 frame['data'] = data 133 134 if frame['mask']: 135 pytest.fail('Received frame with mask') 136 137 return frame 138 139 def frame_to_send( 140 self, 141 opcode, 142 data, 143 fin=True, 144 length=None, 145 rsv1=False, 146 rsv2=False, 147 rsv3=False, 148 mask=True, 149 ): 150 frame = b'' 151 152 if isinstance(data, str): 153 data = data.encode('utf-8') 154 155 head1 = ( 156 (0b10000000 if fin else 0) 157 | (0b01000000 if rsv1 else 0) 158 | (0b00100000 if rsv2 else 0) 159 | (0b00010000 if rsv3 else 0) 160 | opcode 161 ) 162 163 head2 = 0b10000000 if mask else 0 164 165 data_length = len(data) if length is None else length 166 if data_length < 126: 167 frame += struct.pack('!BB', head1, head2 | data_length) 168 elif data_length < 65536: 169 frame += struct.pack('!BBH', head1, head2 | 126, data_length) 170 else: 171 frame += struct.pack('!BBQ', head1, head2 | 127, data_length) 172 173 if mask: 174 mask_bits = struct.pack('!I', random.getrandbits(32)) 175 frame += mask_bits 176 177 if mask: 178 frame += self.apply_mask(data, mask_bits) 179 else: 180 frame += data 181 182 return frame 183 184 def frame_write(self, sock, *args, **kwargs): 185 chopsize = kwargs.pop('chopsize') if 'chopsize' in kwargs else None 186 187 frame = self.frame_to_send(*args, **kwargs) 188 189 if chopsize is None: 190 try: 191 sock.sendall(frame) 192 except BrokenPipeError: 193 pass 194 195 else: 196 pos = 0 197 frame_len = len(frame) 198 while pos < frame_len: 199 end = min(pos + chopsize, frame_len) 200 try: 201 sock.sendall(frame[pos:end]) 202 except BrokenPipeError: 203 end = frame_len 204 pos = end 205 206 def message(self, sock, type, message, fragmention_size=None, **kwargs): 207 message_len = len(message) 208 209 if fragmention_size is None: 210 fragmention_size = message_len 211 212 if message_len <= fragmention_size: 213 self.frame_write(sock, type, message, **kwargs) 214 return 215 216 pos = 0 217 op_code = type 218 while pos < message_len: 219 end = min(pos + fragmention_size, message_len) 220 fin = end == message_len 221 self.frame_write( 222 sock, op_code, message[pos:end], fin=fin, **kwargs 223 ) 224 op_code = self.OP_CONT 225 pos = end 226 227 def message_read(self, sock, read_timeout=60): 228 frame = self.frame_read(sock, read_timeout=read_timeout) 229 230 while not frame['fin']: 231 temp = self.frame_read(sock, read_timeout=read_timeout) 232 frame['data'] += temp['data'] 233 frame['fin'] = temp['fin'] 234 235 return frame 236