# compatibile Windows 11 # compatibile Ubuntu 24.10 # compatibile python 3.12.7 import socket import select import logging import threading import sys import errno import time import os from cryptography.hazmat.primitives.asymmetric import rsa, padding from cryptography.hazmat.primitives import hashes, serialization from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend import signal # Per la gestione dei segnali class UDPServer: def __init__(self, host, port, buffer_size=1024, timeout=5, fragment_timeout=10, max_fragment_size=486): """Inizializza.""" self.host = host self.port = port self.buffer_size = buffer_size self.socket_timeout = timeout self.fragment_timeout = fragment_timeout self.max_fragment_size = max_fragment_size self.udp_socket = None self.is_active = False self._stop_event = threading.Event() self.receiver_thread = None self.on_message = None self.logger = self._setup_logger() self._socket_lock = threading.Lock() self.client_message_fragment_buffers = {} self.private_key = rsa.generate_private_key(public_exponent=65537, key_size=4096) self.public_key = self.private_key.public_key() self.public_key_pem = self.public_key.public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo) self.client_public_keys = {} def _setup_logger(self): """Configura il logger.""" logger = logging.getLogger('UDPServer') logger.setLevel(logging.DEBUG) console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.DEBUG) console_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') console_handler.setFormatter(console_formatter) logger.addHandler(console_handler) return logger def start(self): """Avvia il server.""" if self.is_active: self.logger.warning("Server già attivo.") return try: with self._socket_lock: self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.udp_socket.settimeout(self.socket_timeout) self.udp_socket.bind((self.host, self.port)) self.is_active = True self._stop_event.clear() self.receiver_thread = threading.Thread(target=self._receive_loop, daemon=True) self.receiver_thread.start() self.logger.info(f"Server UDP avviato su {self.host}:{self.port}") except socket.error as e: self._handle_error(f"Errore socket durante l'avvio: {e}") self.close() except Exception as e: self._handle_error(f"Errore generico durante l'avvio: {e}") self.close() def _receive_loop(self): """Loop principale di ricezione.""" try: while not self._stop_event.is_set(): with self._socket_lock: if self.udp_socket is None: break ready = select.select([self.udp_socket], [], [], self.socket_timeout) if ready[0]: try: with self._socket_lock: if self.udp_socket is None: break data, client_address = self.udp_socket.recvfrom(self.buffer_size) if not data: self.logger.warning("Ricevuto pacchetto vuoto.") continue self._process_received_data(client_address, data) except socket.timeout: continue except socket.error as e: if e.errno == errno.ECONNREFUSED: self.logger.warning(f"Nessun host in ascolto: {e}") else: self._handle_error(f"Errore socket: {e}") except Exception as e: self._handle_error(f"Errore inatteso in _receive_loop: {e}") break # Esci dal ciclo in caso di errore inatteso else: self._cleanup_fragment_buffers() except Exception as ex: self._handle_error(f"Errore nel thread di ricezione: {ex}") def _cleanup_fragment_buffers(self): """Pulisce i buffer scaduti.""" clients_to_cleanup = [] for client_address, message_buffers in self.client_message_fragment_buffers.items(): message_ids_to_cleanup = [] for message_id, buffer_info in message_buffers.items(): if 'timer' in buffer_info and buffer_info['timer'] and not buffer_info['timer'].is_alive(): self.logger.debug(f"Timeout buffer frammenti msg_id {message_id} da {client_address}.") message_ids_to_cleanup.append(message_id) for message_id in message_ids_to_cleanup: del message_buffers[message_id] if not message_buffers: clients_to_cleanup.append(client_address) for client_address in clients_to_cleanup: del self.client_message_fragment_buffers[client_address] def _reset_fragment_timer(self, client_address, message_id): """Resetta il timer.""" if client_address in self.client_message_fragment_buffers and message_id in self.client_message_fragment_buffers[client_address] and 'timer' in self.client_message_fragment_buffers[client_address][message_id] and self.client_message_fragment_buffers[client_address][message_id]['timer']: self.client_message_fragment_buffers[client_address][message_id]['timer'].cancel() timer = threading.Timer(self.fragment_timeout, self._timeout_fragment_buffer, args=[client_address, message_id]) self.client_message_fragment_buffers[client_address][message_id]['timer'] = timer timer.start() def _timeout_fragment_buffer(self, client_address, message_id): """Gestisce il timeout.""" if client_address in self.client_message_fragment_buffers and message_id in self.client_message_fragment_buffers[client_address]: self.logger.warning(f"Timeout frammenti msg_id {message_id} da {client_address}.") del self.client_message_fragment_buffers[client_address][message_id] if not self.client_message_fragment_buffers[client_address]: del self.client_message_fragment_buffers[client_address] def _process_received_data(self, client_address, data): """Processa i dati grezzi ricevuti.""" self.logger.debug(f"Ricevuto dati da {client_address}: {data[:100]}...") # Gestisci la frammentazione *PRIMA* di K e R e dell'estrazione del TLV if len(data) >= 7: # Assicurati che ci siano abbastanza byte per l'intestazione (7 byte ora) msg_id = int.from_bytes(data[:4], byteorder='big') frag_num = data[4] total_frags = int.from_bytes(data[5:7], byteorder='big') payload = data[7:] if msg_id > 0 and total_frags > 1: self._process_fragment(client_address, msg_id, frag_num, total_frags, payload) return else: #Messaggio non frammentato data = data[7:] #Rimuovi l'header # Estrai TLV *DOPO* la gestione della frammentazione. try: type = data[0:1] length_bytes = data[1:3] length = int.from_bytes(length_bytes, byteorder='big', signed=False) value = data[3:3 + length] except Exception as e: self._handle_error(f"Errore durante il parsing TLV iniziale nel server: {e}") return # Gestisci 'K' e 'R' *DOPO* la frammentazione e l'estrazione del TLV. if type == b'K': self.logger.debug(f"Rilevata chiave pubblica da {client_address}.") try: public_key = serialization.load_pem_public_key(value, backend=default_backend()) self.client_public_keys[client_address] = public_key self.logger.info(f"Chiave pubblica ricevuta e caricata da {client_address}.") return except ValueError as e: self._handle_error(f"Errore nella lettura della chiave da {client_address}: {e}") return elif type == b'R': self.logger.debug(f"Ricevuta richiesta chiave pubblica (tipo R) da {client_address}") if value == b"REQ_PUB_KEY": self.logger.info(f"Richiesta chiave pubblica da {client_address}") self.send_tlv(client_address, b'K', self.public_key_pem) return # Se arrivi qui, il messaggio NON e' un frammento. # Decritta PRIMA di processare il TLV (se necessario e se NON e' K o R). decrypted_data = self._decrypt_message(client_address, data) if decrypted_data is not None: self._process_complete_message(client_address, decrypted_data) #Passa solo data else: return #Se la decrittazione fallisce e ritorna None def _process_fragment(self, client_address, message_id, fragment_number, total_fragments, payload): """Gestisce un frammento.""" if client_address not in self.client_message_fragment_buffers: self.client_message_fragment_buffers[client_address] = {} if message_id not in self.client_message_fragment_buffers[client_address]: self.client_message_fragment_buffers[client_address][message_id] = { 'fragments': {}, 'total_fragments': total_fragments, 'timer': None } message_buffer = self.client_message_fragment_buffers[client_address][message_id] # --- MODIFICA QUI: Decritta il frammento *PRIMA* di aggiungerlo al buffer --- decrypted_payload = self._decrypt_message(client_address, payload) if decrypted_payload is None: # Gestisci l'errore di decrittazione, ad es. scartando il frammento self.logger.error(f"Errore di decrittazione frammento {fragment_number} di {message_id} da {client_address}. Frammento scartato.") return # <--- Importante: esci se la decrittazione fallisce message_buffer['fragments'][fragment_number] = decrypted_payload # Salva il frammento *decrittato* if len(message_buffer['fragments']) == total_fragments: self.logger.debug(f"Tutti frammenti ricevuti per msg_id {message_id} da {client_address}.") reassembled = bytearray() for i in range(1, total_fragments + 1): if i not in message_buffer['fragments']: self.logger.error(f"Frammento #{i} mancante per msg_id {message_id} da {client_address}!") del self.client_message_fragment_buffers[client_address][message_id] return reassembled.extend(message_buffer['fragments'][i]) # Chiamata corretta a _process_complete_message: self._process_complete_message(client_address, bytes(reassembled)) # <--- CORREZIONE QUI del self.client_message_fragment_buffers[client_address][message_id] else: self._reset_fragment_timer(client_address, message_id) self.logger.debug(f"Frammento #{fragment_number}/{total_fragments} per msg_id {message_id} da {client_address}.") def _process_complete_message(self, client_address, data): """Processa un messaggio completo (dopo eventuale decrittazione e riassemblaggio).""" # Estrai TLV *DOPO* la decrittazione (o se non era criptato). try: type = data[0:1] length = int.from_bytes(data[1:3], byteorder='big') value = data[3:3 + length] if type == b'M': message = value.decode('utf-8') self.logger.info(f"Messaggio da {client_address}: {message[:50]}...") print(f"Messaggio ASCII da {client_address}: {value.decode('utf-8', errors='replace')}") # Stampa leggibile if self.on_message: self.on_message(message, client_address) self.send_tlv(client_address, b'A', b"ACK") # Invia ACK elif type == b'A': # Gestione ACK (opzionale) - per ora non facciamo nulla pass else: self.logger.warning(f"Tipo TLV sconosciuto da {client_address}: {type}") except (ValueError, IndexError, UnicodeDecodeError) as e: self.logger.error(f"Errore nell'estrazione o decodifica TLV da {client_address}: {e}") except Exception as e: # Cattura altre possibili eccezioni self.logger.error(f"Errore imprevisto in _process_complete_message da {client_address}: {e}") def _decrypt_message(self, client_address, message_bytes): """Decripta il messaggio (o frammento).""" if client_address in self.client_public_keys: try: decrypted = self.private_key.decrypt( message_bytes, padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None) ) return decrypted except (ValueError, InvalidSignature): #Gestione di InvalidSignature self.logger.debug(f"Impossibile decrittare da {client_address} (potrebbe non essere criptato o firma non valida).") return message_bytes #Restituisce comunque i byte per gestire anche messaggi non criptati except Exception as e: self.logger.error(f"Errore imprevisto durante la decrittazione da {client_address}: {e}") return None # In caso di errore imprevisto, restituisci None else: self.logger.warning(f"Chiave pubblica non disponibile per {client_address}.") return None def close(self): """Chiude il server.""" if not self.is_active: return self.is_active = False self._stop_event.set() # Imposta l'evento di stop try: with self._socket_lock: if self.udp_socket: self.udp_socket.setblocking(0) self.udp_socket.close() self.udp_socket = None except socket.error as e: self._handle_error(f"Errore socket in chiusura: {e}") except Exception as e: self._handle_error(f"Errore in chiusura: {e}") finally: self.logger.info("Server UDP arrestato.") def send_tlv(self, client_address, type, value): """Invia un messaggio TLV al client.""" if not self.is_active: self._handle_error("Server non attivo.") return if isinstance(value, str): value = value.encode('utf-8') length = len(value) length_bytes = length.to_bytes(2, byteorder='big') tlv_message = type + length_bytes + value if type != b'K' and type != b'R': if client_address in self.client_public_keys: try: encrypted_message = self.client_public_keys[client_address].encrypt( tlv_message, padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None) ) message_to_send = encrypted_message except ValueError as e: self.logger.error(f"Errore durante la crittografia per {client_address}: {e}") return else: self.logger.warning(f"Chiave pubblica non trovata per {client_address}. Invio non criptato.") message_to_send = tlv_message else: message_to_send = tlv_message try: with self._socket_lock: if self.udp_socket: self.udp_socket.sendto(message_to_send, client_address) self.logger.debug(f"Inviato a {client_address}: {message_to_send[:50]}...") if type == b'K': self.logger.debug(f"Invio Server Public Key a {client_address}") else: self._handle_error("Socket non inizializzato.") except (socket.error, ValueError) as e: self._handle_error(f"Errore invio a {client_address}: {e}") except Exception as e: self._handle_error(f"Errore generico invio a {client_address}: {e}") def _handle_error(self, message): """Gestisce gli errori.""" self.logger.error(message) def stop(self): # Metodo per fermare esplicitamente il server """Ferma il server (metodo pubblico).""" self.close() #self.logger.info("Server stoppato da chiamata esterna") # --> NON serve più def print_message_callback(messaggio, indirizzo): """Callback di esempio.""" print(f"Callback: Messaggio: {messaggio}, da: {indirizzo}") # --- Funzione per gestire SIGINT (Ctrl+C) --- def signal_handler(sig, frame): print("Ricevuto SIGINT (Ctrl+C). Arresto in corso...") server.stop() # Chiama .stop() direttamente # --- MAIN --- if __name__ == "__main__": host = "localhost" port = 5001 print(f"Server in ascolto su {host}:{port}") server = UDPServer(host, port, fragment_timeout=5) server.on_message = print_message_callback # Imposta la callback # Imposta il signal handler signal.signal(signal.SIGINT, signal_handler) server.start() try: while not server._stop_event.is_set(): # Usa _stop_event time.sleep(1) except KeyboardInterrupt: # Questa parte non dovrebbe più essere necessaria print("Arresto in corso (KB)...") server.stop() #Ridondante, ma non fa male finally: if server.is_active: #Chiamata finale a close per sicurezza server.close()