diff --git a/slips/main.py b/slips/main.py index a4bbecca5..e6aea4f4c 100644 --- a/slips/main.py +++ b/slips/main.py @@ -411,7 +411,7 @@ def get_analyzed_flows_percentage(self) -> str: if not hasattr(self, "total_flows"): self.total_flows = self.db.get_total_flows() - processed = self.db.get_flow_analyzed_by_the_profiler_so_far() + processed = self.db.get_flows_analyzed_by_the_profiler_so_far() if not processed: return "" try: diff --git a/slips_files/core/database/database_manager.py b/slips_files/core/database/database_manager.py index 6f41abefc..be959d22f 100644 --- a/slips_files/core/database/database_manager.py +++ b/slips_files/core/database/database_manager.py @@ -307,6 +307,9 @@ def get_disabled_modules(self, *args, **kwargs): def increment_profiler_workers_started(self, *args, **kwargs): return self.rdb.increment_profiler_workers_started(*args, **kwargs) + def decrement_profiler_workers_started(self, *args, **kwargs): + return self.rdb.decrement_profiler_workers_started(*args, **kwargs) + def get_profiler_workers_started(self, *args, **kwargs): return self.rdb.get_profiler_workers_started(*args, **kwargs) @@ -1149,7 +1152,7 @@ def get_info_about_icmp_flows_using_sport(self, *args, **kwargs): def increment_processed_flows(self, *args, **kwargs): return self.rdb.increment_processed_flows(*args, **kwargs) - def get_flow_analyzed_by_the_profiler_so_far(self, *args, **kwargs): + def get_flows_analyzed_by_the_profiler_so_far(self, *args, **kwargs): return self.rdb.get_flow_analyzed_by_the_profiler_so_far( *args, **kwargs ) diff --git a/slips_files/core/database/redis_db/database.py b/slips_files/core/database/redis_db/database.py index 708ea7248..d4bd46734 100644 --- a/slips_files/core/database/redis_db/database.py +++ b/slips_files/core/database/redis_db/database.py @@ -2008,6 +2008,9 @@ def increment_profiler_workers_started(self) -> int: """increments the number of profiler workers started""" return self.r.incr(self.constants.PROFILER_WORKERS_STARTED, 1) + def decrement_profiler_workers_started(self) -> int: + return self.r.incr(self.constants.PROFILER_WORKERS_STARTED, -1) + def get_profiler_workers_started(self) -> int: """returns number of profiler workers started so far""" count = self.r.get(self.constants.PROFILER_WORKERS_STARTED) diff --git a/slips_files/core/profiler.py b/slips_files/core/profiler.py index 8e06ef569..82c543c59 100644 --- a/slips_files/core/profiler.py +++ b/slips_files/core/profiler.py @@ -18,8 +18,6 @@ import queue import multiprocessing import time -import threading -from multiprocessing import Process from multiprocessing.synchronize import Event, Semaphore from typing import ( List, @@ -42,7 +40,7 @@ from slips_files.core.input_profilers.nfdump import Nfdump from slips_files.core.input_profilers.suricata import Suricata from slips_files.core.input_profilers.zeek import ZeekJSON, ZeekTabs -from slips_files.core.profiler_worker import ProfilerWorker +from slips_files.core.worker_manager_mixin import WorkerManagerMixin SUPPORTED_INPUT_TYPES = { InputType.ZEEK: ZeekJSON, @@ -62,7 +60,7 @@ } -class Profiler(ICore, IObservable): +class Profiler(WorkerManagerMixin, ICore, IObservable): """A class to create the profiles for IPs""" name = "profiler" @@ -107,16 +105,7 @@ def init( # is set by input to indicate it stopped because of a failure self.is_input_failed_event: Optional[Event] = is_input_failed_event self.input_handler_obj = None - # to close them on shutdown - self.profiler_child_processes: List[Process] = [] - # to access their internal attributes if needed - self.workers: List[ProfilerWorker] = [] - # is set by this module to indicate to the monitor thread that - # workers stoppped. - self.did_all_workers_stop: Event = multiprocessing.Event() - self.last_worker_id = -1 - # max parallel profiler workers to start when high throughput is detected - self.max_workers = 6 + self.init_worker_manager() # 30MBs max size of this queue to avoid growing forever in mem self.aid_queue = multiprocessing.Queue(maxsize=30000000) # This starts a process that handles calculatng aid hash and stores @@ -126,13 +115,6 @@ def init( self.db, self.aid_queue, ) - now = time.monotonic() - self.next_throughput_check_time = now + 300 - self.profiler_monitor_thread = threading.Thread( - target=self._run_profiler_monitor_loop, - name="profiler_monitor_loop", - daemon=True, - ) def subscribe_to_channels(self): self.channels = {} @@ -176,23 +158,6 @@ def get_input_type(self, line: dict, input_type: str) -> str: # binetflow, binetflow tabs, nfdump, suricata return input_type - def stop_profiler_workers(self): - """ - wait as long as needed foreach worker to stop - """ - # ensure we don't block forever waiting for workers that will never - # receive the stop sentinel - if self.is_input_done_event is not None: - self.is_input_done_event.wait() - - for process in self.profiler_child_processes: - try: - process.join() - except (OSError, ChildProcessError): - pass - - self.did_all_workers_stop.set() - def mark_self_as_done_processing(self) -> None: """ is called to mark this process as done processing so @@ -223,30 +188,6 @@ def get_msg_from_queue(self, q: multiprocessing.Queue): except Exception: return None - def start_profiler_worker(self, worker_id: int = None): - """starts A profiler worker for faster processing of the flows""" - worker_name = f"profiler_worker_process_{worker_id}" - worker = ProfilerWorker( - logger=self.logger, - output_dir=self.parent_output_dir, - redis_port=self.redis_port, - termination_event=self.termination_event, - conf=self.conf, - ppid=self.ppid, - slips_args=self.args, - bloom_filters_manager=self.bloom_filters, - # module specific kwargs - name=worker_name, - profiler_queue=self.profiler_queue, - input_handler=self.input_handler_obj, - aid_queue=self.aid_queue, - aid_manager=self.aid_manager, - is_input_done_event=self.is_input_done_event, - ) - worker.start() - self.profiler_child_processes.append(worker) - self.db.increment_profiler_workers_started() - def get_handler_obj( self, first_msg: dict ) -> ZeekTabs | ZeekJSON | Argus | Suricata | ZeekTabs | Nfdump: @@ -307,70 +248,11 @@ def shutdown_gracefully(self): self.mark_self_as_done_processing() self.db.set_new_incoming_flows(False) - def did_5min_pass_since_last_throughput_check(self) -> bool: - """ - returns true if 5 mins passed since the last time we checked - the flows read per second - """ - now = time.monotonic() - if now < self.next_throughput_check_time: - return False - - # Advance in 5-min steps to reduce drift on long delays. - while self.next_throughput_check_time <= now: - self.next_throughput_check_time += 300 - return True - - def max_workers_started(self) -> bool: - """ - returns true if the maximum number of profiler workers - is already started - """ - # bc workers start from 0 - if self.last_worker_id + 1 >= self.max_workers: - return True - return False - - def _check_if_high_throughput_and_add_workers(self): - """ - Checks for input and profile flows/sec imbalance and adds more - profiler workers if needed. - """ - if self.max_workers_started(): - return - - if not self.did_5min_pass_since_last_throughput_check(): - return - - profiler_fps = self.db.get_core_module_flows_per_second(self.name) or 0 - input_fps = self.db.get_core_module_flows_per_second("Input") or 0 - if float(input_fps) > ( - float(profiler_fps) * 1.1 - ): # 10% more input fps than profiler fps - worker_id = self.last_worker_id + 1 - self.start_profiler_worker(worker_id) - self.last_worker_id = worker_id - self.print( - f"Warning: High throughput detected. Started " - f"additional worker: " - f"profiler_worker_{worker_id} to handle the flows." - ) - - if self.last_worker_id == self.max_workers - 1: - self.print( - f"Maximum number of profiler workers " - f"({self.max_workers}) started." - ) - def pre_main(self): client_ips = [str(ip) for ip in self.client_ips] if client_ips: self.print(f"Used client IPs: {green(', '.join(client_ips))}") - def _update_lines_read_by_all_workers(self): - # needed by store_flows_read_per_second() - self.lines = sum([worker.received_lines for worker in self.workers]) - def should_stop(self): """ overrides IModule.should_stop(). @@ -381,18 +263,6 @@ def should_stop(self): """ return self.stop_other_workers.is_set() - def _run_profiler_monitor_loop(self): - """ - Does necessary monitoring and stats updating for the profiler while - the workers are - running. - """ - while not self.did_all_workers_stop.is_set(): - self._update_lines_read_by_all_workers() - # implemented in icore.py - self.store_flows_read_per_second() - self._check_if_high_throughput_and_add_workers() - def _is_input_done(self) -> bool: return ( self.is_input_done_event is not None @@ -443,7 +313,6 @@ def main(self): self.db ) else: - self.input_handler_obj = self.get_handler_obj(msg) # put again that msg in queue to be processed by the profilers, # we just checked it here to determine the input handler obj @@ -461,8 +330,7 @@ def main(self): # slips starts with these workers by default until it detects # high throughput that these workers arent enough to handle - num_of_initial_profiler_workers = 3 - for worker_id in range(num_of_initial_profiler_workers): + for worker_id in range(self.num_of_initial_profiler_workers): self.last_worker_id = worker_id self.start_profiler_worker(worker_id) diff --git a/slips_files/core/worker_manager_mixin.py b/slips_files/core/worker_manager_mixin.py new file mode 100644 index 000000000..139000dfd --- /dev/null +++ b/slips_files/core/worker_manager_mixin.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: 2021 Sebastian Garcia +# SPDX-License-Identifier: GPL-2.0-only +import multiprocessing +import time +import threading +from multiprocessing import Process +from typing import List, Optional + +from slips_files.core.profiler_worker import ProfilerWorker + +FIVE_MINS = 300 + + +class WorkerManagerMixin: + """ + Contains all logic for managing, terminating, increasing and decreasing + workers, etc. + """ + + def init_worker_manager(self) -> None: + """ + Initialize profiler worker manager state. + + Return: + None. + """ + # to close them on shutdown + self.profiler_child_processes: List[Process] = [] + # to access their internal attributes if needed + self.workers: List[ProfilerWorker] = [] + # is set by this module to indicate to the monitor thread that + # workers stopped. + self.did_all_workers_stop = multiprocessing.Event() + self.last_worker_id = -1 + self.active_profiler_workers = 0 + self.num_of_initial_profiler_workers = 3 + # max parallel profiler workers to start when high throughput is + # detected + self.max_workers = 6 + now = time.monotonic() + self.next_throughput_check_time = now + FIVE_MINS + self.next_worker_decrease_check_time = now + FIVE_MINS + self.profiler_monitor_thread = threading.Thread( + target=self._run_profiler_workers_manager_loop, + name="profiler_monitor_loop", + daemon=True, + ) + + def stop_profiler_workers(self) -> None: + """ + Wait as long as needed for each worker to stop. + + Return: + None. + """ + # ensure we don't block forever waiting for workers that will never + # receive the stop sentinel + if self.is_input_done_event is not None: + self.is_input_done_event.wait() + + for process in self.profiler_child_processes: + try: + process.join() + except (OSError, ChildProcessError): + pass + + self.did_all_workers_stop.set() + + def start_profiler_worker(self, worker_id: Optional[int] = None) -> None: + """ + Start a profiler worker for faster processing of the flows. + + Parameters: + worker_id: The identifier to include in the worker process name. + + Return: + None. + """ + worker_name = f"profiler_worker_process_{worker_id}" + worker = ProfilerWorker( + logger=self.logger, + output_dir=self.parent_output_dir, + redis_port=self.redis_port, + termination_event=self.termination_event, + conf=self.conf, + ppid=self.ppid, + slips_args=self.args, + bloom_filters_manager=self.bloom_filters, + # module specific kwargs + name=worker_name, + profiler_queue=self.profiler_queue, + input_handler=self.input_handler_obj, + aid_queue=self.aid_queue, + aid_manager=self.aid_manager, + is_input_done_event=self.is_input_done_event, + ) + worker.start() + self.profiler_child_processes.append(worker) + self.workers.append(worker) + self.active_profiler_workers += 1 + self.db.increment_profiler_workers_started() + + def did_5min_pass_since_last_throughput_check(self) -> bool: + """ + Return whether 5 minutes passed since the last throughput check. + + Return: + True when throughput should be checked. + """ + now = time.monotonic() + if now < self.next_throughput_check_time: + return False + + while self.next_throughput_check_time <= now: + self.next_throughput_check_time += FIVE_MINS + return True + + def did_5min_pass_since_last_worker_decrease_check(self) -> bool: + """ + Return whether 5 minutes passed since the last worker decrease check. + + Return: + True when worker decrease should be checked. + """ + now = time.monotonic() + if now < self.next_worker_decrease_check_time: + return False + + while self.next_worker_decrease_check_time <= now: + self.next_worker_decrease_check_time += FIVE_MINS + return True + + def max_workers_started(self) -> bool: + """ + Return whether the maximum number of profiler workers is started. + + Return: + True when no more profiler workers should be started. + """ + if self.active_profiler_workers >= self.max_workers: + return True + return False + + def is_the_min_number_of_workers_active(self) -> bool: + return ( + self.active_profiler_workers + <= self.num_of_initial_profiler_workers + ) + + def _get_flows_per_second(self, module_name: str) -> float: + """ + Get the latest stored flows per second for a core module. + + Parameters: + module_name: The core module name. + + Return: + The module flow rate as a float. + """ + try: + return float( + self.db.get_core_module_flows_per_second(module_name) or 0 + ) + except (TypeError, ValueError): + return 0 + + def _check_if_high_throughput_and_add_workers(self) -> None: + """ + Check for input and profile flows/sec imbalance and add workers. + + Return: + None. + """ + if self.max_workers_started(): + return + + if not self.did_5min_pass_since_last_throughput_check(): + return + + profiler_fps = self._get_flows_per_second(self.name) + input_fps = self._get_flows_per_second("input") + if input_fps > (profiler_fps * 1.1): + worker_id = self.last_worker_id + 1 + self.start_profiler_worker(worker_id) + self.last_worker_id = worker_id + self.print( + f"Warning: High throughput detected. Started " + f"additional worker: " + f"profiler_worker_{worker_id} to handle the flows." + ) + + if self.last_worker_id == self.max_workers - 1: + self.print( + f"Maximum number of profiler workers " + f"({self.max_workers}) started." + ) + + def _update_lines_read_by_all_workers(self) -> None: + """ + Update the number of lines read by all workers. + """ + # needed by store_flows_read_per_second() + self.lines = sum([worker.received_lines for worker in self.workers]) + + def _run_profiler_workers_manager_loop(self) -> None: + """ + Monitor profiler workers and update profiler stats while they run. + """ + while not self.did_all_workers_stop.is_set(): + self._update_lines_read_by_all_workers() + # implemented in icore.py + self.store_flows_read_per_second() + self._check_if_high_throughput_and_add_workers() + self._check_if_stabled_throughput_and_remove_workers() + + def _check_if_stabled_throughput_and_remove_workers(self) -> None: + """ + Remove one extra worker when profiler throughput has stabilized. + """ + if self.is_the_min_number_of_workers_active(): + # can't decrese more than that + return + + if not self.did_5min_pass_since_last_worker_decrease_check(): + return + + profiler_fps = self._get_flows_per_second(self.name) + input_fps = self._get_flows_per_second("input") + + if profiler_fps < input_fps: + # still under high throughput + return + + self.profiler_queue.put("stop") + self.active_profiler_workers -= 1 + self.last_worker_id -= 1 + self.print( + "Stable throughput detected. Requested one additional " + "profiler worker to stop." + ) + self.db.decrement_profiler_workers_started() diff --git a/tests/unit/slips_files/core/test_profiler.py b/tests/unit/slips_files/core/test_profiler.py index 6ef3cd789..a18414053 100644 --- a/tests/unit/slips_files/core/test_profiler.py +++ b/tests/unit/slips_files/core/test_profiler.py @@ -225,7 +225,7 @@ def test_notify_observers_with_correct_message(): observer_mock.update.assert_called_once_with(test_msg) -@patch("slips_files.core.profiler.ProfilerWorker") +@patch("slips_files.core.worker_manager_mixin.ProfilerWorker") def test_start_profiler_worker_uses_parent_output_dir(mock_worker_cls): profiler = ModuleFactory().create_profiler_obj() worker = mock_worker_cls.return_value @@ -257,5 +257,6 @@ def test_start_profiler_worker_uses_parent_output_dir(mock_worker_cls): ) worker.start.assert_called_once() assert profiler.profiler_child_processes == [worker] - assert profiler.workers == [] + assert profiler.workers == [worker] + assert profiler.active_profiler_workers == 1 profiler.db.increment_profiler_workers_started.assert_called_once() diff --git a/tests/unit/slips_files/core/test_worker_manager_mixin.py b/tests/unit/slips_files/core/test_worker_manager_mixin.py new file mode 100644 index 000000000..34db1dae6 --- /dev/null +++ b/tests/unit/slips_files/core/test_worker_manager_mixin.py @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: 2021 Sebastian Garcia +# SPDX-License-Identifier: GPL-2.0-only +"""Unit tests for the profiler worker manager mixin.""" + +from unittest.mock import Mock, call + +import pytest + +from tests.module_factory import ModuleFactory + + +@pytest.mark.parametrize( + "last_worker_id, should_remove", + [ + (2, False), + (3, True), + ], +) +def test_check_if_stabled_throughput_only_removes_extra_workers( + last_worker_id: int, + should_remove: bool, +) -> None: + """ + Test that stable throughput never removes initial profiler workers. + + Parameters: + last_worker_id: The last started profiler worker identifier. + should_remove: Whether a worker should be removed. + + Return: + None. + """ + profiler = ModuleFactory().create_profiler_obj() + profiler.last_worker_id = last_worker_id + profiler.active_profiler_workers = last_worker_id + 1 + profiler.profiler_queue = Mock() + profiler.print = Mock() + profiler.did_5min_pass_since_last_worker_decrease_check = Mock( + return_value=True + ) + profiler.db.get_core_module_flows_per_second.side_effect = [100, 80] + + profiler._check_if_stabled_throughput_and_remove_workers() + + if should_remove: + profiler.profiler_queue.put.assert_called_once_with("stop") + assert profiler.active_profiler_workers == last_worker_id + assert profiler.last_worker_id == last_worker_id - 1 + profiler.print.assert_called_once() + else: + profiler.profiler_queue.put.assert_not_called() + assert profiler.active_profiler_workers == last_worker_id + 1 + assert profiler.last_worker_id == last_worker_id + profiler.print.assert_not_called() + + +@pytest.mark.parametrize( + "profiler_fps, input_fps, should_remove", + [ + (100, 100, True), + (100, 80, True), + (80, 100, False), + ], +) +def test_check_if_stabled_throughput_compares_rates( + profiler_fps: int, + input_fps: int, + should_remove: bool, +) -> None: + """ + Test that stable throughput removes a worker only when profiler keeps up. + + Parameters: + profiler_fps: The profiler flows per second. + input_fps: The input flows per second. + should_remove: Whether a worker should be removed. + + Return: + None. + """ + profiler = ModuleFactory().create_profiler_obj() + profiler.last_worker_id = 3 + profiler.active_profiler_workers = 4 + profiler.profiler_queue = Mock() + profiler.print = Mock() + profiler.did_5min_pass_since_last_worker_decrease_check = Mock( + return_value=True + ) + profiler.db.get_core_module_flows_per_second.side_effect = [ + profiler_fps, + input_fps, + ] + + profiler._check_if_stabled_throughput_and_remove_workers() + + if should_remove: + profiler.profiler_queue.put.assert_called_once_with("stop") + assert profiler.active_profiler_workers == 3 + assert profiler.last_worker_id == 2 + else: + profiler.profiler_queue.put.assert_not_called() + assert profiler.active_profiler_workers == 4 + assert profiler.last_worker_id == 3 + + +def test_check_if_stabled_throughput_waits_for_interval() -> None: + """ + Test that workers are not removed before the decrease interval elapses. + + Return: + None. + """ + profiler = ModuleFactory().create_profiler_obj() + profiler.last_worker_id = 3 + profiler.active_profiler_workers = 4 + profiler.profiler_queue = Mock() + profiler.did_5min_pass_since_last_worker_decrease_check = Mock( + return_value=False + ) + + profiler._check_if_stabled_throughput_and_remove_workers() + + profiler.profiler_queue.put.assert_not_called() + profiler.db.get_core_module_flows_per_second.assert_not_called() + assert profiler.active_profiler_workers == 4 + assert profiler.last_worker_id == 3 + + +def test_worker_scaling_uses_lowercase_input_metric_key() -> None: + """ + Test that worker scaling reads the stored lowercase input metric key. + + Return: + None. + """ + profiler = ModuleFactory().create_profiler_obj() + profiler.last_worker_id = 2 + profiler.active_profiler_workers = 3 + profiler.start_profiler_worker = Mock() + profiler.print = Mock() + profiler.did_5min_pass_since_last_throughput_check = Mock( + return_value=True + ) + profiler.db.get_core_module_flows_per_second.side_effect = [10, 100] + + profiler._check_if_high_throughput_and_add_workers() + + profiler.db.get_core_module_flows_per_second.assert_has_calls( + [call("profiler"), call("input")] + ) + profiler.start_profiler_worker.assert_called_once_with(3) + + +def test_update_lines_read_sums_worker_received_lines() -> None: + """ + Test that profiler throughput sums lines received by all workers. + + Return: + None. + """ + profiler = ModuleFactory().create_profiler_obj() + profiler.workers = [ + Mock(received_lines=10), + Mock(received_lines=20), + Mock(received_lines=12), + ] + + profiler._update_lines_read_by_all_workers() + + assert profiler.lines == 42 + + +@pytest.mark.parametrize( + "stored_value, expected_fps", + [ + ("42", 42.0), + (None, 0), + ("invalid", 0), + ], +) +def test_get_flows_per_second_handles_database_values( + stored_value: str | None, + expected_fps: float, +) -> None: + """ + Test that stored flow rates are converted defensively. + + Parameters: + stored_value: The value returned by the database. + expected_fps: The expected converted flows per second. + + Return: + None. + """ + profiler = ModuleFactory().create_profiler_obj() + profiler.db.get_core_module_flows_per_second.return_value = stored_value + + assert profiler._get_flows_per_second("Input") == expected_fps