1import random 2import base64 3import struct 4import select 5import hashlib 6import itertools 7from unit.applications.proto import TestApplicationProto 8 9GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 10 11 12class TestApplicationWebsocket(TestApplicationProto): 13 14 OP_CONT = 0x00 15 OP_TEXT = 0x01 16 OP_BINARY = 0x02 17 OP_CLOSE = 0x08 18 OP_PING = 0x09 19 OP_PONG = 0x0A 20 CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011] 21 22 def __init__(self, preinit=False): 23 self.preinit = preinit 24 25 def key(self): 26 raw_key = bytes(random.getrandbits(8) for _ in range(16)) 27 return base64.b64encode(raw_key).decode() 28 29 def accept(self, key): 30 sha1 = hashlib.sha1((key + GUID).encode()).digest() 31 return base64.b64encode(sha1).decode() 32 33 def upgrade(self): 34 key = self.key() 35 36 if self.preinit: 37 self.get() 38 39 resp, sock = self.get( 40 headers={ 41 'Host': 'localhost', 42 'Upgrade': 'websocket', 43 'Connection': 'Upgrade', 44 'Sec-WebSocket-Key': key, 45 'Sec-WebSocket-Protocol': 'chat', 46 'Sec-WebSocket-Version': 13, 47 }, 48 read_timeout=1, 49 start=True, 50 ) 51 52 return (resp, sock, key) 53 54 def apply_mask(self, data, mask): 55 return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask))) 56 57 def serialize_close(self, code=1000, reason=''): 58 return struct.pack('!H', code) + reason.encode('utf-8') 59 60 def frame_read(self, sock, read_timeout=10): 61 def recv_bytes(sock, bytes): 62 data = b'' 63 while select.select([sock], [], [], read_timeout)[0]: 64 data += sock.recv(bytes - len(data)) 65 66 if len(data) == bytes: 67 break 68 69 return data 70 71 frame = {} 72 73 head1, = struct.unpack('!B', recv_bytes(sock, 1)) 74 head2, = struct.unpack('!B', recv_bytes(sock, 1)) 75 76 frame['fin'] = bool(head1 & 0b10000000) 77 frame['rsv1'] = bool(head1 & 0b01000000) 78 frame['rsv2'] = bool(head1 & 0b00100000) 79 frame['rsv3'] = bool(head1 & 0b00010000) 80 frame['opcode'] = head1 & 0b00001111 81 frame['mask'] = head2 & 0b10000000 82 83 length = head2 & 0b01111111 84 if length == 126: 85 data = recv_bytes(sock, 2) 86 length, = struct.unpack('!H', data) 87 elif length == 127: 88 data = recv_bytes(sock, 8) 89 length, = struct.unpack('!Q', data) 90 91 if frame['mask']: 92 mask_bits = recv_bytes(sock, 4) 93 94 data = recv_bytes(sock, length) 95 if frame['mask']: 96 data = self.apply_mask(data, mask_bits) 97 98 if frame['opcode'] == self.OP_CLOSE: 99 if length >= 2: 100 code, = struct.unpack('!H', data[:2]) 101 reason = data[2:].decode('utf-8') 102 if not (code in self.CLOSE_CODES or 3000 <= code < 5000): 103 self.fail('Invalid status code') 104 frame['code'] = code 105 frame['reason'] = reason 106 elif length == 0: 107 frame['code'] = 1005 108 frame['reason'] = '' 109 else: 110 self.fail('Close frame too short') 111 112 frame['data'] = data 113 114 if frame['mask']: 115 self.fail('Received frame with mask') 116 117 return frame 118 119 def frame_to_send( 120 self, 121 opcode, 122 data, 123 fin=True, 124 length=None, 125 rsv1=False, 126 rsv2=False, 127 rsv3=False, 128 mask=True, 129 ): 130 frame = b'' 131 132 if isinstance(data, str): 133 data = data.encode('utf-8') 134 135 head1 = ( 136 (0b10000000 if fin else 0) 137 | (0b01000000 if rsv1 else 0) 138 | (0b00100000 if rsv2 else 0) 139 | (0b00010000 if rsv3 else 0) 140 | opcode 141 ) 142 143 head2 = 0b10000000 if mask else 0 144 145 data_length = len(data) if length is None else length 146 if data_length < 126: 147 frame += struct.pack('!BB', head1, head2 | data_length) 148 elif data_length < 65536: 149 frame += struct.pack('!BBH', head1, head2 | 126, data_length) 150 else: 151 frame += struct.pack('!BBQ', head1, head2 | 127, data_length) 152 153 if mask: 154 mask_bits = struct.pack('!I', random.getrandbits(32)) 155 frame += mask_bits 156 157 if mask: 158 frame += self.apply_mask(data, mask_bits) 159 else: 160 frame += data 161 162 return frame 163 164 def frame_write(self, sock, *args, **kwargs): 165 chopsize = kwargs.pop('chopsize') if 'chopsize' in kwargs else None 166 167 frame = self.frame_to_send(*args, **kwargs) 168 169 if chopsize is None: 170 try: 171 sock.sendall(frame) 172 except BrokenPipeError: 173 pass 174 175 else: 176 pos = 0 177 frame_len = len(frame) 178 while pos < frame_len: 179 end = min(pos + chopsize, frame_len) 180 try: 181 sock.sendall(frame[pos:end]) 182 except BrokenPipeError: 183 end = frame_len 184 pos = end 185 186 def message(self, sock, type, message, fragmention_size=None, **kwargs): 187 message_len = len(message) 188 189 if fragmention_size is None: 190 fragmention_size = message_len 191 192 if message_len <= fragmention_size: 193 self.frame_write(sock, type, message, **kwargs) 194 return 195 196 pos = 0 197 op_code = type 198 while pos < message_len: 199 end = min(pos + fragmention_size, message_len) 200 fin = end == message_len 201 self.frame_write( 202 sock, op_code, message[pos:end], fin=fin, **kwargs 203 ) 204 op_code = self.OP_CONT 205 pos = end 206 207 def message_read(self, sock, read_timeout=10): 208 frame = self.frame_read(sock, read_timeout=read_timeout) 209 210 while not frame['fin']: 211 temp = self.frame_read(sock, read_timeout=read_timeout) 212 frame['data'] += temp['data'] 213 frame['fin'] = temp['fin'] 214 215 return frame 216