Coverage for /home/antoine/projects/xpra-git/dist/python3/lib64/python/xpra/net/protocol.py : 64%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of Xpra.
2# Copyright (C) 2011-2020 Antoine Martin <antoine@xpra.org>
3# Copyright (C) 2008, 2009, 2010 Nathaniel Smith <njs@pobox.com>
4# Xpra is released under the terms of the GNU GPL v2, or, at your option, any
5# later version. See the file COPYING for details.
7# oh gods it's threads
9# but it works on win32, for whatever that's worth.
11import os
12from socket import error as socket_error
13from threading import Lock, Event
14from queue import Queue
16from xpra.os_util import memoryview_to_bytes, strtobytes, bytestostr, hexstr, monotonic_time
17from xpra.util import repr_ellipsized, ellipsizer, csv, envint, envbool, typedict
18from xpra.make_thread import make_thread, start_thread
19from xpra.net.common import (
20 ConnectionClosedException, may_log_packet,
21 MAX_PACKET_SIZE, FLUSH_HEADER,
22 )
23from xpra.net.bytestreams import ABORT
24from xpra.net import compression
25from xpra.net.compression import (
26 decompress, sanity_checks as compression_sanity_checks,
27 InvalidCompressionException, Compressed, LevelCompressed, Compressible, LargeStructure,
28 )
29from xpra.net import packet_encoding
30from xpra.net.socket_util import guess_packet_type
31from xpra.net.packet_encoding import (
32 decode, sanity_checks as packet_encoding_sanity_checks,
33 InvalidPacketEncodingException,
34 )
35from xpra.net.header import unpack_header, pack_header, FLAGS_CIPHER, FLAGS_NOHEADER, FLAGS_FLUSH, HEADER_SIZE
36from xpra.net.crypto import get_encryptor, get_decryptor, pad, INITIAL_PADDING
37from xpra.log import Logger
39log = Logger("network", "protocol")
40cryptolog = Logger("network", "crypto")
43USE_ALIASES = envbool("XPRA_USE_ALIASES", True)
44READ_BUFFER_SIZE = envint("XPRA_READ_BUFFER_SIZE", 65536)
45#merge header and packet if packet is smaller than:
46PACKET_JOIN_SIZE = envint("XPRA_PACKET_JOIN_SIZE", READ_BUFFER_SIZE)
47LARGE_PACKET_SIZE = envint("XPRA_LARGE_PACKET_SIZE", 4096)
48LOG_RAW_PACKET_SIZE = envbool("XPRA_LOG_RAW_PACKET_SIZE", False)
49#inline compressed data in packet if smaller than:
50INLINE_SIZE = envint("XPRA_INLINE_SIZE", 32768)
51FAKE_JITTER = envint("XPRA_FAKE_JITTER", 0)
52MIN_COMPRESS_SIZE = envint("XPRA_MIN_COMPRESS_SIZE", 378)
53SEND_INVALID_PACKET = envint("XPRA_SEND_INVALID_PACKET", 0)
54SEND_INVALID_PACKET_DATA = strtobytes(os.environ.get("XPRA_SEND_INVALID_PACKET_DATA", b"ZZinvalid-packetZZ"))
57def sanity_checks():
58 """ warns the user if important modules are missing """
59 compression_sanity_checks()
60 packet_encoding_sanity_checks()
63def exit_queue():
64 queue = Queue()
65 for _ in range(10): #just 2 should be enough!
66 queue.put(None)
67 return queue
69def force_flush_queue(q):
70 try:
71 #discard all elements in the old queue and push the None marker:
72 try:
73 while q.qsize()>0:
74 q.read(False)
75 except Exception:
76 pass
77 q.put_nowait(None)
78 except Exception:
79 pass
82def verify_packet(packet):
83 """ look for None values which may have caused the packet to fail encoding """
84 if not isinstance(packet, list):
85 return False
86 assert packet, "invalid packet: %s" % packet
87 tree = ["'%s' packet" % packet[0]]
88 return do_verify_packet(tree, packet)
90def do_verify_packet(tree, packet):
91 def err(msg):
92 log.error("%s in %s", msg, "->".join(tree))
93 def new_tree(append):
94 nt = tree[:]
95 nt.append(append)
96 return nt
97 if packet is None:
98 err("None value")
99 return False
100 r = True
101 if isinstance(packet, (list, tuple)):
102 for i, x in enumerate(packet):
103 if not do_verify_packet(new_tree("[%s]" % i), x):
104 r = False
105 elif isinstance(packet, dict):
106 for k,v in packet.items():
107 if not do_verify_packet(new_tree("key for value='%s'" % str(v)), k):
108 r = False
109 if not do_verify_packet(new_tree("value for key='%s'" % str(k)), v):
110 r = False
111 elif isinstance(packet, (int, bool, str, bytes)):
112 pass
113 else:
114 err("unsupported type: %s" % type(packet))
115 r = False
116 return r
119class Protocol:
120 """
121 This class handles sending and receiving packets,
122 it will encode and compress them before sending,
123 and decompress and decode when receiving.
124 """
126 CONNECTION_LOST = "connection-lost"
127 GIBBERISH = "gibberish"
128 INVALID = "invalid"
130 TYPE = "xpra"
132 def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None):
133 """
134 You must call this constructor and source_has_more() from the main thread.
135 """
136 assert scheduler is not None
137 assert conn is not None
138 self.start_time = monotonic_time()
139 self.timeout_add = scheduler.timeout_add
140 self.idle_add = scheduler.idle_add
141 self.source_remove = scheduler.source_remove
142 self.read_buffer_size = READ_BUFFER_SIZE
143 self.hangup_delay = 1000
144 self._conn = conn
145 if FAKE_JITTER>0: # pragma: no cover
146 from xpra.net.fake_jitter import FakeJitter
147 fj = FakeJitter(self.timeout_add, process_packet_cb, FAKE_JITTER)
148 self._process_packet_cb = fj.process_packet_cb
149 else:
150 self._process_packet_cb = process_packet_cb
151 self.make_chunk_header = self.make_xpra_header
152 self.make_frame_header = self.noframe_header
153 self._write_queue = Queue(1)
154 self._read_queue = Queue(20)
155 self._process_read = self.read_queue_put
156 self._read_queue_put = self.read_queue_put
157 # Invariant: if .source is None, then _source_has_more == False
158 self._get_packet_cb = get_packet_cb
159 #counters:
160 self.input_stats = {}
161 self.input_packetcount = 0
162 self.input_raw_packetcount = 0
163 self.output_stats = {}
164 self.output_packetcount = 0
165 self.output_raw_packetcount = 0
166 #initial value which may get increased by client/server after handshake:
167 self.max_packet_size = MAX_PACKET_SIZE
168 self.abs_max_packet_size = 256*1024*1024
169 self.large_packets = [b"hello", b"window-metadata", b"sound-data", b"notify_show", b"setting-change", b"shell-reply"]
170 self.send_aliases = {}
171 self.send_flush_flag = False
172 self.receive_aliases = {}
173 self._log_stats = None #None here means auto-detect
174 self._closed = False
175 self.encoder = "none"
176 self._encoder = packet_encoding.get_encoder("none")
177 self.compressor = "none"
178 self._compress = compression.get_compressor("none")
179 self.compression_level = 0
180 self.cipher_in = None
181 self.cipher_in_name = None
182 self.cipher_in_block_size = 0
183 self.cipher_in_padding = INITIAL_PADDING
184 self.cipher_out = None
185 self.cipher_out_name = None
186 self.cipher_out_block_size = 0
187 self.cipher_out_padding = INITIAL_PADDING
188 self._write_lock = Lock()
189 self._write_thread = None
190 self._read_thread = make_thread(self._read_thread_loop, "read", daemon=True)
191 self._read_parser_thread = None #started when needed
192 self._write_format_thread = None #started when needed
193 self._source_has_more = Event()
195 STATE_FIELDS = ("max_packet_size", "large_packets", "send_aliases", "receive_aliases",
196 "cipher_in", "cipher_in_name", "cipher_in_block_size", "cipher_in_padding",
197 "cipher_out", "cipher_out_name", "cipher_out_block_size", "cipher_out_padding",
198 "compression_level", "encoder", "compressor")
200 def save_state(self):
201 state = {}
202 for x in Protocol.STATE_FIELDS:
203 state[x] = getattr(self, x)
204 return state
206 def restore_state(self, state):
207 assert state is not None
208 for x in Protocol.STATE_FIELDS:
209 assert x in state, "field %s is missing" % x
210 setattr(self, x, state[x])
211 #special handling for compressor / encoder which are named objects:
212 self.enable_compressor(self.compressor)
213 self.enable_encoder(self.encoder)
216 def is_closed(self) -> bool:
217 return self._closed
220 def wait_for_io_threads_exit(self, timeout=None):
221 io_threads = [x for x in (self._read_thread, self._write_thread) if x is not None]
222 for t in io_threads:
223 if t.is_alive():
224 t.join(timeout)
225 exited = True
226 cinfo = self._conn or "cleared connection"
227 for t in io_threads:
228 if t.is_alive():
229 log.warn("Warning: %s thread of %s is still alive (timeout=%s)", t.name, cinfo, timeout)
230 exited = False
231 return exited
233 def set_packet_source(self, get_packet_cb):
234 self._get_packet_cb = get_packet_cb
237 def set_cipher_in(self, ciphername, iv, password, key_salt, iterations, padding):
238 cryptolog("set_cipher_in%s", (ciphername, iv, password, key_salt, iterations))
239 self.cipher_in, self.cipher_in_block_size = get_decryptor(ciphername, iv, password, key_salt, iterations)
240 self.cipher_in_padding = padding
241 if self.cipher_in_name!=ciphername:
242 cryptolog.info("receiving data using %s encryption", ciphername)
243 self.cipher_in_name = ciphername
245 def set_cipher_out(self, ciphername, iv, password, key_salt, iterations, padding):
246 cryptolog("set_cipher_out%s", (ciphername, iv, password, key_salt, iterations, padding))
247 self.cipher_out, self.cipher_out_block_size = get_encryptor(ciphername, iv, password, key_salt, iterations)
248 self.cipher_out_padding = padding
249 if self.cipher_out_name!=ciphername:
250 cryptolog.info("sending data using %s encryption", ciphername)
251 self.cipher_out_name = ciphername
254 def __repr__(self):
255 return "Protocol(%s)" % self._conn
257 def get_threads(self):
258 return tuple(x for x in (
259 self._write_thread,
260 self._read_thread,
261 self._read_parser_thread,
262 self._write_format_thread,
263 ) if x is not None)
265 def accept(self):
266 pass
268 def parse_remote_caps(self, caps : typedict):
269 for k,v in caps.dictget("aliases", {}).items():
270 self.send_aliases[bytestostr(k)] = v
271 if FLUSH_HEADER:
272 self.send_flush_flag = caps.boolget("flush", False)
274 def get_info(self, alias_info=True) -> dict:
275 info = {
276 "large_packets" : tuple(bytestostr(x) for x in self.large_packets),
277 "compression_level" : self.compression_level,
278 "max_packet_size" : self.max_packet_size,
279 "aliases" : USE_ALIASES,
280 "flush" : self.send_flush_flag,
281 }
282 c = self.compressor
283 if c:
284 info["compressor"] = c
285 e = self.encoder
286 if e:
287 info["encoder"] = e
288 if alias_info:
289 info["send_alias"] = self.send_aliases
290 info["receive_alias"] = self.receive_aliases
291 c = self._conn
292 if c:
293 try:
294 info.update(c.get_info())
295 except Exception:
296 log.error("error collecting connection information on %s", c, exc_info=True)
297 #add stats to connection info:
298 info.setdefault("input", {}).update({
299 "buffer-size" : self.read_buffer_size,
300 "hangup-delay" : self.hangup_delay,
301 "packetcount" : self.input_packetcount,
302 "raw_packetcount" : self.input_raw_packetcount,
303 "count" : self.input_stats,
304 "cipher" : {"": self.cipher_in_name or "",
305 "padding" : self.cipher_in_padding,
306 },
307 })
308 info.setdefault("output", {}).update({
309 "packet-join-size" : PACKET_JOIN_SIZE,
310 "large-packet-size" : LARGE_PACKET_SIZE,
311 "inline-size" : INLINE_SIZE,
312 "min-compress-size" : MIN_COMPRESS_SIZE,
313 "packetcount" : self.output_packetcount,
314 "raw_packetcount" : self.output_raw_packetcount,
315 "count" : self.output_stats,
316 "cipher" : {"": self.cipher_out_name or "",
317 "padding" : self.cipher_out_padding
318 },
319 })
320 shm = self._source_has_more
321 info["has_more"] = shm and shm.is_set()
322 for t in (self._write_thread, self._read_thread, self._read_parser_thread, self._write_format_thread):
323 if t:
324 info.setdefault("thread", {})[t.name] = t.is_alive()
325 return info
328 def start(self):
329 def start_network_read_thread():
330 if not self._closed:
331 self._read_thread.start()
332 self.idle_add(start_network_read_thread)
333 if SEND_INVALID_PACKET:
334 self.timeout_add(SEND_INVALID_PACKET*1000, self.raw_write, "invalid", SEND_INVALID_PACKET_DATA)
337 def send_disconnect(self, reasons, done_callback=None):
338 self.flush_then_close(["disconnect"]+list(reasons), done_callback=done_callback)
340 def send_now(self, packet):
341 if self._closed:
342 log("send_now(%s ...) connection is closed already, not sending", packet[0])
343 return
344 log("send_now(%s ...)", packet[0])
345 if self._get_packet_cb:
346 raise Exception("cannot use send_now when a packet source exists! (set to %s)" % self._get_packet_cb)
347 tmp_queue = [packet]
348 def packet_cb():
349 self._get_packet_cb = None
350 if not tmp_queue:
351 raise Exception("packet callback used more than once!")
352 packet = tmp_queue.pop()
353 return (packet, )
354 self._get_packet_cb = packet_cb
355 self.source_has_more()
357 def source_has_more(self): #pylint: disable=method-hidden
358 shm = self._source_has_more
359 if not shm or self._closed:
360 return
361 shm.set()
362 #start the format thread:
363 if not self._write_format_thread and not self._closed:
364 self._write_format_thread = make_thread(self.write_format_thread_loop, "format", daemon=True)
365 self._write_format_thread.start()
366 #from now on, take shortcut:
367 self.source_has_more = self._source_has_more.set
369 def write_format_thread_loop(self):
370 log("write_format_thread_loop starting")
371 try:
372 while not self._closed:
373 self._source_has_more.wait()
374 gpc = self._get_packet_cb
375 if self._closed or not gpc:
376 return
377 self._add_packet_to_queue(*gpc())
378 except Exception as e:
379 if self._closed:
380 return
381 self._internal_error("error in network packet write/format", e, exc_info=True)
383 def _add_packet_to_queue(self, packet, start_send_cb=None, end_send_cb=None, fail_cb=None, synchronous=True, has_more=False, wait_for_more=False):
384 if not has_more:
385 shm = self._source_has_more
386 if shm:
387 shm.clear()
388 if packet is None:
389 return
390 #log("add_packet_to_queue(%s ... %s, %s, %s)", packet[0], synchronous, has_more, wait_for_more)
391 packet_type = packet[0]
392 chunks = self.encode(packet)
393 with self._write_lock:
394 if self._closed:
395 return
396 try:
397 self._add_chunks_to_queue(packet_type, chunks, start_send_cb, end_send_cb, fail_cb, synchronous, has_more or wait_for_more)
398 except:
399 log.error("Error: failed to queue '%s' packet", packet[0])
400 log("add_chunks_to_queue%s", (chunks, start_send_cb, end_send_cb, fail_cb), exc_info=True)
401 raise
403 def _add_chunks_to_queue(self, packet_type, chunks, start_send_cb=None, end_send_cb=None, fail_cb=None, synchronous=True, more=False):
404 """ the write_lock must be held when calling this function """
405 items = []
406 for proto_flags,index,level,data in chunks:
407 payload_size = len(data)
408 actual_size = payload_size
409 if self.cipher_out:
410 proto_flags |= FLAGS_CIPHER
411 #note: since we are padding: l!=len(data)
412 padding_size = self.cipher_out_block_size - (payload_size % self.cipher_out_block_size)
413 if padding_size==0:
414 padded = data
415 else:
416 # pad byte value is number of padding bytes added
417 padded = memoryview_to_bytes(data) + pad(self.cipher_out_padding, padding_size)
418 actual_size += padding_size
419 assert len(padded)==actual_size, "expected padded size to be %i, but got %i" % (len(padded), actual_size)
420 data = self.cipher_out.encrypt(padded)
421 assert len(data)==actual_size, "expected encrypted size to be %i, but got %i" % (len(data), actual_size)
422 cryptolog("sending %s bytes %s encrypted with %s padding",
423 payload_size, self.cipher_out_name, padding_size)
424 if proto_flags & FLAGS_NOHEADER:
425 assert not self.cipher_out
426 #for plain/text packets (ie: gibberish response)
427 log("sending %s bytes without header", payload_size)
428 items.append(data)
429 else:
430 #if the other end can use this flag, expose it:
431 if self.send_flush_flag and not more and index==0:
432 proto_flags |= FLAGS_FLUSH
433 #the xpra packet header:
434 #(WebSocketProtocol may also add a websocket header too)
435 header = self.make_chunk_header(packet_type, proto_flags, level, index, payload_size)
436 if actual_size<PACKET_JOIN_SIZE:
437 if not isinstance(data, bytes):
438 data = memoryview_to_bytes(data)
439 items.append(header+data)
440 else:
441 items.append(header)
442 items.append(data)
443 #WebSocket header may be added here:
444 frame_header = self.make_frame_header(packet_type, items) #pylint: disable=assignment-from-none
445 if frame_header:
446 item0 = items[0]
447 if len(item0)<PACKET_JOIN_SIZE:
448 if not isinstance(item0, bytes):
449 item0 = memoryview_to_bytes(item0)
450 items[0] = frame_header + item0
451 else:
452 items.insert(0, frame_header)
453 self.raw_write(packet_type, items, start_send_cb, end_send_cb, fail_cb, synchronous, more)
455 def make_xpra_header(self, _packet_type, proto_flags, level, index, payload_size) -> bytes:
456 return pack_header(proto_flags, level, index, payload_size)
458 def noframe_header(self, _packet_type, _items):
459 return None
462 def start_write_thread(self):
463 self._write_thread = start_thread(self._write_thread_loop, "write", daemon=True)
465 def raw_write(self, packet_type, items, start_cb=None, end_cb=None, fail_cb=None, synchronous=True, more=False):
466 """ Warning: this bypasses the compression and packet encoder! """
467 if self._write_thread is None:
468 log("raw_write for %s, starting write thread", packet_type)
469 self.start_write_thread()
470 self._write_queue.put((items, start_cb, end_cb, fail_cb, synchronous, more))
473 def enable_default_encoder(self):
474 opts = packet_encoding.get_enabled_encoders()
475 assert opts, "no packet encoders available!"
476 self.enable_encoder(opts[0])
478 def enable_encoder_from_caps(self, caps):
479 opts = packet_encoding.get_enabled_encoders(order=packet_encoding.PERFORMANCE_ORDER)
480 log("enable_encoder_from_caps(..) options=%s", opts)
481 for e in opts:
482 if caps.boolget(e, e=="bencode"):
483 self.enable_encoder(e)
484 return True
485 log.error("no matching packet encoder found!")
486 return False
488 def enable_encoder(self, e):
489 self._encoder = packet_encoding.get_encoder(e)
490 self.encoder = e
491 log("enable_encoder(%s): %s", e, self._encoder)
494 def enable_default_compressor(self):
495 opts = compression.get_enabled_compressors()
496 if opts:
497 self.enable_compressor(opts[0])
498 else:
499 self.enable_compressor("none")
501 def enable_compressor_from_caps(self, caps):
502 if self.compression_level==0:
503 self.enable_compressor("none")
504 return
505 opts = compression.get_enabled_compressors(order=compression.PERFORMANCE_ORDER)
506 compressors = caps.strtupleget("compressors")
507 log("enable_compressor_from_caps(..) options=%s", opts)
508 for c in opts: #ie: [zlib, lz4, lzo]
509 if c=="none":
510 continue
511 if c in compressors or caps.boolget(c):
512 self.enable_compressor(c)
513 return
514 log.warn("Warning: compression disabled, no matching compressor found")
515 self.enable_compressor("none")
517 def enable_compressor(self, compressor):
518 self._compress = compression.get_compressor(compressor)
519 self.compressor = compressor
520 log("enable_compressor(%s): %s", compressor, self._compress)
523 def encode(self, packet_in):
524 """
525 Given a packet (tuple or list of items), converts it for the wire.
526 This method returns all the binary packets to send, as an array of:
527 (index, compression_level and compression flags, binary_data)
528 The index, if positive indicates the item to populate in the packet
529 whose index is zero.
530 ie: ["blah", [large binary data], "hello", 200]
531 may get converted to:
532 [
533 (1, compression_level, [large binary data now zlib compressed]),
534 (0, 0, bencoded/rencoded(["blah", '', "hello", 200]))
535 ]
536 """
537 packets = []
538 packet = list(packet_in)
539 level = self.compression_level
540 size_check = LARGE_PACKET_SIZE
541 min_comp_size = MIN_COMPRESS_SIZE
542 for i in range(1, len(packet)):
543 item = packet[i]
544 if item is None:
545 raise TypeError("invalid None value in %s packet at index %s" % (packet[0], i))
546 ti = type(item)
547 if ti in (int, bool, dict, list, tuple):
548 continue
549 try:
550 l = len(item)
551 except TypeError as e:
552 raise TypeError("invalid type %s in %s packet at index %s: %s" % (ti, packet[0], i, e)) from None
553 if ti==LargeStructure:
554 packet[i] = item.data
555 continue
556 if ti==Compressible:
557 #this is a marker used to tell us we should compress it now
558 #(used by the client for clipboard data)
559 item = item.compress()
560 packet[i] = item
561 ti = type(item)
562 #(it may now be a "Compressed" item and be processed further)
563 if ti in (Compressed, LevelCompressed):
564 #already compressed data (usually pixels, cursors, etc)
565 if not item.can_inline or l>INLINE_SIZE:
566 il = 0
567 if ti==LevelCompressed:
568 #unlike Compressed (usually pixels, decompressed in the paint thread),
569 #LevelCompressed is decompressed by the network layer
570 #so we must tell it how to do that and pass the level flag
571 il = item.level
572 packets.append((0, i, il, item.data))
573 packet[i] = b''
574 else:
575 #data is small enough, inline it:
576 packet[i] = item.data
577 min_comp_size += l
578 size_check += l
579 elif ti==bytes and level>0 and l>LARGE_PACKET_SIZE:
580 log.warn("Warning: found a large uncompressed item")
581 log.warn(" in packet '%s' at position %i: %s bytes", packet[0], i, len(item))
582 #add new binary packet with large item:
583 cl, cdata = self._compress(item, level)
584 packets.append((0, i, cl, cdata))
585 #replace this item with an empty string placeholder:
586 packet[i] = ''
587 elif ti not in (str, bytes):
588 log.warn("Warning: unexpected data type %s", ti)
589 log.warn(" in '%s' packet at position %i: %s", packet[0], i, repr_ellipsized(item))
590 #now the main packet (or what is left of it):
591 packet_type = packet[0]
592 self.output_stats[packet_type] = self.output_stats.get(packet_type, 0)+1
593 if USE_ALIASES:
594 alias = self.send_aliases.get(packet_type)
595 if alias:
596 #replace the packet type with the alias:
597 packet[0] = alias
598 else:
599 log("packet type send alias not found for '%s'", packet_type)
600 try:
601 main_packet, proto_flags = self._encoder(packet)
602 except Exception:
603 if self._closed:
604 return [], 0
605 log.error("Error: failed to encode packet: %s", packet, exc_info=True)
606 #make the error a bit nicer to parse: undo aliases:
607 packet[0] = packet_type
608 verify_packet(packet)
609 raise
610 if len(main_packet)>size_check and strtobytes(packet_in[0]) not in self.large_packets:
611 log.warn("Warning: found large packet")
612 log.warn(" '%s' packet is %s bytes: ", packet_type, len(main_packet))
613 log.warn(" argument types: %s", csv(type(x) for x in packet[1:]))
614 log.warn(" sizes: %s", csv(len(strtobytes(x)) for x in packet[1:]))
615 log.warn(" packet: %s", repr_ellipsized(packet))
616 #compress, but don't bother for small packets:
617 if level>0 and len(main_packet)>min_comp_size:
618 try:
619 cl, cdata = self._compress(main_packet, level)
620 except Exception as e:
621 log.error("Error compressing '%s' packet", packet_type)
622 log.error(" %s", e)
623 raise
624 packets.append((proto_flags, 0, cl, cdata))
625 else:
626 packets.append((proto_flags, 0, 0, main_packet))
627 may_log_packet(True, packet_type, packet)
628 return packets
630 def set_compression_level(self, level : int):
631 #this may be used next time encode() is called
632 assert 0<=level<=10, "invalid compression level: %s (must be between 0 and 10" % level
633 self.compression_level = level
636 def _io_thread_loop(self, name, callback):
637 try:
638 log("io_thread_loop(%s, %s) loop starting", name, callback)
639 while not self._closed and callback():
640 pass
641 log("io_thread_loop(%s, %s) loop ended, closed=%s", name, callback, self._closed)
642 except ConnectionClosedException:
643 log("%s closed", self._conn, exc_info=True)
644 if not self._closed:
645 #ConnectionClosedException means the warning has been logged already
646 self._connection_lost("%s connection %s closed" % (name, self._conn))
647 except (OSError, socket_error) as e:
648 if not self._closed:
649 self._internal_error("%s connection %s reset" % (name, self._conn), e, exc_info=e.args[0] not in ABORT)
650 except Exception as e:
651 #can happen during close(), in which case we just ignore:
652 if not self._closed:
653 log.error("Error: %s on %s failed: %s", name, self._conn, type(e), exc_info=True)
654 self.close()
657 def _write_thread_loop(self):
658 self._io_thread_loop("write", self._write)
659 def _write(self):
660 items = self._write_queue.get()
661 # Used to signal that we should exit:
662 if items is None:
663 log("write thread: empty marker, exiting")
664 self.close()
665 return False
666 return self.write_items(*items)
668 def write_items(self, buf_data, start_cb=None, end_cb=None, fail_cb=None, synchronous=True, more=False):
669 conn = self._conn
670 if not conn:
671 return False
672 if more or len(buf_data)>1:
673 conn.set_nodelay(False)
674 if len(buf_data)>1:
675 conn.set_cork(True)
676 if start_cb:
677 try:
678 start_cb(conn.output_bytecount)
679 except Exception:
680 if not self._closed:
681 log.error("Error on write start callback %s", start_cb, exc_info=True)
682 self.write_buffers(buf_data, fail_cb, synchronous)
683 if len(buf_data)>1:
684 conn.set_cork(False)
685 if not more:
686 conn.set_nodelay(True)
687 if end_cb:
688 try:
689 end_cb(self._conn.output_bytecount)
690 except Exception:
691 if not self._closed:
692 log.error("Error on write end callback %s", end_cb, exc_info=True)
693 return True
695 def write_buffers(self, buf_data, _fail_cb, _synchronous):
696 con = self._conn
697 if not con:
698 return
699 for buf in buf_data:
700 while buf and not self._closed:
701 written = self.con_write(con, buf)
702 #example test code, for sending small chunks very slowly:
703 #written = con.write(buf[:1024])
704 #import time
705 #time.sleep(0.05)
706 if written:
707 buf = buf[written:]
708 self.output_raw_packetcount += 1
709 self.output_packetcount += 1
711 def con_write(self, con, buf):
712 return con.write(buf)
715 def _read_thread_loop(self):
716 self._io_thread_loop("read", self._read)
717 def _read(self):
718 buf = self._conn.read(self.read_buffer_size)
719 #log("read thread: got data of size %s: %s", len(buf), repr_ellipsized(buf))
720 #add to the read queue (or whatever takes its place - see steal_connection)
721 self._process_read(buf)
722 if not buf:
723 log("read thread: eof")
724 #give time to the parse thread to call close itself
725 #so it has time to parse and process the last packet received
726 self.timeout_add(1000, self.close)
727 return False
728 self.input_raw_packetcount += 1
729 return True
731 def _internal_error(self, message="", exc=None, exc_info=False):
732 #log exception info with last log message
733 if self._closed:
734 return
735 ei = exc_info
736 if exc:
737 ei = None #log it separately below
738 log.error("Error: %s", message, exc_info=ei)
739 if exc:
740 log.error(" %s", exc, exc_info=exc_info)
741 exc = None
742 self.idle_add(self._connection_lost, message)
744 def _connection_lost(self, message="", exc_info=False):
745 log("connection lost: %s", message, exc_info=exc_info)
746 self.close()
747 return False
750 def invalid(self, msg, data):
751 self.idle_add(self._process_packet_cb, self, [Protocol.INVALID, msg, data])
752 # Then hang up:
753 self.timeout_add(1000, self._connection_lost, msg)
755 def gibberish(self, msg, data):
756 self.idle_add(self._process_packet_cb, self, [Protocol.GIBBERISH, msg, data])
757 # Then hang up:
758 self.timeout_add(self.hangup_delay, self._connection_lost, msg)
761 #delegates to invalid_header()
762 #(so this can more easily be intercepted and overriden
763 # see tcp-proxy)
764 def invalid_header(self, proto, data, msg="invalid packet header"):
765 self._invalid_header(proto, data, msg)
767 def _invalid_header(self, proto, data, msg=""):
768 log("invalid_header(%s, %s bytes: '%s', %s)",
769 proto, len(data or ""), msg, ellipsizer(data))
770 guess = guess_packet_type(data)
771 if guess:
772 err = "invalid packet format: %s" % guess
773 else:
774 err = "%s: 0x%s" % (msg, hexstr(data[:HEADER_SIZE]))
775 if len(data)>1:
776 err += " read buffer=%s (%i bytes)" % (repr_ellipsized(data), len(data))
777 self.gibberish(err, data)
780 def process_read(self, data):
781 self._read_queue_put(data)
783 def read_queue_put(self, data):
784 #start the parse thread if needed:
785 if not self._read_parser_thread and not self._closed:
786 if data is None:
787 log("empty marker in read queue, exiting")
788 self.idle_add(self.close)
789 return
790 self.start_read_parser_thread()
791 self._read_queue.put(data)
792 #from now on, take shortcut:
793 self._read_queue_put = self._read_queue.put
795 def start_read_parser_thread(self):
796 self._read_parser_thread = start_thread(self._read_parse_thread_loop, "parse", daemon=True)
798 def _read_parse_thread_loop(self):
799 log("read_parse_thread_loop starting")
800 try:
801 self.do_read_parse_thread_loop()
802 except Exception as e:
803 if self._closed:
804 return
805 self._internal_error("error in network packet reading/parsing", e, exc_info=True)
807 def do_read_parse_thread_loop(self):
808 """
809 Process the individual network packets placed in _read_queue.
810 Concatenate the raw packet data, then try to parse it.
811 Extract the individual packets from the potentially large buffer,
812 saving the rest of the buffer for later, and optionally decompress this data
813 and re-construct the one python-object-packet from potentially multiple packets (see packet_index).
814 The 8 bytes packet header gives us information on the packet index, packet size and compression.
815 The actual processing of the packet is done via the callback process_packet_cb,
816 this will be called from this parsing thread so any calls that need to be made
817 from the UI thread will need to use a callback (usually via 'idle_add')
818 """
819 header = b""
820 read_buffers = []
821 payload_size = -1
822 padding_size = 0
823 packet_index = 0
824 compression_level = 0
825 raw_packets = {}
826 while not self._closed:
827 buf = self._read_queue.get()
828 if not buf:
829 log("parse thread: empty marker, exiting")
830 self.idle_add(self.close)
831 return
833 read_buffers.append(buf)
834 while read_buffers:
835 #have we read the header yet?
836 if payload_size<0:
837 #try to handle the first buffer:
838 buf = read_buffers[0]
839 if not header and buf[0]!=ord("P"):
840 self.invalid_header(self, buf, "invalid packet header byte")
841 return
842 #how much to we need to slice off to complete the header:
843 read = min(len(buf), HEADER_SIZE-len(header))
844 header += memoryview_to_bytes(buf[:read])
845 if len(header)<HEADER_SIZE:
846 #need to process more buffers to get a full header:
847 read_buffers.pop(0)
848 continue
849 elif len(buf)>read:
850 #got the full header and more, keep the rest of the packet:
851 read_buffers[0] = buf[read:]
852 else:
853 #we only got the header:
854 assert len(buf)==read
855 read_buffers.pop(0)
856 continue
857 #parse the header:
858 # format: struct.pack(b'cBBBL', ...) - HEADER_SIZE bytes
859 _, protocol_flags, compression_level, packet_index, data_size = unpack_header(header)
861 #sanity check size (will often fail if not an xpra client):
862 if data_size>self.abs_max_packet_size:
863 self.invalid_header(self, header, "invalid size in packet header: %s" % data_size)
864 return
866 if protocol_flags & FLAGS_CIPHER:
867 if self.cipher_in_block_size==0 or not self.cipher_in_name:
868 cryptolog.warn("Warning: received cipher block,")
869 cryptolog.warn(" but we don't have a cipher to decrypt it with,")
870 cryptolog.warn(" not an xpra client?")
871 self.invalid_header(self, header, "invalid encryption packet flag (no cipher configured)")
872 return
873 padding_size = self.cipher_in_block_size - (data_size % self.cipher_in_block_size)
874 payload_size = data_size + padding_size
875 else:
876 #no cipher, no padding:
877 padding_size = 0
878 payload_size = data_size
879 assert payload_size>0, "invalid payload size: %i" % payload_size
881 if payload_size>self.max_packet_size:
882 #this packet is seemingly too big, but check again from the main UI thread
883 #this gives 'set_max_packet_size' a chance to run from "hello"
884 def check_packet_size(size_to_check, packet_header):
885 if self._closed:
886 return False
887 log("check_packet_size(%#x, %s) max=%#x",
888 size_to_check, hexstr(packet_header), self.max_packet_size)
889 if size_to_check>self.max_packet_size:
890 msg = "packet size requested is %s but maximum allowed is %s" % \
891 (size_to_check, self.max_packet_size)
892 self.invalid(msg, packet_header)
893 return False
894 self.timeout_add(1000, check_packet_size, payload_size, header)
896 #how much data do we have?
897 bl = sum(len(v) for v in read_buffers)
898 if bl<payload_size:
899 # incomplete packet, wait for the rest to arrive
900 break
902 buf = read_buffers[0]
903 if len(buf)==payload_size:
904 #exact match, consume it all:
905 data = read_buffers.pop(0)
906 elif len(buf)>payload_size:
907 #keep rest of packet for later:
908 read_buffers[0] = buf[payload_size:]
909 data = buf[:payload_size]
910 else:
911 #we need to aggregate chunks,
912 #just concatenate them all:
913 data = b"".join(read_buffers)
914 if bl==payload_size:
915 #nothing left:
916 read_buffers = []
917 else:
918 #keep the left over:
919 read_buffers = [data[payload_size:]]
920 data = data[:payload_size]
922 #decrypt if needed:
923 if self.cipher_in:
924 if not protocol_flags & FLAGS_CIPHER:
925 self.invalid("unencrypted packet dropped", data)
926 return
927 cryptolog("received %i %s encrypted bytes with %i padding",
928 payload_size, self.cipher_in_name, padding_size)
929 data = self.cipher_in.decrypt(data)
930 if padding_size > 0:
931 def debug_str(s):
932 try:
933 return hexstr(s)
934 except Exception:
935 return csv(tuple(s))
936 # pad byte value is number of padding bytes added
937 padtext = pad(self.cipher_in_padding, padding_size)
938 if data.endswith(padtext):
939 cryptolog("found %s %s padding", self.cipher_in_padding, self.cipher_in_name)
940 else:
941 actual_padding = data[-padding_size:]
942 cryptolog.warn("Warning: %s decryption failed: invalid padding", self.cipher_in_name)
943 cryptolog(" cipher block size=%i, data size=%i", self.cipher_in_block_size, data_size)
944 cryptolog(" data does not end with %i %s padding bytes %s (%s)",
945 padding_size, self.cipher_in_padding, debug_str(padtext), type(padtext))
946 cryptolog(" but with %i bytes: %s (%s)",
947 len(actual_padding), debug_str(actual_padding), type(data))
948 cryptolog(" decrypted data (%i bytes): %r..", len(data), data[:128])
949 cryptolog(" decrypted data (hex): %s..", debug_str(data[:128]))
950 self._internal_error("%s encryption padding error - wrong key?" % self.cipher_in_name)
951 return
952 data = data[:-padding_size]
953 #uncompress if needed:
954 if compression_level>0:
955 try:
956 data = decompress(data, compression_level)
957 except InvalidCompressionException as e:
958 self.invalid("invalid compression: %s" % e, data)
959 return
960 except Exception as e:
961 ctype = compression.get_compression_type(compression_level)
962 log("%s packet decompression failed", ctype, exc_info=True)
963 msg = "%s packet decompression failed" % ctype
964 if self.cipher_in:
965 msg += " (invalid encryption key?)"
966 else:
967 #only include the exception text when not using encryption
968 #as this may leak crypto information:
969 msg += " %s" % e
970 del e
971 self.gibberish(msg, data)
972 return
974 if self._closed:
975 return
977 #we're processing this packet,
978 #make sure we get a new header next time
979 header = b""
980 if packet_index>0:
981 #raw packet, store it and continue:
982 raw_packets[packet_index] = data
983 payload_size = -1
984 if len(raw_packets)>=4:
985 self.invalid("too many raw packets: %s" % len(raw_packets), data)
986 return
987 continue
988 #final packet (packet_index==0), decode it:
989 try:
990 packet = list(decode(data, protocol_flags))
991 except InvalidPacketEncodingException as e:
992 self.invalid("invalid packet encoding: %s" % e, data)
993 return
994 except ValueError as e:
995 etype = packet_encoding.get_packet_encoding_type(protocol_flags)
996 log.error("Error parsing %s packet:", etype)
997 log.error(" %s", e)
998 if self._closed:
999 return
1000 log("failed to parse %s packet: %s", etype, hexstr(data[:128]))
1001 log(" %s", e)
1002 log(" data: %s", repr_ellipsized(data))
1003 log(" packet index=%i, packet size=%i, buffer size=%s", packet_index, payload_size, bl)
1004 self.gibberish("failed to parse %s packet" % etype, data)
1005 return
1007 if self._closed:
1008 return
1009 payload_size = -1
1010 #add any raw packets back into it:
1011 if raw_packets:
1012 for index,raw_data in raw_packets.items():
1013 #replace placeholder with the raw_data packet data:
1014 packet[index] = raw_data
1015 raw_packets = {}
1017 packet_type = packet[0]
1018 if self.receive_aliases and isinstance(packet_type, int):
1019 packet_type = self.receive_aliases.get(packet_type)
1020 if packet_type:
1021 packet[0] = packet_type
1022 self.input_stats[packet_type] = self.output_stats.get(packet_type, 0)+1
1023 if LOG_RAW_PACKET_SIZE:
1024 log("%s: %i bytes", packet_type, HEADER_SIZE + payload_size)
1026 self.input_packetcount += 1
1027 log("processing packet %s", bytestostr(packet_type))
1028 self._process_packet_cb(self, packet)
1029 packet = None
1031 def flush_then_close(self, last_packet, done_callback=None): #pylint: disable=method-hidden
1032 """ Note: this is best effort only
1033 the packet may not get sent.
1035 We try to get the write lock,
1036 we try to wait for the write queue to flush
1037 we queue our last packet,
1038 we wait again for the queue to flush,
1039 then no matter what, we close the connection and stop the threads.
1040 """
1041 def closing_already(last_packet, done_callback=None):
1042 log("flush_then_close%s had already been called, this new request has been ignored",
1043 (last_packet, done_callback))
1044 self.flush_then_close = closing_already
1045 log("flush_then_close(%s, %s) closed=%s", last_packet, done_callback, self._closed)
1046 def done():
1047 log("flush_then_close: done, callback=%s", done_callback)
1048 if done_callback:
1049 done_callback()
1050 if self._closed:
1051 log("flush_then_close: already closed")
1052 done()
1053 return
1054 def wait_for_queue(timeout=10):
1055 #IMPORTANT: if we are here, we have the write lock held!
1056 if not self._write_queue.empty():
1057 #write queue still has stuff in it..
1058 if timeout<=0:
1059 log("flush_then_close: queue still busy, closing without sending the last packet")
1060 try:
1061 self._write_lock.release()
1062 except Exception:
1063 pass
1064 self.close()
1065 done()
1066 else:
1067 log("flush_then_close: still waiting for queue to flush")
1068 self.timeout_add(100, wait_for_queue, timeout-1)
1069 else:
1070 log("flush_then_close: queue is now empty, sending the last packet and closing")
1071 chunks = self.encode(last_packet)
1072 def close_and_release():
1073 log("flush_then_close: wait_for_packet_sent() close_and_release()")
1074 self.close()
1075 try:
1076 self._write_lock.release()
1077 except Exception:
1078 pass
1079 done()
1080 def wait_for_packet_sent():
1081 log("flush_then_close: wait_for_packet_sent() queue.empty()=%s, closed=%s",
1082 self._write_queue.empty(), self._closed)
1083 if self._write_queue.empty() or self._closed:
1084 #it got sent, we're done!
1085 close_and_release()
1086 return False
1087 return not self._closed #run until we manage to close (here or via the timeout)
1088 def packet_queued(*_args):
1089 #if we're here, we have the lock and the packet is in the write queue
1090 log("flush_then_close: packet_queued() closed=%s", self._closed)
1091 if wait_for_packet_sent():
1092 #check again every 100ms
1093 self.timeout_add(100, wait_for_packet_sent)
1094 self._add_chunks_to_queue(last_packet[0], chunks,
1095 start_send_cb=None, end_send_cb=packet_queued,
1096 synchronous=False, more=False)
1097 #just in case wait_for_packet_sent never fires:
1098 self.timeout_add(5*1000, close_and_release)
1100 def wait_for_write_lock(timeout=100):
1101 wl = self._write_lock
1102 if not wl:
1103 #cleaned up already
1104 return
1105 if not wl.acquire(False):
1106 if timeout<=0:
1107 log("flush_then_close: timeout waiting for the write lock")
1108 self.close()
1109 done()
1110 else:
1111 log("flush_then_close: write lock is busy, will retry %s more times", timeout)
1112 self.timeout_add(10, wait_for_write_lock, timeout-1)
1113 else:
1114 log("flush_then_close: acquired the write lock")
1115 #we have the write lock - we MUST free it!
1116 wait_for_queue()
1117 #normal codepath:
1118 # -> wait_for_write_lock
1119 # -> wait_for_queue
1120 # -> _add_chunks_to_queue
1121 # -> packet_queued
1122 # -> wait_for_packet_sent
1123 # -> close_and_release
1124 log("flush_then_close: wait_for_write_lock()")
1125 wait_for_write_lock()
1127 def close(self):
1128 log("Protocol.close() closed=%s, connection=%s", self._closed, self._conn)
1129 if self._closed:
1130 return
1131 self._closed = True
1132 self.idle_add(self._process_packet_cb, self, [Protocol.CONNECTION_LOST])
1133 c = self._conn
1134 if c:
1135 self._conn = None
1136 try:
1137 log("Protocol.close() calling %s", c.close)
1138 c.close()
1139 if self._log_stats is None and c.input_bytecount==0 and c.output_bytecount==0:
1140 #no data sent or received, skip logging of stats:
1141 self._log_stats = False
1142 if self._log_stats:
1143 from xpra.simple_stats import std_unit, std_unit_dec
1144 log.info("connection closed after %s packets received (%s bytes) and %s packets sent (%s bytes)",
1145 std_unit(self.input_packetcount), std_unit_dec(c.input_bytecount),
1146 std_unit(self.output_packetcount), std_unit_dec(c.output_bytecount)
1147 )
1148 except Exception:
1149 log.error("error closing %s", c, exc_info=True)
1150 self.terminate_queue_threads()
1151 self.idle_add(self.clean)
1152 log("Protocol.close() done")
1154 def steal_connection(self, read_callback=None):
1155 #so we can re-use this connection somewhere else
1156 #(frees all protocol threads and resources)
1157 #Note: this method can only be used with non-blocking sockets,
1158 #and if more than one packet can arrive, the read_callback should be used
1159 #to ensure that no packets get lost.
1160 #The caller must call wait_for_io_threads_exit() to ensure that this
1161 #class is no longer reading from the connection before it can re-use it
1162 assert not self._closed, "cannot steal a closed connection"
1163 if read_callback:
1164 self._read_queue_put = read_callback
1165 conn = self._conn
1166 self._closed = True
1167 self._conn = None
1168 if conn:
1169 #this ensures that we exit the untilConcludes() read/write loop
1170 conn.set_active(False)
1171 self.terminate_queue_threads()
1172 return conn
1174 def clean(self):
1175 #clear all references to ensure we can get garbage collected quickly:
1176 self._get_packet_cb = None
1177 self._encoder = None
1178 self._write_thread = None
1179 self._read_thread = None
1180 self._read_parser_thread = None
1181 self._write_format_thread = None
1182 self._process_packet_cb = None
1183 self._process_read = None
1184 self._read_queue_put = None
1185 self._compress = None
1186 self._write_lock = None
1187 self._source_has_more = None
1188 self._conn = None #should be redundant
1189 def noop(): # pragma: no cover
1190 pass
1191 self.source_has_more = noop
1194 def terminate_queue_threads(self):
1195 log("terminate_queue_threads()")
1196 #the format thread will exit:
1197 self._get_packet_cb = None
1198 self._source_has_more.set()
1199 #make all the queue based threads exit by adding the empty marker:
1200 #write queue:
1201 owq = self._write_queue
1202 self._write_queue = exit_queue()
1203 force_flush_queue(owq)
1204 #read queue:
1205 orq = self._read_queue
1206 self._read_queue = exit_queue()
1207 force_flush_queue(orq)
1208 #just in case the read thread is waiting again:
1209 self._source_has_more.set()