1*1130Szelenkov@nginx.comimport random
2*1130Szelenkov@nginx.comimport base64
3*1130Szelenkov@nginx.comimport struct
4*1130Szelenkov@nginx.comimport select
5*1130Szelenkov@nginx.comimport hashlib
6*1130Szelenkov@nginx.comimport itertools
7*1130Szelenkov@nginx.comfrom unit.applications.proto import TestApplicationProto
8*1130Szelenkov@nginx.com
9*1130Szelenkov@nginx.comGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
10*1130Szelenkov@nginx.com
11*1130Szelenkov@nginx.com
12*1130Szelenkov@nginx.comclass TestApplicationWebsocket(TestApplicationProto):
13*1130Szelenkov@nginx.com
14*1130Szelenkov@nginx.com    OP_CONT = 0x00
15*1130Szelenkov@nginx.com    OP_TEXT = 0x01
16*1130Szelenkov@nginx.com    OP_BINARY = 0x02
17*1130Szelenkov@nginx.com    OP_CLOSE = 0x08
18*1130Szelenkov@nginx.com    OP_PING = 0x09
19*1130Szelenkov@nginx.com    OP_PONG = 0x0A
20*1130Szelenkov@nginx.com    CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011]
21*1130Szelenkov@nginx.com
22*1130Szelenkov@nginx.com    def __init__(self, preinit=False):
23*1130Szelenkov@nginx.com        self.preinit = preinit
24*1130Szelenkov@nginx.com
25*1130Szelenkov@nginx.com    def key(self):
26*1130Szelenkov@nginx.com        raw_key = bytes(random.getrandbits(8) for _ in range(16))
27*1130Szelenkov@nginx.com        return base64.b64encode(raw_key).decode()
28*1130Szelenkov@nginx.com
29*1130Szelenkov@nginx.com    def accept(self, key):
30*1130Szelenkov@nginx.com        sha1 = hashlib.sha1((key + GUID).encode()).digest()
31*1130Szelenkov@nginx.com        return base64.b64encode(sha1).decode()
32*1130Szelenkov@nginx.com
33*1130Szelenkov@nginx.com    def upgrade(self):
34*1130Szelenkov@nginx.com        key = self.key()
35*1130Szelenkov@nginx.com
36*1130Szelenkov@nginx.com        if self.preinit:
37*1130Szelenkov@nginx.com            self.get()
38*1130Szelenkov@nginx.com
39*1130Szelenkov@nginx.com        resp, sock = self.get(
40*1130Szelenkov@nginx.com            headers={
41*1130Szelenkov@nginx.com                'Host': 'localhost',
42*1130Szelenkov@nginx.com                'Upgrade': 'websocket',
43*1130Szelenkov@nginx.com                'Connection': 'Upgrade',
44*1130Szelenkov@nginx.com                'Sec-WebSocket-Key': key,
45*1130Szelenkov@nginx.com                'Sec-WebSocket-Protocol': 'chat',
46*1130Szelenkov@nginx.com                'Sec-WebSocket-Version': 13,
47*1130Szelenkov@nginx.com            },
48*1130Szelenkov@nginx.com            read_timeout=1,
49*1130Szelenkov@nginx.com            start=True,
50*1130Szelenkov@nginx.com        )
51*1130Szelenkov@nginx.com
52*1130Szelenkov@nginx.com        return (resp, sock, key)
53*1130Szelenkov@nginx.com
54*1130Szelenkov@nginx.com    def apply_mask(self, data, mask):
55*1130Szelenkov@nginx.com        return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask)))
56*1130Szelenkov@nginx.com
57*1130Szelenkov@nginx.com    def serialize_close(self, code = 1000, reason = ''):
58*1130Szelenkov@nginx.com        return struct.pack('!H', code) + reason.encode('utf-8')
59*1130Szelenkov@nginx.com
60*1130Szelenkov@nginx.com    def frame_read(self, sock, read_timeout=1):
61*1130Szelenkov@nginx.com        def recv_bytes(sock, bytes):
62*1130Szelenkov@nginx.com            data = b''
63*1130Szelenkov@nginx.com            while select.select([sock], [], [], read_timeout)[0]:
64*1130Szelenkov@nginx.com                try:
65*1130Szelenkov@nginx.com                    if bytes < 65536:
66*1130Szelenkov@nginx.com                        data = sock.recv(bytes)
67*1130Szelenkov@nginx.com                    else:
68*1130Szelenkov@nginx.com                        data = self.recvall(
69*1130Szelenkov@nginx.com                            sock,
70*1130Szelenkov@nginx.com                            read_timeout=read_timeout,
71*1130Szelenkov@nginx.com                            buff_size=bytes,
72*1130Szelenkov@nginx.com                        )
73*1130Szelenkov@nginx.com                    break
74*1130Szelenkov@nginx.com                except:
75*1130Szelenkov@nginx.com                    break
76*1130Szelenkov@nginx.com
77*1130Szelenkov@nginx.com            return data
78*1130Szelenkov@nginx.com
79*1130Szelenkov@nginx.com        frame = {}
80*1130Szelenkov@nginx.com
81*1130Szelenkov@nginx.com        head1, = struct.unpack('!B', recv_bytes(sock, 1))
82*1130Szelenkov@nginx.com        head2, = struct.unpack('!B', recv_bytes(sock, 1))
83*1130Szelenkov@nginx.com
84*1130Szelenkov@nginx.com        frame['fin'] = bool(head1 & 0b10000000)
85*1130Szelenkov@nginx.com        frame['rsv1'] = bool(head1 & 0b01000000)
86*1130Szelenkov@nginx.com        frame['rsv2'] = bool(head1 & 0b00100000)
87*1130Szelenkov@nginx.com        frame['rsv3'] = bool(head1 & 0b00010000)
88*1130Szelenkov@nginx.com        frame['opcode'] = head1 & 0b00001111
89*1130Szelenkov@nginx.com        frame['mask'] = head2 & 0b10000000
90*1130Szelenkov@nginx.com
91*1130Szelenkov@nginx.com        length = head2 & 0b01111111
92*1130Szelenkov@nginx.com        if length == 126:
93*1130Szelenkov@nginx.com            data = recv_bytes(sock, 2)
94*1130Szelenkov@nginx.com            length, = struct.unpack('!H', data)
95*1130Szelenkov@nginx.com        elif length == 127:
96*1130Szelenkov@nginx.com            data = recv_bytes(sock, 8)
97*1130Szelenkov@nginx.com            length, = struct.unpack('!Q', data)
98*1130Szelenkov@nginx.com
99*1130Szelenkov@nginx.com        if frame['mask']:
100*1130Szelenkov@nginx.com            mask_bits = recv_bytes(sock, 4)
101*1130Szelenkov@nginx.com
102*1130Szelenkov@nginx.com        data = recv_bytes(sock, length)
103*1130Szelenkov@nginx.com        if frame['mask']:
104*1130Szelenkov@nginx.com            data = self.apply_mask(data, mask_bits)
105*1130Szelenkov@nginx.com
106*1130Szelenkov@nginx.com        if frame['opcode'] == self.OP_CLOSE:
107*1130Szelenkov@nginx.com            if length >= 2:
108*1130Szelenkov@nginx.com                code, = struct.unpack('!H', data[:2])
109*1130Szelenkov@nginx.com                reason = data[2:].decode('utf-8')
110*1130Szelenkov@nginx.com                if not (code in self.CLOSE_CODES or 3000 <= code < 5000):
111*1130Szelenkov@nginx.com                    self.fail('Invalid status code')
112*1130Szelenkov@nginx.com                frame['code'] = code
113*1130Szelenkov@nginx.com                frame['reason'] = reason
114*1130Szelenkov@nginx.com            elif length == 0:
115*1130Szelenkov@nginx.com                frame['code'] = 1005
116*1130Szelenkov@nginx.com                frame['reason'] = ''
117*1130Szelenkov@nginx.com            else:
118*1130Szelenkov@nginx.com                self.fail('Close frame too short')
119*1130Szelenkov@nginx.com
120*1130Szelenkov@nginx.com        frame['data'] = data
121*1130Szelenkov@nginx.com
122*1130Szelenkov@nginx.com        if frame['mask']:
123*1130Szelenkov@nginx.com            self.fail('Received frame with mask')
124*1130Szelenkov@nginx.com
125*1130Szelenkov@nginx.com        return frame
126*1130Szelenkov@nginx.com
127*1130Szelenkov@nginx.com    def frame_to_send(
128*1130Szelenkov@nginx.com        self,
129*1130Szelenkov@nginx.com        opcode,
130*1130Szelenkov@nginx.com        data,
131*1130Szelenkov@nginx.com        fin=True,
132*1130Szelenkov@nginx.com        length=None,
133*1130Szelenkov@nginx.com        rsv1=False,
134*1130Szelenkov@nginx.com        rsv2=False,
135*1130Szelenkov@nginx.com        rsv3=False,
136*1130Szelenkov@nginx.com        mask=True,
137*1130Szelenkov@nginx.com    ):
138*1130Szelenkov@nginx.com        frame = b''
139*1130Szelenkov@nginx.com
140*1130Szelenkov@nginx.com        if isinstance(data, str):
141*1130Szelenkov@nginx.com            data = data.encode('utf-8')
142*1130Szelenkov@nginx.com
143*1130Szelenkov@nginx.com        head1 = (
144*1130Szelenkov@nginx.com            (0b10000000 if fin else 0)
145*1130Szelenkov@nginx.com            | (0b01000000 if rsv1 else 0)
146*1130Szelenkov@nginx.com            | (0b00100000 if rsv2 else 0)
147*1130Szelenkov@nginx.com            | (0b00010000 if rsv3 else 0)
148*1130Szelenkov@nginx.com            | opcode
149*1130Szelenkov@nginx.com        )
150*1130Szelenkov@nginx.com
151*1130Szelenkov@nginx.com        head2 = 0b10000000 if mask else 0
152*1130Szelenkov@nginx.com
153*1130Szelenkov@nginx.com        data_length = len(data) if length is None else length
154*1130Szelenkov@nginx.com        if data_length < 126:
155*1130Szelenkov@nginx.com            frame += struct.pack('!BB', head1, head2 | data_length)
156*1130Szelenkov@nginx.com        elif data_length < 65536:
157*1130Szelenkov@nginx.com            frame += struct.pack('!BBH', head1, head2 | 126, data_length)
158*1130Szelenkov@nginx.com        else:
159*1130Szelenkov@nginx.com            frame += struct.pack('!BBQ', head1, head2 | 127, data_length)
160*1130Szelenkov@nginx.com
161*1130Szelenkov@nginx.com        if mask:
162*1130Szelenkov@nginx.com            mask_bits = struct.pack('!I', random.getrandbits(32))
163*1130Szelenkov@nginx.com            frame += mask_bits
164*1130Szelenkov@nginx.com
165*1130Szelenkov@nginx.com        if mask:
166*1130Szelenkov@nginx.com            frame += self.apply_mask(data, mask_bits)
167*1130Szelenkov@nginx.com        else:
168*1130Szelenkov@nginx.com            frame += data
169*1130Szelenkov@nginx.com
170*1130Szelenkov@nginx.com        return frame
171*1130Szelenkov@nginx.com
172*1130Szelenkov@nginx.com    def frame_write(self, sock, *args, **kwargs):
173*1130Szelenkov@nginx.com        chopsize = kwargs.pop('chopsize') if 'chopsize' in kwargs else None
174*1130Szelenkov@nginx.com
175*1130Szelenkov@nginx.com        frame = self.frame_to_send(*args, **kwargs)
176*1130Szelenkov@nginx.com
177*1130Szelenkov@nginx.com        if chopsize is None:
178*1130Szelenkov@nginx.com            sock.sendall(frame)
179*1130Szelenkov@nginx.com
180*1130Szelenkov@nginx.com        else:
181*1130Szelenkov@nginx.com            pos = 0
182*1130Szelenkov@nginx.com            frame_len = len(frame)
183*1130Szelenkov@nginx.com            while (pos < frame_len):
184*1130Szelenkov@nginx.com                end = min(pos + chopsize, frame_len)
185*1130Szelenkov@nginx.com                sock.sendall(frame[pos:end])
186*1130Szelenkov@nginx.com                pos = end
187*1130Szelenkov@nginx.com
188*1130Szelenkov@nginx.com    def message(self, sock, type, message, fragmention_size=None, **kwargs):
189*1130Szelenkov@nginx.com        message_len = len(message)
190*1130Szelenkov@nginx.com
191*1130Szelenkov@nginx.com        if fragmention_size is None:
192*1130Szelenkov@nginx.com            fragmention_size = message_len
193*1130Szelenkov@nginx.com
194*1130Szelenkov@nginx.com        if message_len <= fragmention_size:
195*1130Szelenkov@nginx.com            self.frame_write(sock, type, message, **kwargs)
196*1130Szelenkov@nginx.com            return
197*1130Szelenkov@nginx.com
198*1130Szelenkov@nginx.com        pos = 0
199*1130Szelenkov@nginx.com        op_code = type
200*1130Szelenkov@nginx.com        while(pos < message_len):
201*1130Szelenkov@nginx.com            end = min(pos + fragmention_size, message_len)
202*1130Szelenkov@nginx.com            fin = (end == message_len)
203*1130Szelenkov@nginx.com            self.frame_write(sock, op_code, message[pos:end], fin=fin, **kwargs)
204*1130Szelenkov@nginx.com            op_code = self.OP_CONT
205*1130Szelenkov@nginx.com            pos = end
206*1130Szelenkov@nginx.com
207*1130Szelenkov@nginx.com    def message_read(self, sock, read_timeout=1):
208*1130Szelenkov@nginx.com        frame = self.frame_read(sock, read_timeout=read_timeout)
209*1130Szelenkov@nginx.com
210*1130Szelenkov@nginx.com        while(not frame['fin']):
211*1130Szelenkov@nginx.com            temp = self.frame_read(sock, read_timeout=read_timeout)
212*1130Szelenkov@nginx.com            frame['data'] += temp['data']
213*1130Szelenkov@nginx.com            frame['fin'] = temp['fin']
214*1130Szelenkov@nginx.com
215*1130Szelenkov@nginx.com        return frame
216