diff --git a/CHANGELOG.md b/CHANGELOG.md index f0af6dffc..d17592969 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Can be disabled with `--disable-gc` - Alias for `merlin database` command so it can be called with `merlin db` - Status of run entities in the database (this will differ from task statuses) +- New classes for formatting `query-workers` output: + - `WorkerFormatter`: base class for defining formatted output for workers + - `RichWorkerFormatter`: implementation of `WorkerFormatter` for output using the rich library + - `JSONWorkerFormatter`: implementation of `WorkerFormatter` for JSON output + - `WorkerFormatterFactory`: factory class for selecting the desired worker formatter + +### Changed +- Changes to the `query-workers` command: + - Output now displays a lot more information, including logical and physical worker specific info + - Output now formatted using rich tables or json + - Now handled through worker classes rather than functions in `celeryadapter.py` file + - Behind the scenes this is now querying the new Merlin Database ## [2.0.0b2] diff --git a/docs/assets/images/monitoring/queues_and_workers/connected-workers.png b/docs/assets/images/monitoring/queues_and_workers/connected-workers.png deleted file mode 100644 index 2f030aa71..000000000 Binary files a/docs/assets/images/monitoring/queues_and_workers/connected-workers.png and /dev/null differ diff --git a/docs/assets/images/monitoring/queues_and_workers/no-connected-workers.png b/docs/assets/images/monitoring/queues_and_workers/no-connected-workers.png deleted file mode 100644 index 1dec3631d..000000000 Binary files a/docs/assets/images/monitoring/queues_and_workers/no-connected-workers.png and /dev/null differ diff --git a/docs/assets/images/monitoring/queues_and_workers/query-workers-queues-all-workers.png b/docs/assets/images/monitoring/queues_and_workers/query-workers-queues-all-workers.png new file mode 100644 index 000000000..be35e264d Binary files /dev/null and b/docs/assets/images/monitoring/queues_and_workers/query-workers-queues-all-workers.png differ diff --git a/docs/assets/images/monitoring/queues_and_workers/query-workers-queues-option.png b/docs/assets/images/monitoring/queues_and_workers/query-workers-queues-option.png new file mode 100644 index 000000000..8d0365cfd Binary files /dev/null and b/docs/assets/images/monitoring/queues_and_workers/query-workers-queues-option.png differ diff --git a/docs/assets/images/monitoring/queues_and_workers/query-workers-spec-all-workers.png b/docs/assets/images/monitoring/queues_and_workers/query-workers-spec-all-workers.png index 8620ef77f..c4bad03bd 100644 Binary files a/docs/assets/images/monitoring/queues_and_workers/query-workers-spec-all-workers.png and b/docs/assets/images/monitoring/queues_and_workers/query-workers-spec-all-workers.png differ diff --git a/docs/assets/images/monitoring/queues_and_workers/query-workers-spec-option.png b/docs/assets/images/monitoring/queues_and_workers/query-workers-spec-option.png index 8a50700de..5875a4826 100644 Binary files a/docs/assets/images/monitoring/queues_and_workers/query-workers-spec-option.png and b/docs/assets/images/monitoring/queues_and_workers/query-workers-spec-option.png differ diff --git a/docs/assets/images/monitoring/queues_and_workers/query-workers-worker-entities-do-not-exist.png b/docs/assets/images/monitoring/queues_and_workers/query-workers-worker-entities-do-not-exist.png new file mode 100644 index 000000000..833c298b6 Binary files /dev/null and b/docs/assets/images/monitoring/queues_and_workers/query-workers-worker-entities-do-not-exist.png differ diff --git a/docs/assets/images/monitoring/queues_and_workers/query-workers-worker-entities-exist.png b/docs/assets/images/monitoring/queues_and_workers/query-workers-worker-entities-exist.png new file mode 100644 index 000000000..5d27f5239 Binary files /dev/null and b/docs/assets/images/monitoring/queues_and_workers/query-workers-worker-entities-exist.png differ diff --git a/docs/assets/images/monitoring/queues_and_workers/query-workers-workers-option.png b/docs/assets/images/monitoring/queues_and_workers/query-workers-workers-option.png new file mode 100644 index 000000000..6fadf0593 Binary files /dev/null and b/docs/assets/images/monitoring/queues_and_workers/query-workers-workers-option.png differ diff --git a/docs/assets/images/monitoring/queues_and_workers/queues-example-all-workers.png b/docs/assets/images/monitoring/queues_and_workers/queues-example-all-workers.png deleted file mode 100644 index fdba79fcd..000000000 Binary files a/docs/assets/images/monitoring/queues_and_workers/queues-example-all-workers.png and /dev/null differ diff --git a/docs/assets/images/monitoring/queues_and_workers/queues-example-filtered-workers.png b/docs/assets/images/monitoring/queues_and_workers/queues-example-filtered-workers.png deleted file mode 100644 index 0ca463a1d..000000000 Binary files a/docs/assets/images/monitoring/queues_and_workers/queues-example-filtered-workers.png and /dev/null differ diff --git a/docs/assets/images/monitoring/queues_and_workers/workers-option-with-regex.png b/docs/assets/images/monitoring/queues_and_workers/workers-option-with-regex.png deleted file mode 100644 index ec9687110..000000000 Binary files a/docs/assets/images/monitoring/queues_and_workers/workers-option-with-regex.png and /dev/null differ diff --git a/docs/assets/images/monitoring/queues_and_workers/workers-option-with-worker-names.png b/docs/assets/images/monitoring/queues_and_workers/workers-option-with-worker-names.png deleted file mode 100644 index 147b29108..000000000 Binary files a/docs/assets/images/monitoring/queues_and_workers/workers-option-with-worker-names.png and /dev/null differ diff --git a/docs/user_guide/command_line.md b/docs/user_guide/command_line.md index a6c6d0394..507b6f7ba 100644 --- a/docs/user_guide/command_line.md +++ b/docs/user_guide/command_line.md @@ -1308,7 +1308,7 @@ The Merlin library comes equipped with several commands to help monitor your wor - *[detailed-status](#detailed-status-merlin-detailed-status)*: Display task-by-task status information for a study - *[monitor](#monitor-merlin-monitor)*: Keep your allocation alive while tasks are being processed -- *[query-workers](#query-workers-merlin-query-workers)*: Communicate with Celery to view information on active workers +- *[query-workers](#query-workers-merlin-query-workers)*: View information on [worker entities](./database/entities.md#worker-entities) - *[queue-info](#queue-info-merlin-queue-info)*: Communicate with Celery to view the status of queues in your workflow(s) - *[status](#status-merlin-status)*: Display a summary of the status of a study @@ -1477,9 +1477,9 @@ merlin monitor [OPTIONS] SPECIFICATION ### Query Workers (`merlin query-workers`) -Check which workers are currently connected to the task server. +View information on [worker entities](./database/entities.md#worker-entities) in your database. -This will broadcast a command to all connected workers and print the names of any that respond and the queues they're attached to. This is useful for interacting with workers, such as via `merlin stop-workers --workers`. +This will query the [Merlin Database](./database/index.md) for information on [logical worker entities](./database/entities.md#logical-worker-entity) and [physical worker entities](./database/entities.md#physical-worker-entity). This can be useful for seeing which workers are running and where. For more information, see the [Query Workers documentation](./monitoring/queues_and_workers.md#query-workers). @@ -1497,7 +1497,9 @@ merlin query-workers [OPTIONS] | `--task_server` | string | Task server type for which to query workers. Currently only "celery" is implemented. | "celery" | | `--spec` | filename | Query for the workers named in the `merlin` block of the spec file given here | None | | `--queues` | List[string] | Takes a space-delimited list of queues as input. This will query for workers associated with the names of the queues you provide here. | None | -| `--workers` | List[regex] | A space-delimited list of regular expressions representing workers to query | None | +| `--workers` | List[string] | A space-delimited list of logical worker names to query | None | +| `-f`, `--format` | choice(`rich` \| `json`) | Output format | rich | +| `-l`, `--local-db` | boolean | Use the local Merlin database for querying workers | `False` | **Examples:** @@ -1527,14 +1529,6 @@ merlin query-workers [OPTIONS] merlin query-workers --workers step_1_worker ``` -!!! example "Query Workers Using Regex" - - This will query only workers whose names start with `step`: - - ```bash - merlin query-workers --workers ^step - ``` - ### Queue Info (`merlin queue-info`) !!! note diff --git a/docs/user_guide/monitoring/queues_and_workers.md b/docs/user_guide/monitoring/queues_and_workers.md index 101c37ec5..7cba3228c 100644 --- a/docs/user_guide/monitoring/queues_and_workers.md +++ b/docs/user_guide/monitoring/queues_and_workers.md @@ -280,28 +280,30 @@ merlin queue-info --spec --vars = ## Query Workers -Merlin provides users with the [`merlin query-workers`](../command_line.md#query-workers-merlin-query-workers) command to help users see which workers are running and what task queues they're watching. +Merlin provides users with the [`merlin query-workers`](../command_line.md#query-workers-merlin-query-workers) command to help users see information about the [worker entities](../database/entities.md#worker-entities) that currently exist in their database. -This command will output content to a table format with two columns: workers and queues. The workers column will contain one connected worker per row. The queues column will contain a comma-delimited list of queues that the connected worker is watching. +This command will output a summary of both the [logical worker entities](../database/entities.md#logical-worker-entity) and the [physical worker entities](../database/entities.md#physical-worker-entity), detailed information on the physical workers that exist, and information on logical workers that don't have any physical instances yet. -**Usage:** +There are two different ways that output of the `query-workers` command can be formattted: `rich` or `json`. By default, this is set to `rich`. + +**Basic Usage:** ```bash merlin query-workers ``` -??? example "Example Query-Workers Output With No Connected Workers" +??? example "Example Query-Workers Output With No Worker Entities in the Database"
- ![Output of Query-Workers When No Workers Are Connected](../../assets/images/monitoring/queues_and_workers/no-connected-workers.png) -
Output of Query-Workers When No Workers Are Connected
+ ![Output of Query-Workers When Worker Entities Do Not Exist](../../assets/images/monitoring/queues_and_workers/query-workers-worker-entities-do-not-exist.png) +
Output of Query-Workers When Worker Entities Do Not Exist
-??? example "Example Query-Workers Output With Connected Workers" +??? example "Example Query-Workers Output With Worker Entities in the Database"
- ![Output of Query-Workers When There Are Workers Connected](../../assets/images/monitoring/queues_and_workers/connected-workers.png) -
Output of Query-Workers When There Are Workers Connected
+ ![Output of Query-Workers When Worker Entities Exist](../../assets/images/monitoring/queues_and_workers/query-workers-worker-entities-exist.png) +
Output of Query-Workers When Worker Entities Exist
### Query Workers by Spec File @@ -456,7 +458,7 @@ merlin query-workers --queues Say we have the below spec file with four workers `creator`, `trainer`, `predictor`, and `verifier` that are each attached to their respective steps/task queues. In other words, `creator` will be connected to the `create_data` task queue, `trainer` will be connected to the `train` task queue, etc.: - ```yaml title="demo_workflow.yaml" hl_lines="33-44" + ```yaml title="demo_workflow_queues_option.yaml" hl_lines="33-44" description: name: demo_workflow description: a very simple merlin workflow @@ -506,14 +508,14 @@ merlin query-workers --queues We can start these workers with: ```bash - merlin run-workers demo_workflow.yaml + merlin run-workers demo_workflow_queues_option.yaml ``` Now if we query the workers *without* the `--queues` option, we'll see all four workers alive and connected to their respective queues:
- ![All Four Workers From 'demo_workflow.yaml' Being Queried](../../assets/images/monitoring/queues_and_workers/queues-example-all-workers.png) -
All Four Workers From 'demo_workflow.yaml' Being Queried
+ ![All Four Workers From 'demo_workflow_queues_option.yaml' Being Queried](../../assets/images/monitoring/queues_and_workers/query-workers-queues-all-workers.png) +
All Four Workers From 'demo_workflow_queues_option.yaml' Being Queried
Let's refine this query to just view the workers connected to the `train` and `predict` queues: @@ -525,25 +527,25 @@ merlin query-workers --queues As we can see in the output below, only the `trainer` and `predictor` workers are now displayed:
- ![Output of Query-Workers Using the Queues Option](../../assets/images/monitoring/queues_and_workers/queues-example-filtered-workers.png) + ![Output of Query-Workers Using the Queues Option](../../assets/images/monitoring/queues_and_workers/query-workers-queues-option.png)
Output of Query-Workers Using the Queues Option
### Query Workers by Worker Regex -There will be instances when you know precisely which workers you want to query. In such cases, the `--workers` option in the `query-workers` command proves useful. This option facilitates querying workers using [regular expressions](https://docs.python.org/3/library/re.html). As full strings are accepted as regular expressions, you can also query workers by worker name. +There will be instances when you know precisely which workers you want to query. In such cases, the `--workers` option in the `query-workers` command proves useful. This option facilitates querying workers by their logical names. **Usage:** ```bash -merlin query-workers --workers +merlin query-workers --workers ``` ??? example "Example of Using the `--workers` Option With Query-Workers" Say we have the following spec file with 3 workers `step_1_worker`, `step_2_worker`, and `other_worker`: - ```yaml title="demo_workflow.yaml" hl_lines="27-35" + ```yaml title="demo_workflow_workers_option.yaml" hl_lines="27-35" description: name: demo_workflow description: a very simple merlin workflow @@ -590,19 +592,6 @@ merlin query-workers --workers In our output we see that both workers that we asked for were queried but `other_worker` was ignored:
- ![Output of Query-Workers Using the Workers Option With Worker Names](../../assets/images/monitoring/queues_and_workers/workers-option-with-worker-names.png) + ![Output of Query-Workers Using the Workers Option With Worker Names](../../assets/images/monitoring/queues_and_workers/query-workers-workers-option.png)
Output of Query-Workers Using the Workers Option With Worker Names
- - Alternatively, we can do the exact same query using a regular expression: - - ```bash - merlin query-workers --workers ^step - ``` - - The `^` operator for regular expressions will match the beginning of a string. In this example when we say `^step` we're asking Merlin to match any worker starting with the word `step`, which in this case is `step_1_worker` and `step_2_worker`. We can see this in the output below: - -
- ![Output of Query-Workers Using the Workers Option With RegEx](../../assets/images/monitoring/queues_and_workers/workers-option-with-regex.png) -
Output of Query-Workers Using the Workers Option With RegEx
-
diff --git a/merlin/abstracts/factory.py b/merlin/abstracts/factory.py index 22719864a..343a6ad1d 100644 --- a/merlin/abstracts/factory.py +++ b/merlin/abstracts/factory.py @@ -124,19 +124,6 @@ def _discover_plugins_via_entry_points(self): except ImportError: LOG.debug("pkg_resources not available for plugin discovery") - def _discover_builtin_modules(self): - """ - Optional hook to discover built-in components by scanning local modules. - - Default implementation does nothing. - - Subclasses can override this method to implement package/module scanning. - """ - LOG.warning( - f"Class {self.__class__.__name__} did not override _discover_builtin_modules(). " - "Built-in module discovery will be skipped." - ) - def _discover_plugins(self): """ Discover and register plugin components via entry points. @@ -144,7 +131,6 @@ def _discover_plugins(self): Subclasses can override this to support more discovery mechanisms. """ self._discover_plugins_via_entry_points() - self._discover_builtin_modules() def _raise_component_error_class(self, msg: str) -> Type[Exception]: """ @@ -227,6 +213,7 @@ def _get_component_class(self, canonical_name: str, component_type: str) -> Any: return component_class + # TODO should we change 'config' to 'kwargs'? def create(self, component_type: str, config: Dict = None) -> Any: """ Instantiate and return a component of the specified type. diff --git a/merlin/celery.py b/merlin/celery.py index 19fada108..5c279259c 100644 --- a/merlin/celery.py +++ b/merlin/celery.py @@ -214,7 +214,7 @@ def handle_worker_startup(sender: str = None, **kwargs): "physical_worker", name=str(sender), host=host, - status=WorkerStatus.RUNNING, + worker_status=WorkerStatus.RUNNING.value, logical_worker_id=logical_worker.get_id(), pid=os.getpid(), ) diff --git a/merlin/cli/commands/query_workers.py b/merlin/cli/commands/query_workers.py index f09225eeb..81d34c0bf 100644 --- a/merlin/cli/commands/query_workers.py +++ b/merlin/cli/commands/query_workers.py @@ -20,9 +20,11 @@ from merlin.ascii_art import banner_small from merlin.cli.commands.command_entry_point import CommandEntryPoint -from merlin.router import query_workers +from merlin.config.configfile import initialize_config from merlin.spec.specification import MerlinSpec from merlin.utils import verify_filepath +from merlin.workers.formatters.formatter_factory import worker_formatter_factory +from merlin.workers.handlers.handler_factory import worker_handler_factory LOG = logging.getLogger("merlin") @@ -66,7 +68,21 @@ def add_parser(self, subparsers: ArgumentParser): action="store", nargs="+", default=None, - help="Regex match for specific workers to query.", + help="Specific logical worker names to query.", + ) + format_default = "rich" + query.add_argument( + "-f", + "--format", + choices=worker_formatter_factory.list_available(), + default=format_default, + help=f"Output format. Default: {format_default}", + ) + query.add_argument( + "-l", + "--local-db", + action="store_true", + help="Use the local Merlin database for querying workers.", ) def process_command(self, args: Namespace): @@ -85,15 +101,24 @@ def process_command(self, args: Namespace): """ print(banner_small) - # Get the workers from the spec file if --spec provided + if args.local_db: + initialize_config(local_mode=True) + worker_names = [] + if args.workers: + worker_names.extend(args.workers) + + # Get the workers from the spec file if --spec provided + spec = None if args.spec: spec_path = verify_filepath(args.spec) spec = MerlinSpec.load_specification(spec_path) - worker_names = spec.get_worker_names() + worker_names.extend(spec.get_worker_names()) for worker_name in worker_names: if "$" in worker_name: LOG.warning(f"Worker '{worker_name}' is unexpanded. Target provenance spec instead?") LOG.debug(f"Searching for the following workers to stop based on the spec {args.spec}: {worker_names}") - query_workers(args.task_server, worker_names, args.queues, args.workers) + task_server = spec.merlin["resources"]["task_server"] if spec else args.task_server + worker_handler = worker_handler_factory.create(task_server) + worker_handler.query_workers(args.format, queues=args.queues, workers=worker_names, local_db=args.local_db) diff --git a/merlin/common/enums.py b/merlin/common/enums.py index 41b6cc6b0..242e6808b 100644 --- a/merlin/common/enums.py +++ b/merlin/common/enums.py @@ -53,10 +53,10 @@ class WorkerStatus(Enum): REBOOTING (str): Indicates the worker is actively restarting itself. String value: "rebooting". """ - RUNNING = "running" - STALLED = "stalled" - STOPPED = "stopped" - REBOOTING = "rebooting" + RUNNING = "RUNNING" + STALLED = "STALLED" + STOPPED = "STOPPED" + REBOOTING = "REBOOTING" class RunStatus(Enum): diff --git a/merlin/db_scripts/data_models.py b/merlin/db_scripts/data_models.py index 4ba017d67..013681e17 100644 --- a/merlin/db_scripts/data_models.py +++ b/merlin/db_scripts/data_models.py @@ -417,7 +417,7 @@ class PhysicalWorkerModel(BaseDataModel): # pylint: disable=too-many-instance-a name (str): The name of the physical worker. pid (str): The process ID (PID) of the worker process. restart_count (int): The number of times this worker has been restarted. - worker_status (WorkerStatus): The current status of the worker (e.g., running, stopped). + worker_status (str): The current status of the worker (e.g., running, stopped). """ id: str = field(default_factory=lambda: str(uuid.uuid4())) # pylint: disable=invalid-name @@ -426,7 +426,7 @@ class PhysicalWorkerModel(BaseDataModel): # pylint: disable=too-many-instance-a launch_cmd: str = None args: Dict = field(default_factory=dict) pid: str = None - worker_status: WorkerStatus = WorkerStatus.STOPPED + worker_status: str = field(default=WorkerStatus.STOPPED.value) heartbeat_timestamp: datetime = field(default_factory=datetime.now) latest_start_time: datetime = field(default_factory=datetime.now) host: str = None diff --git a/merlin/db_scripts/entities/physical_worker_entity.py b/merlin/db_scripts/entities/physical_worker_entity.py index 0f775ffae..96bdc0277 100644 --- a/merlin/db_scripts/entities/physical_worker_entity.py +++ b/merlin/db_scripts/entities/physical_worker_entity.py @@ -202,9 +202,17 @@ def get_pid(self) -> Optional[int]: The process ID for this worker or None if not set. """ self.reload_data() - return int(self.entity_info.pid) if self.entity_info.pid else None + if not self.entity_info.pid: + return None - def set_pid(self, pid: str): + # Handle both int strings and float strings + try: + # Convert to float first, then to int + return int(float(self.entity_info.pid)) + except (ValueError, TypeError): + return None + + def set_pid(self, pid: int): """ Set the PID of this worker. @@ -223,7 +231,7 @@ def get_worker_status(self) -> WorkerStatus: the status of this worker. """ self.reload_data() - return self.entity_info.worker_status + return WorkerStatus(self.entity_info.worker_status) def set_worker_status(self, status: WorkerStatus): """ @@ -233,7 +241,7 @@ def set_worker_status(self, status: WorkerStatus): status: A [`WorkerStatus`][common.enums.WorkerStatus] enum representing the new status of the worker. """ - self.entity_info.worker_status = status + self.entity_info.worker_status = status.value self.save() def get_heartbeat_timestamp(self) -> str: @@ -294,7 +302,7 @@ def get_restart_count(self) -> int: The number of times that this worker has been restarted. """ self.reload_data() - return self.entity_info.restart_count + return int(float(self.entity_info.restart_count)) def increment_restart_count(self): """ diff --git a/merlin/db_scripts/entity_managers/entity_manager.py b/merlin/db_scripts/entity_managers/entity_manager.py index cebd9eed3..3d0472755 100644 --- a/merlin/db_scripts/entity_managers/entity_manager.py +++ b/merlin/db_scripts/entity_managers/entity_manager.py @@ -127,7 +127,7 @@ def _matches_filters(self, entity: T, filters: Dict) -> bool: entity: The entity instance to check against the filters. filters: A dictionary of filter keys and values used to narrow down the query results. Filter keys must correspond to entries in the `_filter_accessor_map` defined - by the subclass. Values are compared against the entity’s corresponding attributes + by the subclass. Values are compared against the entity's corresponding attributes or methods (e.g., {"name": "foo"}, {"queues": ["queue1", "queue2"]}). Returns: diff --git a/merlin/exceptions/__init__.py b/merlin/exceptions/__init__.py index e5c496108..ec919ddad 100644 --- a/merlin/exceptions/__init__.py +++ b/merlin/exceptions/__init__.py @@ -114,6 +114,12 @@ class MerlinWorkerHandlerNotSupportedError(Exception): """ +class MerlinWorkerFormatterNotSupportedError(Exception): + """ + Exception to signal that the provided worker formatter is not supported by Merlin. + """ + + class MerlinWorkerNotSupportedError(Exception): """ Exception to signal that the provided worker is not supported by Merlin. diff --git a/merlin/monitor/celery_monitor.py b/merlin/monitor/celery_monitor.py index 0e11fee9b..400b738c1 100644 --- a/merlin/monitor/celery_monitor.py +++ b/merlin/monitor/celery_monitor.py @@ -23,10 +23,14 @@ import time from typing import List, Set +from celery import Celery + from merlin.db_scripts.entities.run_entity import RunEntity +from merlin.db_scripts.merlin_db import MerlinDatabase from merlin.exceptions import NoWorkersException from merlin.monitor.task_server_monitor import TaskServerMonitor -from merlin.study.celeryadapter import get_workers_from_app, query_celery_queues +from merlin.study.celeryadapter import query_celery_queues +from merlin.workers.handlers.celery_handler import CeleryWorkerHandler LOG = logging.getLogger(__name__) @@ -38,6 +42,9 @@ class CeleryMonitor(TaskServerMonitor): for Celery task servers. This class provides methods to monitor Celery workers, tasks, and workflows. + Attributes: + worker_handler (CeleryWorkerHandler): The worker handler for managing Celery workers. + Methods: wait_for_workers: Wait for Celery workers to start up. check_workers_processing: Check if any Celery workers are still processing tasks. @@ -47,6 +54,16 @@ class CeleryMonitor(TaskServerMonitor): check_tasks: Checks the status of tasks in the Celery queues for a given workflow run. """ + def __init__(self, merlin_db: MerlinDatabase = None, app: Celery = None): + """ + Constructor for CeleryMonitor. + + Args: + merlin_db: The MerlinDatabase instance or None. + app: The Celery application instance or None. + """ + self.worker_handler: CeleryWorkerHandler = CeleryWorkerHandler(merlin_db=merlin_db, app=app) + def wait_for_workers(self, workers: List[str], sleep: int): """ Wait for Celery workers to start up. @@ -61,7 +78,7 @@ def wait_for_workers(self, workers: List[str], sleep: int): count = 0 max_count = 10 while count < max_count: - worker_status = get_workers_from_app() + worker_status = self.worker_handler.get_workers_from_app() LOG.debug(f"CeleryMonitor: checking for workers, running workers = {worker_status} ...") # Check if any of the desired workers have started @@ -114,6 +131,7 @@ def _restart_workers(self, workers: List[str]): except Exception as e: # pylint: disable=broad-exception-caught LOG.error(f"CeleryMonitor: Failed to restart worker '{worker}'. Error: {e}") + # TODO when we create worker watchdog process we may need a method like this in the CeleryWorkerHandler def _get_dead_workers(self, workers: List[str]) -> Set[str]: """ Identify unresponsive Celery workers from a given list. diff --git a/merlin/monitor/monitor.py b/merlin/monitor/monitor.py index 5538b37c6..fc803c733 100644 --- a/merlin/monitor/monitor.py +++ b/merlin/monitor/monitor.py @@ -85,8 +85,8 @@ def __init__(self, spec: MerlinSpec, sleep: int, task_server: str, no_restart: b self.spec: MerlinSpec = spec self.sleep: int = sleep self.no_restart: bool = no_restart - self.task_server_monitor: TaskServerMonitor = monitor_factory.create(task_server) self.merlin_db = MerlinDatabase() + self.task_server_monitor: TaskServerMonitor = monitor_factory.create(task_server, {"merlin_db": self.merlin_db}) # Run garbage collection if enabled if auto_cleanup: diff --git a/merlin/router.py b/merlin/router.py index 48fdb6532..2b8e376fe 100644 --- a/merlin/router.py +++ b/merlin/router.py @@ -26,7 +26,6 @@ get_workers_from_app, purge_celery_tasks, query_celery_queues, - query_celery_workers, run_celery, stop_celery_workers, ) @@ -159,24 +158,6 @@ def query_queues( return {} -def query_workers(task_server: str, spec_worker_names: List[str], queues: List[str], workers_regex: str): - """ - Retrieves information from workers associated with the specified task server. - - Args: - task_server: The task server to query. - spec_worker_names: A list of specific worker names to query. - queues: A list of queues to search for associated workers. - workers_regex: A regex pattern used to filter worker names during the query. - """ - LOG.info("Searching for workers...") - - if task_server == "celery": - query_celery_workers(spec_worker_names, queues, workers_regex) - else: - LOG.error("Celery is not specified as the task server!") - - def get_workers(task_server: str) -> List[str]: """ This function queries the designated task server to obtain a list of all diff --git a/merlin/study/celeryadapter.py b/merlin/study/celeryadapter.py index 87a689780..c1c1ea1ee 100644 --- a/merlin/study/celeryadapter.py +++ b/merlin/study/celeryadapter.py @@ -15,7 +15,6 @@ from amqp.exceptions import ChannelError from celery import Celery -from tabulate import tabulate from merlin.common.dumper import dump_handler from merlin.config import Config @@ -152,41 +151,6 @@ def get_active_celery_queues(app: Celery) -> Tuple[Dict[str, List[str]], List[st return queues, [*active_workers] -def get_active_workers(app: Celery) -> Dict[str, List[str]]: - """ - Retrieve a mapping of active workers to their associated queues for a Celery application. - - This function serves as the inverse of - [`get_active_celery_queues()`][study.celeryadapter.get_active_celery_queues]. It constructs - a dictionary where each key is a worker's name and the corresponding value is a - list of queues that the worker is connected to. This allows for easy identification - of which queues are being handled by each worker. - - Args: - app: The Celery application instance. - - Returns: - A dictionary mapping active worker names to lists of queue names they are - attached to. If no active workers are found, an empty dictionary is returned. - """ - # Get the information we need from celery - i = app.control.inspect() - active_workers = i.active_queues() - if active_workers is None: - active_workers = {} - - # Build the mapping dictionary - worker_queue_map = {} - for worker, queues in active_workers.items(): - for queue in queues: - if worker in worker_queue_map: - worker_queue_map[worker].append(queue["name"]) - else: - worker_queue_map[worker] = [queue["name"]] - - return worker_queue_map - - def celerize_queues(queues: List[str], config: SimpleNamespace = None): """ Prepend a queue tag to each queue in the provided list to conform to Celery's @@ -208,106 +172,6 @@ def celerize_queues(queues: List[str], config: SimpleNamespace = None): queues[i] = f"{config.celery.queue_tag}{queue}" -def _build_output_table(worker_list: List[str], output_table: List[Tuple[str, str]]): - """ - Construct an output table for displaying the status of workers and their associated queues. - - This helper function populates the provided output table with entries for each worker - in the given worker list. It retrieves the mapping of active workers to their queues - and formats the data accordingly. - - Args: - worker_list: A list of worker names to be included in the output table. - output_table: A list of tuples where each entry will be of the form - (worker name, associated queues). - """ - from merlin.celery import app # pylint: disable=C0415 - - # Get a mapping between workers and the queues they're watching - worker_queue_map = get_active_workers(app) - - # Loop through the list of workers and add an entry in the table - # of the form (worker name, queues attached to this worker) - for worker in worker_list: - if "celery@" not in worker: - worker = f"celery@{worker}" - output_table.append((worker, ", ".join(worker_queue_map[worker]))) - - -def query_celery_workers(spec_worker_names: List[str], queues: List[str], workers_regex: List[str]): - """ - Query and filter existing Celery workers based on specified criteria, - and print a table of the workers along with their associated queues. - - This function retrieves the list of active Celery workers and filters them - according to the provided specifications, including worker names from a - spec file, specific queues, and regular expressions for worker names. - It then constructs and displays a table of the matching workers and their - associated queues. - - Args: - spec_worker_names: A list of worker names defined in a spec file - to filter the workers. - queues: A list of queues to filter the workers by. - workers_regex: A list of regular expressions to filter the worker names. - """ - from merlin.celery import app # pylint: disable=C0415 - - # Ping all workers and grab which ones are running - workers = get_workers_from_app() - if not workers: - LOG.warning("No workers found!") - return - - # Remove prepended celery tag while we filter - workers = [worker.replace("celery@", "") for worker in workers] - workers_to_query = [] - - # --queues flag - if queues: - # Get a mapping between queues and the workers watching them - queue_worker_map, _ = get_active_celery_queues(app) - # Remove duplicates and prepend the celery queue tag to all queues - queues = list(set(queues)) - celerize_queues(queues) - # Add the workers associated to each queue to the list of workers we're - # going to query - for queue in queues: - try: - workers_to_query.extend(queue_worker_map[queue]) - except KeyError: - LOG.warning(f"No workers connected to {queue}.") - - # --spec flag - if spec_worker_names: - apply_list_of_regex(spec_worker_names, workers, workers_to_query) - - # --workers flag - if workers_regex: - apply_list_of_regex(workers_regex, workers, workers_to_query) - - # Remove any potential duplicates - workers = set(workers) - workers_to_query = set(workers_to_query) - - # If there were filters and nothing was found then we can't display a table - if (queues or spec_worker_names or workers_regex) and not workers_to_query: - LOG.warning("No workers found that match your filters.") - return - - # Build the output table based on our filters - table = [] - if workers_to_query: - _build_output_table(workers_to_query, table) - else: - _build_output_table(workers, table) - - # Display the output table - LOG.info("Found these connected workers:") - print(tabulate(table, headers=["Workers", "Queues"])) - print() - - def build_csv_queue_info(query_return: List[Tuple[str, int, int]], date: str) -> Dict[str, List]: """ Construct a dictionary containing queue information and column labels diff --git a/merlin/workers/__init__.py b/merlin/workers/__init__.py index 06a19c240..34fb591e1 100644 --- a/merlin/workers/__init__.py +++ b/merlin/workers/__init__.py @@ -20,12 +20,12 @@ responsible for launching and managing groups of workers. Modules: - worker.py: Defines the `MerlinWorker` abstract base class, which represents a single + worker: Defines the `MerlinWorker` abstract base class, which represents a single task server worker and provides a common interface for launching and configuring worker instances. - celery_worker.py: Implements `CeleryWorker`, a concrete subclass of `MerlinWorker` that uses + celery_worker: Implements `CeleryWorker`, a concrete subclass of `MerlinWorker` that uses Celery to process tasks from configured queues. Supports local and batch launch modes. - worker_factory.py: Defines the `WorkerFactory`, which manages the registration, validation, + worker_factory: Defines the `WorkerFactory`, which manages the registration, validation, and instantiation of individual worker implementations such as `CeleryWorker`. Supports plugin discovery via entry points. """ diff --git a/merlin/workers/celery_worker.py b/merlin/workers/celery_worker.py index 037845968..7bea76bc2 100644 --- a/merlin/workers/celery_worker.py +++ b/merlin/workers/celery_worker.py @@ -81,7 +81,7 @@ def __init__( """ super().__init__(name, config, env) self.args = self.config.get("args", "") - self.queues = self.config.get("queues", {"[merlin]_merlin"}) + self.queues = self.config.get("queues", {"merlin"}) self.batch = self.config.get("batch", {}) self.machines = self.config.get("machines", []) self.overlap = overlap diff --git a/merlin/workers/formatters/__init__.py b/merlin/workers/formatters/__init__.py new file mode 100644 index 000000000..50ef8e183 --- /dev/null +++ b/merlin/workers/formatters/__init__.py @@ -0,0 +1,29 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Merlin Worker Formatters Package. + +This package provides classes and utilities for formatting and displaying +Merlin worker information. Worker formatters can render logical and physical +worker data in multiple formats, including JSON for programmatic consumption +and Rich for interactive terminal visualization. The package also includes +a factory for managing supported formatter implementations. + +Modules: + formatter_factory: WorkerFormatterFactory for managing supported worker + formatters. Allows creation by name or alias and ensures consistent + handling of different output formats. + json_formatter: JSONWorkerFormatter that outputs structured, machine-readable + JSON data, including detailed logical and physical worker records, + applied filters, timestamps, and summary statistics. + rich_formatter: RichWorkerFormatter and related layout classes for formatting + and displaying worker information in the terminal with responsive layouts, + summary panels, compact views, and rich styling. Adapts to terminal width + for optimal readability. + worker_formatter: WorkerFormatter abstract base class defining the standard + interface for all worker formatters. +""" diff --git a/merlin/workers/formatters/formatter_factory.py b/merlin/workers/formatters/formatter_factory.py new file mode 100644 index 000000000..4767130fb --- /dev/null +++ b/merlin/workers/formatters/formatter_factory.py @@ -0,0 +1,87 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Worker formatter factory for Merlin. + +This module provides the `WorkerFormatterFactory`, a central registry and +factory class for managing supported worker formatter implementations. +It allows clients to create worker formatters by name or alias, ensuring +consistent handling of different output formats (e.g., JSON, Rich). +""" + +from typing import Any, Type + +from merlin.abstracts import MerlinBaseFactory +from merlin.exceptions import MerlinWorkerFormatterNotSupportedError +from merlin.workers.formatters.json_formatter import JSONWorkerFormatter +from merlin.workers.formatters.rich_formatter import RichWorkerFormatter +from merlin.workers.formatters.worker_formatter import WorkerFormatter + + +class WorkerFormatterFactory(MerlinBaseFactory): + """ + Factory class for managing and instantiating supported Merlin worker formatters. + + This subclass of `MerlinBaseFactory` handles registration, validation, + and instantiation of worker formatters (e.g., rich, json). + + Attributes: + _registry (Dict[str, WorkerFormatter]): Maps canonical formatter names to formatter classes. + _aliases (Dict[str, str]): Maps legacy or alternate names to canonical formatter names. + + Methods: + register: Register a new formatter class and optional aliases. + list_available: Return a list of supported formatter names. + create: Instantiate a formatter class by name or alias. + get_component_info: Return metadata about a registered formatter. + """ + + def _register_builtins(self): + """ + Register built-in worker formatter implementations. + """ + self.register("json", JSONWorkerFormatter) + self.register("rich", RichWorkerFormatter) + + def _validate_component(self, component_class: Any): + """ + Ensure registered component is a subclass of WorkerFormatter. + + Args: + component_class: The class to validate. + + Raises: + TypeError: If the component does not subclass WorkerFormatter. + """ + if not issubclass(component_class, WorkerFormatter): + raise TypeError(f"{component_class} must inherit from WorkerFormatter") + + def _entry_point_group(self) -> str: + """ + Entry point group used for discovering worker formatter plugins. + + Returns: + The entry point namespace for Merlin worker formatter plugins. + """ + return "merlin.workers.formatters" + + def _raise_component_error_class(self, msg: str) -> Type[Exception]: + """ + Raise an appropriate exception when an invalid component is requested. + + Subclasses should override this to raise more specific exceptions. + + Args: + msg: The message to add to the error being raised. + + Raises: + A subclass of Exception (e.g., ValueError by default). + """ + raise MerlinWorkerFormatterNotSupportedError(msg) + + +worker_formatter_factory = WorkerFormatterFactory() diff --git a/merlin/workers/formatters/json_formatter.py b/merlin/workers/formatters/json_formatter.py new file mode 100644 index 000000000..aba1b8585 --- /dev/null +++ b/merlin/workers/formatters/json_formatter.py @@ -0,0 +1,122 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +JSON-based worker information formatter for Merlin. + +This module provides a JSON formatter for displaying worker information +in a structured, machine-readable format. It is primarily intended for +programmatic consumption by downstream tools, scripts, or external systems +that need to parse and analyze worker data rather than display it in a +human-friendly format. + +The formatter includes:\n + - Detailed records of logical workers and their associated queues + - Physical worker details such as ID, host, PID, status, restart counts, + and timestamps + - Relationships between logical and physical workers + - Applied filters and generation timestamp metadata + - Summary statistics for logical and physical workers +""" + +import json +from datetime import datetime +from typing import Dict, List + +from merlin.db_scripts.entities.logical_worker_entity import LogicalWorkerEntity +from merlin.db_scripts.merlin_db import MerlinDatabase +from merlin.workers.formatters.worker_formatter import WorkerFormatter + + +class JSONWorkerFormatter(WorkerFormatter): + """ + JSON formatter for programmatic worker information consumption. + + This formatter generates structured JSON output representing logical + and physical worker entities. The output includes worker details, + relationships between logical and physical workers, and comprehensive + statistics. Designed for use cases where downstream tools or scripts + need to parse worker information in a machine-readable format. + + Attributes: + console (rich.console.Console): A Rich Console object used for displaying + output to the terminal. + + Methods: + format_and_display: Format and print worker information as structured JSON, + including details for logical and physical workers, filters, timestamp, + and summary statistics. + get_worker_statistics: Compute worker statistics, including counts of logical + and physical workers by status, for inclusion in JSON output. + """ + + def format_and_display( + self, + logical_workers: List[LogicalWorkerEntity], + filters: Dict, + merlin_db: MerlinDatabase, + ): + """ + Format and display worker information as JSON. + + This method produces JSON output containing:\n + - A record of applied filters + - A timestamp of when the report was generated + - Detailed logical worker entries (name, queues, associated physical workers) + - Detailed physical worker entries (ID, host, PID, status, restart count, timestamps) + - A summary of worker statistics + + Args: + logical_workers (List[db_scripts.entities.logical_worker_entity.LogicalWorkerEntity]): + A list of logical worker entities to format. + filters (Dict): A dictionary of filters applied to the query. + merlin_db (db_scripts.merlin_db.MerlinDatabase): Database interface for retrieving + physical worker details. + """ + data = { + "filters": filters, + "timestamp": datetime.now().isoformat(), + "logical_workers": [], + "summary": self.get_worker_statistics(logical_workers, merlin_db), + } + + for logical_worker in logical_workers: + logical_data = { + "name": logical_worker.get_name(), + "queues": sorted( + [q[len("[merlin]_") :] if q.startswith("[merlin]_") else q for q in logical_worker.get_queues()] + ), + "physical_workers": [], + } + + physical_worker_ids = logical_worker.get_physical_workers() + physical_workers = [merlin_db.get("physical_worker", pid) for pid in physical_worker_ids] + + for physical_worker in physical_workers: + physical_data = { + "id": physical_worker.get_id() if hasattr(physical_worker, "get_id") else None, + "name": physical_worker.get_name(), + "host": physical_worker.get_host(), + "pid": physical_worker.get_pid(), + "worker_status": physical_worker.get_worker_status().value, + "restart_count": physical_worker.get_restart_count(), + "latest_start_time": ( + physical_worker.get_latest_start_time().isoformat() + if physical_worker.get_latest_start_time() + else None + ), + "heartbeat_timestamp": ( + physical_worker.get_heartbeat_timestamp().isoformat() + if physical_worker.get_heartbeat_timestamp() + else None + ), + } + logical_data["physical_workers"].append(physical_data) + + data["logical_workers"].append(logical_data) + + print(f"data: {data}") + self.console.print(json.dumps(data, indent=2)) diff --git a/merlin/workers/formatters/rich_formatter.py b/merlin/workers/formatters/rich_formatter.py new file mode 100644 index 000000000..62c43360d --- /dev/null +++ b/merlin/workers/formatters/rich_formatter.py @@ -0,0 +1,851 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Rich-based worker formatter with responsive layout for Merlin. + +This module provides classes and utilities to format and display +logical and physical worker information in a terminal using Rich. +It adapts automatically to different terminal widths, providing +compact views for narrow screens and detailed tables and panels +for wider screens. Key features include: + +- Responsive layouts: Adjusts tables and panels based on terminal width. +- Summary panels: Displays aggregated statistics for logical and + physical workers, including running, stopped, stalled, and rebooting counts. +- Compact view: Condensed text output for narrow terminals where full tables + would not fit. +- Rich styling: Uses colors, icons, and text formatting to make worker + statuses and information visually distinct. + +Classes: + LayoutSize: Enum defining terminal width categories for responsive layouts. + ColumnConfig: Dataclass defining display properties for individual table columns. + LayoutConfig: Dataclass defining full layout configuration for a given size. + ResponsiveLayoutManager: Selects and provides layout configurations based + on terminal width. + RichWorkerFormatter: Formats and renders worker information using Rich + components with responsive layouts. + +This module depends on the Merlin database interface and worker entities +([`LogicalWorkerEntity`][db_scripts.entities.logical_worker_entity.LogicalWorkerEntity] +and [`PhysicalWorkerEntity`][db_scripts.entities.physical_worker_entity.PhysicalWorkerEntity]) +to retrieve and display worker information. +""" + +from dataclasses import dataclass +from datetime import datetime, timedelta +from enum import Enum +from typing import Any, Callable, Dict, List, Optional + +from rich.columns import Columns +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from merlin.common.enums import WorkerStatus +from merlin.db_scripts.entities.logical_worker_entity import LogicalWorkerEntity +from merlin.db_scripts.entities.physical_worker_entity import PhysicalWorkerEntity +from merlin.db_scripts.merlin_db import MerlinDatabase +from merlin.workers.formatters.worker_formatter import WorkerFormatter + + +class LayoutSize(Enum): + """ + Enumeration of terminal width categories for responsive layouts. + + This enum is used to adapt table and panel rendering based on + the current terminal width, enabling responsive designs that + remain readable across narrow and wide displays. + + Attributes: + COMPACT (str): Very small terminals (< 60 characters wide). + NARROW (str): Narrow terminals (60-79 characters wide). + MEDIUM (str): Standard terminals (80-119 characters wide). + WIDE (str): Wide terminals (>= 120 characters wide). + """ + + COMPACT = "compact" # < 60 chars + NARROW = "narrow" # 60-79 chars + MEDIUM = "medium" # 80-119 chars + WIDE = "wide" # >= 120 chars + + +@dataclass +class ColumnConfig: # pylint: disable=too-many-instance-attributes + """ + Configuration for an individual table column. + + Defines how a column should be displayed within a table, + including its label, alignment, width constraints, and + optional formatting logic. + + Attributes: + key (str): The field or attribute name mapped to this column. + title (str): The display name of the column header. + style (str): Text style to apply (e.g., "white", "bold red"). + width (Optional[int]): Fixed width of the column, if specified. + max_width (Optional[int]): Maximum allowed width before truncation or wrapping. + justify (str): Text alignment ("left", "center", or "right"). + no_wrap (bool): Whether to prevent text wrapping in this column. + formatter (Optional[Callable[[Any], Any]]): Optional callable to transform cell + values before display. + """ + + key: str + title: str + style: str = "white" + width: Optional[int] = None + max_width: Optional[int] = None + justify: str = "left" + no_wrap: bool = False + formatter: Optional[Callable[[Any], Any]] = None + + +@dataclass +class LayoutConfig: + """ + Configuration for rendering worker information at a given layout size. + + Determines which elements (tables, panels) are displayed and how + they are arranged, based on the current terminal width category. + + Attributes: + size (workers.formatters.rich_formatter.LayoutSize): The layout + size category. + show_summary_panels (bool): Whether to display summary panels. + panels_horizontal (bool): If True, display panels side-by-side + (horizontal), otherwise stack them vertically. + physical_worker_columns (List[workers.formatters.rich_formatter.ColumnConfig]): + Column configurations for physical worker tables. + logical_worker_columns (List[workers.formatters.rich_formatter.ColumnConfig]): + Column configurations for logical worker tables. + use_compact_view (bool): Whether to enable a simplified, space-saving + view of worker information. + """ + + size: LayoutSize + show_summary_panels: bool = True + panels_horizontal: bool = True + physical_worker_columns: List[ColumnConfig] = None + logical_worker_columns: List[ColumnConfig] = None + use_compact_view: bool = False + + +class ResponsiveLayoutManager: + """ + Manage and provide responsive layout configurations for terminal output. + + This class adapts the display of worker information (both physical and logical) + to different terminal widths. It selects column visibility, ordering, and + formatting rules depending on the width category (compact, narrow, medium, wide). + + Attributes: + layouts (Dict[workers.formatters.rich_formatter.LayoutSize, workers.formatters.rich_formatter.LayoutConfig]): + A mapping of layout sizes to their respective configurations, including + column definitions and display options. + + Methods: + get_layout_size: Determine the appropriate layout size category based on terminal width. + get_layout_config: Retrieve the full layout configuration object for a given terminal width. + _format_status: Format a worker status into a styled Rich Text object with icons. + """ + + def __init__(self): + """ + Initialize the responsive layout manager. + + Predefines layout configurations for all supported terminal width categories + (compact, narrow, medium, wide). Each configuration specifies which columns + should be shown for physical and logical worker tables, whether summary panels + are displayed, and how they are arranged. + """ + self.layouts = { + LayoutSize.COMPACT: LayoutConfig( + size=LayoutSize.COMPACT, + show_summary_panels=False, + use_compact_view=True, + physical_worker_columns=[], + logical_worker_columns=[], + ), + LayoutSize.NARROW: LayoutConfig( + size=LayoutSize.NARROW, + show_summary_panels=True, + panels_horizontal=False, + physical_worker_columns=[ + ColumnConfig(key="worker", title="Worker", style="bold cyan", max_width=12), + ColumnConfig(key="host", title="Host", style="blue", max_width=10), + ColumnConfig(key="pid", title="PID", style="yellow", width=8, justify="right"), + ColumnConfig(key="worker_status", title="Status", style="bold", width=10, formatter=self._format_status), + ], + logical_worker_columns=[ + ColumnConfig(key="worker", title="Worker", style="bold white", max_width=20), + ColumnConfig(key="status", title="Status", style="bold red", width=12), + ], + ), + LayoutSize.MEDIUM: LayoutConfig( + size=LayoutSize.MEDIUM, + show_summary_panels=True, + panels_horizontal=True, + physical_worker_columns=[ + ColumnConfig(key="worker", title="Worker", style="bold cyan", max_width=15), + ColumnConfig(key="instance", title="Instance", style="bold magenta", max_width=25), + ColumnConfig(key="host", title="Host", style="blue", max_width=12), + ColumnConfig(key="pid", title="PID", style="yellow", width=8, justify="right"), + ColumnConfig(key="worker_status", title="Status", style="bold", width=10, formatter=self._format_status), + ColumnConfig(key="runtime", title="Runtime", style="cyan", width=8), + ], + logical_worker_columns=[ + ColumnConfig(key="worker", title="Worker Name", style="bold white", max_width=25), + ColumnConfig(key="queues", title="Queues", style="green", max_width=30, no_wrap=True), + ColumnConfig(key="status", title="Status", style="bold red", width=12), + ], + ), + LayoutSize.WIDE: LayoutConfig( + size=LayoutSize.WIDE, + show_summary_panels=True, + panels_horizontal=True, + physical_worker_columns=[ + ColumnConfig(key="worker", title="Logical Worker", style="bold cyan", max_width=15), + ColumnConfig(key="queues", title="Queues", style="green", max_width=20, no_wrap=True), + ColumnConfig(key="instance", title="Instance Name", style="bold magenta", max_width=30), + ColumnConfig(key="host", title="Host", style="blue", max_width=12), + ColumnConfig(key="pid", title="PID", style="yellow", width=8, justify="right"), + ColumnConfig(key="worker_status", title="Status", style="bold", width=10, formatter=self._format_status), + ColumnConfig(key="runtime", title="Runtime", style="cyan", width=8), + ColumnConfig(key="heartbeat", title="Heartbeat", style="bright_blue", width=10), + ColumnConfig(key="restarts", title="Restarts", style="red", width=8, justify="right"), + ], + logical_worker_columns=[ + ColumnConfig(key="worker", title="Worker Name", style="bold white", max_width=25), + ColumnConfig(key="queues", title="Queues", style="green", max_width=30, no_wrap=True), + ColumnConfig(key="status", title="Status", style="bold red", width=12), + ], + ), + } + + def get_layout_size(self, width: int) -> LayoutSize: + """ + Determine the layout size category for a given terminal width. + + Args: + width (int): The terminal width in characters. + + Returns: + The corresponding layout size category (COMPACT, NARROW, MEDIUM, or WIDE). + """ + if width < 60: + return LayoutSize.COMPACT + if width < 80: + return LayoutSize.NARROW + if width < 120: + return LayoutSize.MEDIUM + return LayoutSize.WIDE + + def get_layout_config(self, width: int) -> LayoutConfig: + """ + Retrieve the layout configuration for a given terminal width. + + Args: + width (int): The terminal width in characters. + + Returns: + The full layout configuration for the determined size, + including column definitions and panel options. + """ + size = self.get_layout_size(width) + return self.layouts[size] + + def _format_status(self, status: WorkerStatus) -> Text: + """ + Format a worker status value into a Rich `Text` object. + + Adds an icon and applies color styling to make worker statuses + visually distinguishable in tables. + + Args: + status (common.enums.WorkerStatus): The worker status value. + + Returns: + A Rich Text object containing the styled status with an icon. + """ + status_config = { + "RUNNING": ("✓", "bold green"), + "STALLED": ("⚠", "bold yellow"), + "STOPPED": ("✗", "bold red"), + "REBOOTING": ("↻", "bold cyan"), + } + + icon, color = status_config.get(status.value, ("?", "white")) + return Text(f"{icon} {status.value}", style=color) + + +class RichWorkerFormatter(WorkerFormatter): + """ + Format and display worker information using Rich with responsive layouts. + + This class provides a Rich-based implementation of a worker formatter that + adapts to terminal width. It uses responsive tables, panels, and compact + views to display logical and physical worker information in a clear, + visually rich way. Layouts are selected automatically based on terminal + width using a [`ResponsiveLayoutManager`][workers.formatters.rich_formatter.ResponsiveLayoutManager]. + + Attributes: + console (rich.console.Console): A Rich Console object used for displaying + output to the terminal. + layout_manager (workers.formatters.rich_formatter.ResponsiveLayoutManager): + The layout manager responsible for selecting column and panel + configurations based on terminal width. + + Methods: + format_and_display: Format and display worker information with responsive Rich components. + get_worker_statistics: Compute summary statistics for logical and physical workers. + _display_compact_view: Display a simplified worker summary for very narrow terminals. + _display_summary_panels: Render summary panels (logical, physical, filters) + depending on layout settings. + _build_responsive_table: Construct a Rich table using the given column configuration and data. + _get_physical_worker_data: Extract detailed physical worker data for table display. + _get_logical_workers_without_instances_data: Extract logical workers that have no + physical instances. + _sort_physical_workers: Sort physical worker data by status (running first), then by + worker and instance name. + _format_status: Format a worker status with icons and Rich color highlighting. + _format_uptime_or_downtime: Return uptime for running workers or downtime for stopped + workers in human-readable format. + _format_time_duration: Format a time duration into a human-readable string (e.g., "2h 15m"). + _format_last_heartbeat: Format last heartbeat timestamp with color coding based on recency. + _build_summary_panels: Build summary panels showing filters, logical workers, and + physical workers. + _build_compact_view: Build a compact text-only view of worker status for very narrowterminals. + """ + + def __init__(self): + """ + Initialize the Rich-based worker formatter. + + This constructor sets up the responsive layout manager + that determines how worker data will be displayed based + on the terminal width. + """ + super().__init__() + self.layout_manager = ResponsiveLayoutManager() + + def format_and_display( + self, + logical_workers: List[LogicalWorkerEntity], + filters: Dict, + merlin_db: MerlinDatabase, + ): + """ + Format and display worker information using Rich components. + + This method generates a responsive console output of worker + statistics, tables, and summary panels. The output adapts + to the current terminal width and user-defined filters. + + Args: + logical_workers (List[db_scripts.entities.logical_worker_entity.LogicalWorkerEntity]): + A list of logical worker objects to display and summarize. + filters (Dict): Active filters applied to the display + (e.g., by worker type or status). + merlin_db (db_scripts.merlin_db.MerlinDatabase): Reference to + the Merlin database, used to query worker-related data. + """ + # Calculate statistics + stats = self.get_worker_statistics(logical_workers, merlin_db) + console_width = self.console.size.width + layout_config = self.layout_manager.get_layout_config(console_width) + + self.console.print() # Empty line + + # Handle compact view for very narrow terminals + if layout_config.use_compact_view: + compact_view = self._build_compact_view(logical_workers, merlin_db) + self._display_compact_view(compact_view, filters, stats) + return + + # Display summary panels + if layout_config.show_summary_panels: + self._display_summary_panels(stats, filters, layout_config) + self.console.print() # Empty line + + # Show physical workers table if any exist + if stats["total_physical"] > 0: + physical_table = self._build_responsive_table( + "[bold magenta]Physical Worker Instances[/bold magenta]", + layout_config.physical_worker_columns, + self._get_physical_worker_data(logical_workers, merlin_db), + self._sort_physical_workers, + ) + self.console.print(physical_table) + self.console.print() # Empty line + + # Show logical workers without instances if any exist + if stats["logical_without_instances"] > 0: + no_instances_table = self._build_responsive_table( + "[bold cyan]Logical Workers Without Instances[/bold cyan]", + layout_config.logical_worker_columns, + self._get_logical_workers_without_instances_data(logical_workers), + lambda data: sorted(data, key=lambda x: x["worker"]), + ) + self.console.print(no_instances_table) + + def _get_queues_str(self, queues: List[str]) -> str: + """ + Given a list of queue names, remove the '[merlin]_' prefix and + combine them into a comma-delimited string. + + Args: + queues (List[str]): The list of queue names to combine. + + Returns: + A comma-delimited string of queues without the '[merlin]_' prefix. + """ + return ", ".join(sorted(q[len("[merlin]_") :] if q.startswith("[merlin]_") else q for q in queues)) + + def _display_compact_view( + self, + compact_view: str, + filters: Dict, + stats: Dict[str, int], + ): + """ + Display a compact view of worker information for narrow terminals. + + This fallback view is used when the terminal width is too small + to render full tables or summary panels. It prints a concise + summary of worker status, applied filters, and worker details. + + Args: + compact_view (str): Multi-line string representing the compact worker view. + filters (Dict): Active filters applied to the display. + stats (Dict[str, int]): Aggregated worker statistics (e.g., running counts). + """ + self.console.print("[bold cyan]Worker Status[/bold cyan]") + if filters: + filter_text = " | ".join([f"{k}: {','.join(v)}" for k, v in filters.items()]) + self.console.print(f"[dim]Filters: {filter_text}[/dim]") + + self.console.print( + f"[bold]Summary:[/bold] {stats['physical_running']}/{stats['total_physical']} " + f"running, {stats['logical_without_instances']} logical workers without instances\n" + ) + + self.console.print(compact_view) + + def _display_summary_panels(self, stats: Dict[str, int], filters: Dict, layout_config: LayoutConfig): + """ + Display summary panels based on the given layout configuration. + + This method renders one or more Rich panels summarizing worker statistics. + The panels are built dynamically from the provided stats and filters, and + displayed either horizontally in columns or vertically in sequence, + depending on the layout configuration. + + Args: + stats (Dict[str, int]): + A dictionary of aggregated worker statistics to summarize. + filters (Dict): + A dictionary of filter criteria used to generate the summary content. + layout_config (workers.formatters.rich_formatter.LayoutConfig): + An object defining layout preferences (e.g., whether to + arrange panels horizontally). + """ + summary_panels = self._build_summary_panels(stats, filters) + + if layout_config.panels_horizontal: + self.console.print(Columns(summary_panels, equal=True)) + else: + for panel in summary_panels: + self.console.print(panel) + + def _build_responsive_table( + self, title: str, columns: List[ColumnConfig], data: List[Dict], sort_func: Callable = None + ) -> Table: + """ + Build and return a Rich table with responsive column configuration. + + This method creates a table with headers, applies styles and sizing rules + based on column configuration, and optionally sorts the data before + rendering. Each row of the table is constructed by extracting values + from the input data and applying any specified formatters. + + Args: + title (str): + Title displayed above the table. + columns (List[workers.formatters.rich_formatter.ColumnConfig]): + A list of column configuration objects defining headers, + widths, alignment, and formatting functions. + data (List[Dict]): + List of dictionaries containing the row data to display. + sort_func (Optional[Callable]): + A function to sort the data before rendering. Defaults to None. + + Returns: + A fully constructed Rich Table object ready for display. + """ + table = Table( + show_header=True, + header_style="bold white", + title=title, + ) + + # Add columns based on configuration + for col in columns: + table.add_column( + col.title, style=col.style, width=col.width, max_width=col.max_width, justify=col.justify, no_wrap=col.no_wrap + ) + + # Sort data if sort function provided + if sort_func: + data = sort_func(data) + + # Add rows + for row_data in data: + row_values = [] + for col in columns: + value = row_data.get(col.key, "-") + if col.formatter: + value = col.formatter(value) + row_values.append(value) + table.add_row(*row_values) + + return table + + def _get_physical_worker_data(self, logical_workers: List[LogicalWorkerEntity], merlin_db: MerlinDatabase) -> List[Dict]: + """ + Extract and format physical worker data for table display. + + This method retrieves all physical workers associated with the given + logical workers, queries the database for their details, and formats + the data into a list of dictionaries suitable for table rendering. + Each entry includes identifiers, host, status, runtime, heartbeat, + and restart count. + + Args: + logical_workers (List[db_scripts.entities.logical_worker_entity.LogicalWorkerEntity]): + A list of logical worker entities, each referencing one or more + associated physical workers. + merlin_db (db_scripts.merlin_db.MerlinDatabase): + The database interface used to fetch physical worker entities. + + Returns: + A list of dictionaries, where each dictionary contains:\n + - worker (str): Logical worker name. + - queues (str): Comma-separated queue names without "[merlin]_". + - instance (str): Instance/worker name. + - host (str): Hostname where the worker is running. + - pid (str): Process ID of the worker, or "-" if unavailable. + - status (common.enums.WorkerStatus): Raw worker status object (used for formatting). + - runtime (str): Formatted uptime/downtime string. + - heartbeat (str): Last heartbeat timestamp or "-". + - restarts (str): Number of times the worker has restarted. + - _sort_status (str): String representation of the status + used for sorting. + """ + data = [] + + for logical_worker in logical_workers: + worker_name = logical_worker.get_name() + queues_str = self._get_queues_str(logical_worker.get_queues()) + + physical_worker_ids = logical_worker.get_physical_workers() + physical_workers = [merlin_db.get("physical_worker", pid) for pid in physical_worker_ids] + + for physical_worker in physical_workers: + status = physical_worker.get_worker_status() + + # Only show heartbeat for running workers + heartbeat_text = "-" + if status == WorkerStatus.RUNNING: + heartbeat_text = str(self._format_last_heartbeat(physical_worker.get_heartbeat_timestamp())) + + instance_name = physical_worker.get_name() or "-" + + data.append( + { + "worker": worker_name, + "queues": queues_str, + "instance": instance_name, + "host": physical_worker.get_host() or "-", + "pid": str(physical_worker.get_pid()) if physical_worker.get_pid() else "-", + "worker_status": status, + "runtime": self._format_uptime_or_downtime(physical_worker), + "heartbeat": heartbeat_text, + "restarts": str(physical_worker.get_restart_count()), + "_sort_status": status.value, + } + ) + + return data + + def _get_logical_workers_without_instances_data(self, logical_workers: List[LogicalWorkerEntity]) -> List[Dict]: + """ + Extract logical workers that have no associated physical instances. + + This method identifies all logical workers that currently do not have + any physical worker instances. It formats their data into a list of + dictionaries suitable for table rendering, including a status field + indicating "NO INSTANCES". + + Args: + logical_workers (List[db_scripts.entities.logical_worker_entity.LogicalWorkerEntity]): + A list of logical worker entities to check. + + Returns: + A list of dictionaries with the following keys:\n + - worker (str): Name of the logical worker. + - queues (str): Comma-separated list of queues the worker is associated with. + - status (Text): Rich-formatted status indicating no instances. + """ + data = [] + + for logical_worker in logical_workers: + physical_worker_ids = logical_worker.get_physical_workers() + if not physical_worker_ids: + queues_str = self._get_queues_str(logical_worker.get_queues()) + + data.append( + { + "worker": logical_worker.get_name(), + "queues": queues_str, + "status": Text("NO INSTANCES", style="bold red"), + } + ) + + return data + + def _sort_physical_workers(self, data: List[Dict]) -> List[Dict]: + """ + Sort a list of physical worker data dictionaries for display. + + Sorting prioritizes workers that are currently running first, + followed by sorting alphabetically by logical worker name and + then by physical instance name. + + Args: + data (List[Dict]): List of dictionaries containing physical + worker data, including a '_sort_status' key. + + Returns: + Sorted list of physical worker dictionaries. + """ + return sorted(data, key=lambda row: (0 if row["_sort_status"] == "RUNNING" else 1, row["worker"], row["instance"])) + + def _format_status(self, status: WorkerStatus) -> Text: + """ + Format a worker status into a Rich Text object with an icon and color. + + Converts the raw status into a human-readable string with an + associated Unicode icon and color highlighting for easier visualization. + + Args: + status (common.enums.WorkerStatus): Raw worker status object. + + Returns: + Rich Text object combining an icon and colored status string. + Status mapping:\n + - "RUNNING": ✓ green + - "STALLED": ⚠ yellow + - "STOPPED": ✗ red + - "REBOOTING": ↻ cyan + - Unknown: ? white + """ + status_config = { + "RUNNING": ("✓", "bold green"), + "STALLED": ("⚠", "bold yellow"), + "STOPPED": ("✗", "bold red"), + "REBOOTING": ("↻", "bold cyan"), + } + + icon, color = status_config.get(status.value, ("?", "white")) + return Text(f"{icon} {status.value}", style=color) + + def _format_uptime_or_downtime(self, physical_worker: PhysicalWorkerEntity) -> str: + """ + Format the uptime for running workers or downtime for stopped workers. + + For running workers, this method calculates the elapsed time since the + latest start and returns a human-readable string. For stopped workers, + it calculates the time since the stop event and prefixes it with "down". + If no start or stop times are available, returns a placeholder string. + + Args: + physical_worker (db_scripts.entities.physical_worker_entity.PhysicalWorkerEntity): + The physical worker entity to calculate uptime/downtime for. + + Returns: + Human-readable uptime or downtime string. + """ + status = physical_worker.get_worker_status().value + + if status == "RUNNING": + start_time = physical_worker.get_latest_start_time() + if start_time: + uptime = datetime.now() - start_time + return self._format_time_duration(uptime) + else: + # TODO when we refactor stop-workers, add this in + stop_time = getattr(physical_worker, "get_stop_time", lambda: None)() + if stop_time: + downtime = datetime.now() - stop_time + return f"down {self._format_time_duration(downtime)}" + return "stopped" + + return "-" + + def _format_time_duration(self, duration: timedelta) -> str: + """ + Convert a timedelta into a compact, human-readable string. + + Formats the duration using days, hours, minutes, and seconds in a + concise format suitable for table display. + + Args: + duration (timedelta): Time duration to format. + + Returns: + Formatted duration string. + """ + if duration.days > 0: + return f"{duration.days}d {duration.seconds // 3600}h {duration.seconds % 3600 // 60}m" + if duration.seconds >= 3600: + return f"{duration.seconds // 3600}h {(duration.seconds % 3600) // 60}m" + if duration.seconds >= 60: + return f"{duration.seconds // 60}m" + return f"{duration.seconds}s" + + def _format_last_heartbeat(self, heartbeat_timestamp: datetime) -> Text: + """ + Format the last heartbeat timestamp with color coding based on recency. + + Displays a human-readable time difference between the current time + and the last heartbeat. Uses color to indicate how recent the heartbeat was. + + Args: + heartbeat_timestamp (datetime): Timestamp of the last heartbeat. + + Returns: + Rich Text object containing the formatted heartbeat. + """ + if not heartbeat_timestamp: + return Text("-", style="dim") + + time_diff = datetime.now() - heartbeat_timestamp + + if time_diff < timedelta(minutes=1): # Less than 1 minute + return Text("Just now", style="green") + + if time_diff.total_seconds() < 3600: # Less than 1 hour + minutes = int(time_diff.total_seconds() // 60) + if minutes < 5: + return Text(f"{minutes}m ago", style="yellow") + return Text(f"{minutes}m ago", style="orange3") + + # More than 1 hour + hours = int(time_diff.total_seconds() // 3600) + return Text(f"{hours}h ago", style="red") + + def _build_summary_panels(self, stats: Dict[str, int], filters: Dict) -> List[Panel]: + """ + Build Rich summary panels displaying worker statistics and applied filters. + + Creates panels for:\n + - Applied filters (queues, workers) if any. + - Logical worker counts, with and without instances. + - Physical worker counts, categorized by running, stopped, stalled, and rebooting. + + Args: + stats (Dict[str, int]): Dictionary of worker statistics, as returned by `get_worker_statistics`. + filters (Dict): Dictionary of applied filters, e.g., queues or worker names. + + Returns: + List of Rich Panel objects ready for display in the console. + """ + panels = [] + + # Filter information + if filters: + filter_parts = [] + if "queues" in filters: + queues_str = self._get_queues_str(filters["queues"]) + filter_parts.append(f"Queues: {queues_str}") + if "name" in filters: + filter_parts.append(f"Workers: {', '.join(filters['name'])}") + + filter_text = "\n".join(filter_parts) + panels.append(Panel(filter_text, title="[bold blue]Applied Filters[/bold blue]", border_style="blue")) + + # Logical worker summary + logical_summary = ( + f"Total: [bold white]{stats['total_logical']}[/bold white]\n" + f"With Instances: [bold green]{stats['logical_with_instances']}[/bold green]\n" + f"Without Instances: [bold dim]{stats['logical_without_instances']}[/bold dim]" + ) + + panels.append(Panel(logical_summary, title="[bold cyan]Logical Workers[/bold cyan]", border_style="cyan")) + + # Physical worker summary + if stats["total_physical"] > 0: + physical_summary = ( + f"Total: [bold white]{stats['total_physical']}[/bold white]\n" + f"Running: [bold green]{stats['physical_running']}[/bold green]\n" + f"Stopped: [bold red]{stats['physical_stopped']}[/bold red]" + ) + + if stats["physical_stalled"] > 0: + physical_summary += f"\nStalled: [bold yellow]{stats['physical_stalled']}[/bold yellow]" + if stats["physical_rebooting"] > 0: + physical_summary += f"\nRebooting: [bold cyan]{stats['physical_rebooting']}[/bold cyan]" + + panels.append( + Panel(physical_summary, title="[bold magenta]Physical Instances[/bold magenta]", border_style="magenta") + ) + + return panels + + def _build_compact_view(self, logical_workers: List[LogicalWorkerEntity], merlin_db: MerlinDatabase) -> str: + """ + Build a compact text-based view of workers for narrow terminals. + + Displays each logical worker and its physical instances in a concise format. + Logical workers without instances are marked explicitly as "NO INSTANCES". + Each physical worker shows status, host, and PID with basic coloring for status. + + Args: + logical_workers (List[db_scripts.entities.logical_worker_entity.LogicalWorkerEntity]): + List of logical worker entities. + merlin_db (db_scripts.merlin_db.MerlinDatabase): Database interface to fetch physical + worker details. + + Returns: + Multi-line string representing the compact worker view. + """ + output_lines = [] + + for logical_worker in logical_workers: + worker_name = logical_worker.get_name() + physical_worker_ids = logical_worker.get_physical_workers() + + if not physical_worker_ids: + output_lines.append(f"[bold white]{worker_name}[/bold white]: [bold red]NO INSTANCES[/bold red]") + else: + physical_workers = [merlin_db.get("physical_worker", pid) for pid in physical_worker_ids] + + for physical_worker in physical_workers: + status = physical_worker.get_worker_status().value + host = physical_worker.get_host() or "?" + pid = physical_worker.get_pid() or "-" + + status_icon = "✓" if status == "RUNNING" else "✗" + color = "green" if status == "RUNNING" else "red" + + output_lines.append( + f"[bold white]{worker_name}[/bold white]@[blue]{host}[/blue] " + f"[{color}]{status_icon} {status}[/{color}] (PID: {pid})" + ) + + return "\n".join(output_lines) diff --git a/merlin/workers/formatters/worker_formatter.py b/merlin/workers/formatters/worker_formatter.py new file mode 100644 index 000000000..37097be89 --- /dev/null +++ b/merlin/workers/formatters/worker_formatter.py @@ -0,0 +1,136 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Worker formatter base module for displaying worker query results. + +This module defines the abstract base class `WorkerFormatter`, which provides a +standard interface for formatting and displaying information about Merlin workers. +Worker formatters are responsible for presenting logical and physical worker +information in a structured, user-friendly manner (e.g., through text, tables, +or rich console output). + +Intended Usage:\n + Subclasses of `WorkerFormatter` (e.g., those using Rich for terminal + visualization) should implement `format_and_display` to render + worker information, while reusing `get_worker_statistics` for + consistent metrics across implementations. +""" + +from abc import ABC, abstractmethod +from typing import Dict, List + +from rich.console import Console + +from merlin.common.enums import WorkerStatus +from merlin.db_scripts.entities.logical_worker_entity import LogicalWorkerEntity +from merlin.db_scripts.merlin_db import MerlinDatabase + + +class WorkerFormatter(ABC): + """ + Abstract base class for formatting and displaying worker query results. + + Provides a consistent interface for formatting logical and physical worker + information, including a utility method to calculate worker statistics. + + Attributes: + console (rich.console.Console): A Rich Console object used for displaying + output to the terminal. + + Methods: + format_and_display: Abstract method that formats and outputs worker + information. Must be implemented by subclasses. + get_worker_statistics: Compute counts and statuses of logical and + physical workers, including totals and breakdown by status. + """ + + def __init__(self): + """ + Initializer for the WorkerFormatter class. + + Sets up a rich console object so that subclasses can easily print formatted + output to the console. + """ + self.console = Console() + + @abstractmethod + def format_and_display(self, logical_workers: List, filters: Dict, merlin_db: MerlinDatabase): + """ + Format and display information about logical and physical workers. + + This method must be implemented by subclasses to define the output + format (e.g., JSON, Rich tables, text). Implementations should make + use of `get_worker_statistics` if worker summary metrics are required. + + Args: + logical_workers (List[LogicalWorkerEntity]): List of logical worker + entities to be displayed. + filters (Dict): Optional filters applied to the worker query. + merlin_db (MerlinDatabase): Database interface for retrieving + physical worker details. + """ + raise NotImplementedError("Subclasses of `WorkerFormatter` must implement a `format_and_display` method.") + + def get_worker_statistics(self, logical_workers: List[LogicalWorkerEntity], merlin_db: MerlinDatabase) -> Dict[str, int]: + """ + Calculate comprehensive statistics for logical and physical workers. + + Iterates through all logical workers and their associated physical + instances to compute counts of running, stopped, stalled, and rebooting + workers, as well as counts of logical workers with or without instances. + + Args: + logical_workers (List[db_scripts.entities.logical_worker_entity.LogicalWorkerEntity]): + List of logical worker entities. + merlin_db (db_scripts.merlin_db.MerlinDatabase): Database interface to fetch physical + worker details. + + Returns: + Dictionary containing worker statistics:\n + - total_logical: Total number of logical workers. + - logical_with_instances: Number of logical workers with physical instances. + - logical_without_instances: Number of logical workers without physical instances. + - total_physical: Total number of physical workers. + - physical_running: Count of running physical workers. + - physical_stopped: Count of stopped physical workers. + - physical_stalled: Count of stalled physical workers. + - physical_rebooting: Count of rebooting physical workers. + """ + stats = { + "total_logical": len(logical_workers), + "logical_with_instances": 0, + "logical_without_instances": 0, + "total_physical": 0, + "physical_running": 0, + "physical_stopped": 0, + "physical_stalled": 0, + "physical_rebooting": 0, + } + + for logical_worker in logical_workers: + physical_worker_ids = logical_worker.get_physical_workers() + + if physical_worker_ids: + stats["logical_with_instances"] += 1 + physical_workers = [merlin_db.get("physical_worker", pid) for pid in physical_worker_ids] + + for physical_worker in physical_workers: + stats["total_physical"] += 1 + status = physical_worker.get_worker_status() + + if status == WorkerStatus.RUNNING: + stats["physical_running"] += 1 + elif status == WorkerStatus.STOPPED: + stats["physical_stopped"] += 1 + elif status == WorkerStatus.STALLED: + stats["physical_stalled"] += 1 + elif status == WorkerStatus.REBOOTING: + stats["physical_rebooting"] += 1 + else: + stats["logical_without_instances"] += 1 + + return stats diff --git a/merlin/workers/handlers/__init__.py b/merlin/workers/handlers/__init__.py index 30969814f..f1a79e28e 100644 --- a/merlin/workers/handlers/__init__.py +++ b/merlin/workers/handlers/__init__.py @@ -15,11 +15,11 @@ interface while enabling future integration with additional systems such as Kafka. Modules: - handler_factory.py: Factory for registering and instantiating Merlin worker + handler_factory: Factory for registering and instantiating Merlin worker handler implementations. - worker_handler.py: Abstract base class that defines the interface for all Merlin + worker_handler: Abstract base class that defines the interface for all Merlin worker handlers. - celery_handler.py: Celery-specific implementation of the worker handler interface. + celery_handler: Celery-specific implementation of the worker handler interface. """ diff --git a/merlin/workers/handlers/celery_handler.py b/merlin/workers/handlers/celery_handler.py index 2fd7b494e..aaba14cb5 100644 --- a/merlin/workers/handlers/celery_handler.py +++ b/merlin/workers/handlers/celery_handler.py @@ -14,9 +14,15 @@ """ import logging -from typing import List +from typing import Dict, List +from celery import Celery + +from merlin.common.enums import WorkerStatus +from merlin.db_scripts.entities.logical_worker_entity import LogicalWorkerEntity +from merlin.db_scripts.merlin_db import MerlinDatabase from merlin.workers import CeleryWorker +from merlin.workers.formatters.formatter_factory import worker_formatter_factory from merlin.workers.handlers.worker_handler import MerlinWorkerHandler @@ -32,12 +38,21 @@ class CeleryWorkerHandler(MerlinWorkerHandler): Celery-specific behavior, including launching workers with optional command-line overrides, stopping workers, and querying their status. + Attributes: + merlin_db (MerlinDatabase): The database instance used for worker management. + Methods: start_workers: Launch or echo Celery workers with optional arguments. stop_workers: Attempt to stop active Celery workers. query_workers: Return a basic summary of Celery worker status. """ + def __init__(self, merlin_db: MerlinDatabase = None, app: Celery = None): + super().__init__(merlin_db=merlin_db) + if app is None: + from merlin.celery import app # pylint: disable=import-outside-toplevel + self.app = app + def start_workers(self, workers: List[CeleryWorker], **kwargs): """ Launch or echo Celery workers with optional override behavior. @@ -68,7 +83,126 @@ def stop_workers(self): Attempt to stop Celery workers. """ - def query_workers(self): + def get_workers_from_app(self) -> List[str]: + """ + Retrieve a list of all workers connected to the Celery application. + + This method uses the Celery control interface to inspect the current state + of the application and returns a list of workers that are currently connected. + If no workers are found, an empty list is returned. + + Args: + app: The Celery application instance. + + Returns: + A list of worker names that are currently connected to the Celery application. + If no workers are connected, an empty list is returned. + """ + i = self.app.control.inspect() + workers = i.ping() + if workers is None: + return [] + return [*workers] + + def get_active_workers(self) -> Dict[str, List[str]]: """ - Query the status of Celery workers. + Retrieve a mapping of active workers to their associated queues for a Celery application. + + This function serves as the inverse of + [`get_active_celery_queues()`][study.celeryadapter.get_active_celery_queues]. It constructs + a dictionary where each key is a worker's name and the corresponding value is a + list of queues that the worker is connected to. This allows for easy identification + of which queues are being handled by each worker. + + Returns: + A dictionary mapping active worker names to lists of queue names they are + attached to. If no active workers are found, an empty dictionary is returned. """ + # Get the information we need from celery + i = self.app.control.inspect() + active_workers = i.active_queues() + if active_workers is None: + active_workers = {} + + # Build the mapping dictionary + worker_queue_map = {} + for worker, queues in active_workers.items(): + for queue in queues: + if worker in worker_queue_map: + worker_queue_map[worker].append(queue["name"]) + else: + worker_queue_map[worker] = [queue["name"]] + + return worker_queue_map + + def _build_filters(self, queues: List[str], workers: List[str]) -> Dict[str, List[str]]: + """ + Build filters dictionary for database queries. + + Args: + queues: List of queue names to filter by. + workers: List of worker names to filter by. + + Returns: + Dictionary containing filter criteria. + """ + filters = {} + if queues: + filters["queues"] = [queue if queue.startswith("[merlin]_") else f"[merlin]_{queue}" for queue in queues] + if workers: + filters["name"] = workers + return filters + + def _validate_worker_status(self, logical_workers: List[LogicalWorkerEntity]): + """ + Cross-check database state with live Celery workers. + Update status for workers that are actually dead but marked running. + + Args: + logical_workers: List of logical worker entities to validate. + """ + # Get actual running workers from Celery + live_workers = self.get_active_workers() # Uses Celery inspection + + for logical_worker in logical_workers: + physical_ids = logical_worker.get_physical_workers() + for pid in physical_ids: + physical = self.merlin_db.get("physical_worker", pid) + + # If database says running but Celery doesn't know about it + if physical.get_worker_status() == WorkerStatus.RUNNING: + worker_name = physical.get_name() + if worker_name not in live_workers: + # Mark as stalled in database + LOG.warning(f"Worker {worker_name} marked running but not found in Celery") + physical.set_worker_status(WorkerStatus.STALLED) + + def query_workers( + self, + formatter: str, + queues: List[str] = None, + workers: List[str] = None, + local_db: bool = False, + ): + """ + Query the status of Celery workers and display using the configured formatter. + + Args: + formatter: The worker formatter to use (rich or json). + queues: List of queue names to filter by (optional). + workers: List of worker names to filter by (optional). + local_db: Whether to use the local database for querying (optional). + """ + # Build filters dictionary + filters = self._build_filters(queues, workers) + + # Retrieve workers from database + logical_workers = self.merlin_db.get_all("logical_worker", filters=filters) + + # Validate/enrich with live Celery data + if not local_db: + self._validate_worker_status(logical_workers) + + # Use formatter to display the results + formatter = worker_formatter_factory.create(formatter) + formatter.format_and_display(logical_workers, filters, self.merlin_db) diff --git a/merlin/workers/handlers/worker_handler.py b/merlin/workers/handlers/worker_handler.py index 5c6d5ab1c..03df36ad2 100644 --- a/merlin/workers/handlers/worker_handler.py +++ b/merlin/workers/handlers/worker_handler.py @@ -13,8 +13,9 @@ """ from abc import ABC, abstractmethod -from typing import Any, List +from typing import List +from merlin.db_scripts.merlin_db import MerlinDatabase from merlin.workers.worker import MerlinWorker @@ -25,14 +26,23 @@ class MerlinWorkerHandler(ABC): Subclasses must implement the methods to launch, stop, and query workers using a particular task server (e.g., Celery, Kafka, etc.). + Attributes: + merlin_db (MerlinDatabase): The database instance used for worker management. + Methods: start_workers: Launch a list of MerlinWorker instances with optional configuration. stop_workers: Stop running worker processes managed by this handler. query_workers: Query the status of running workers and return summary information. """ - def __init__(self): - """Initialize the worker handler.""" + def __init__(self, merlin_db: MerlinDatabase = None): + """ + Initialize the worker handler. + + Args: + merlin_db: The database instance used for worker management or None. + """ + self.merlin_db = merlin_db or MerlinDatabase() @abstractmethod def start_workers(self, workers: List[MerlinWorker], **kwargs): @@ -55,12 +65,14 @@ def stop_workers(self): raise NotImplementedError("Subclasses of `MerlinWorkerHandler` must implement a `stop_workers` method.") @abstractmethod - def query_workers(self) -> Any: + def query_workers(self, formatter: str, queues: List[str] = None, workers: List[str] = None, local_db: bool = False): """ Query the status of all currently running workers. - Returns: - Subclasses should return an appropriate data structure summarizing - the current state of managed workers (e.g., dict, list, string). + Args: + formatter: The worker formatter to use (rich or json). + queues: List of queue names to filter by (optional). + workers: List of worker names to filter by (optional). + local_db: Whether to use the local database for querying (optional). """ raise NotImplementedError("Subclasses of `MerlinWorkerHandler` must implement a `query_workers` method.") diff --git a/tests/conftest.py b/tests/conftest.py index 5a40e0cd0..e6131154e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,11 +13,13 @@ from glob import glob from time import sleep from typing import Dict +from unittest.mock import MagicMock import pytest import yaml from _pytest.tmpdir import TempPathFactory from celery import Celery +from pytest_mock import MockerFixture from redis import Redis from merlin.config.configfile import CONFIG @@ -211,6 +213,40 @@ def merlin_server_dir(temp_output_dir: FixtureStr) -> FixtureStr: return server_dir +@pytest.fixture +def mock_db_class(mocker: MockerFixture) -> MagicMock: + """ + Mock MerlinDatabase globally for all tests. + + This fixture mocks MerlinDatabase at its source, so all imports + across the codebase will use this mock. + + Args: + mocker: Pytest mocker fixture. + + Returns: + A mocked MerlinDatabase class. + """ + mock_db_class = mocker.patch("merlin.db_scripts.merlin_db.MerlinDatabase", autospec=True) + return mock_db_class + + +@pytest.fixture +def mock_db_instance(mock_db_class: MagicMock) -> MagicMock: + """ + Returns a mocked instance of MerlinDatabase. + + Use this when you need an instance rather than the class itself. + + Args: + mock_db_class: The mocked MerlinDatabase class. + + Returns: + A mocked MerlinDatabase instance. + """ + return mock_db_class.return_value + + @pytest.fixture(scope="session") def redis_server(merlin_server_dir: FixtureStr, test_encryption_key: FixtureBytes) -> FixtureStr: """ diff --git a/tests/integration/commands/test_monitor.py b/tests/integration/commands/test_monitor.py index 0173a6cba..649e62775 100644 --- a/tests/integration/commands/test_monitor.py +++ b/tests/integration/commands/test_monitor.py @@ -130,7 +130,12 @@ def test_auto_restart( f"merlin purge -f {monitor_setup.auto_restart_yaml}".split(), capture_output=True, text=True ) - monitor_stdout, monitor_stderr = monitor_proc.communicate() + # Obtain stdout and stderr from the monitor process + try: + monitor_stdout, monitor_stderr = monitor_proc.communicate(timeout=30) + except subprocess.TimeoutExpired: + monitor_proc.kill() + monitor_stdout, monitor_stderr = monitor_proc.communicate() # Define our test conditions study_name = "monitor_auto_restart_test" diff --git a/tests/integration/commands/test_stop_and_query_workers.py b/tests/integration/commands/test_stop_workers.py similarity index 64% rename from tests/integration/commands/test_stop_and_query_workers.py rename to tests/integration/commands/test_stop_workers.py index a2f5d3376..3fb40a5de 100644 --- a/tests/integration/commands/test_stop_and_query_workers.py +++ b/tests/integration/commands/test_stop_workers.py @@ -5,8 +5,7 @@ ############################################################################## """ -This module will contain the testing logic for -the `stop-workers` and `query-workers` commands. +This module will contain the testing logic for the `stop-workers` command. """ import os @@ -15,8 +14,6 @@ from enum import Enum from typing import List -import pytest - from tests.context_managers.celery_workers_manager import CeleryWorkersManager from tests.fixture_data_classes import RedisBrokerAndBackend from tests.fixture_types import FixtureStr @@ -35,15 +32,14 @@ class WorkerMessages(Enum): """ NO_WORKERS_MSG_STOP = "No workers found to stop" - NO_WORKERS_MSG_QUERY = "No workers found!" STEP_1_WORKER = "step_1_merlin_test_worker" STEP_2_WORKER = "step_2_merlin_test_worker" OTHER_WORKER = "other_merlin_test_worker" -class TestStopAndQueryWorkersCommands: +class TestStopWorkersCommands: """ - Tests for the `merlin stop-workers` and `merlin query-workers` commands. + Tests for the `merlin stop-workers` command. Most of these tests will: 1. Start workers from a spec file used for testing - Use CeleryWorkerManager for this to ensure safe stoppage of workers @@ -57,7 +53,6 @@ def run_test_with_workers( # pylint: disable=too-many-arguments path_to_test_specs: FixtureStr, merlin_server_dir: FixtureStr, conditions: List[Condition], - command: str, flag: str = None, ): """ @@ -90,8 +85,6 @@ def run_test_with_workers( # pylint: disable=too-many-arguments conditions: A list of `Condition` instances that need to pass in order for this test to be successful. - command: - The command that we're testing. E.g. "merlin stop-workers" flag: An optional flag to add to the command that we're testing so we can test different functionality for the command. @@ -110,8 +103,10 @@ def run_test_with_workers( # pylint: disable=too-many-arguments copy_app_yaml_to_cwd(merlin_server_dir) # Run the test - cmd_to_test = f"{command} {flag}" if flag else command - result = subprocess.run(cmd_to_test, capture_output=True, text=True, shell=True) + command = "merlin stop-workers" + if flag: + command += f" {flag}" + result = subprocess.run(command, capture_output=True, text=True, shell=True) info = { "stdout": result.stdout, @@ -124,34 +119,13 @@ def run_test_with_workers( # pylint: disable=too-many-arguments yield - def get_no_workers_msg(self, command_to_test: str) -> WorkerMessages: - """ - Retrieve the appropriate "no workers" found message. - - This method checks the command to test and returns a corresponding - message based on whether the command is to stop workers or query for them. - - Returns: - The message indicating that no workers are available, depending on the - command being tested. - """ - no_workers_msg = None - if command_to_test == "merlin stop-workers": - no_workers_msg = WorkerMessages.NO_WORKERS_MSG_STOP.value - else: - no_workers_msg = WorkerMessages.NO_WORKERS_MSG_QUERY.value - return no_workers_msg - - @pytest.mark.parametrize("command_to_test", ["merlin stop-workers", "merlin query-workers"]) def test_no_workers( self, redis_broker_and_backend_function: RedisBrokerAndBackend, merlin_server_dir: FixtureStr, - command_to_test: str, ): """ - Test the `merlin stop-workers` and `merlin query-workers` commands with no workers - started in the first place. + Test the `merlin stop-workers` command with no workers started in the first place. This test will: 0. Setup the pytest fixtures which include: @@ -170,11 +144,9 @@ def test_no_workers( merlin_server_dir: A fixture to provide the path to the merlin_server directory that will be created by the `redis_server` fixture. - command_to_test: - The command that we're testing, obtained from the parametrize call. """ conditions = [ - HasRegex(self.get_no_workers_msg(command_to_test)), + HasRegex(WorkerMessages.NO_WORKERS_MSG_STOP.value), HasRegex(WorkerMessages.STEP_1_WORKER.value, negate=True), HasRegex(WorkerMessages.STEP_2_WORKER.value, negate=True), HasRegex(WorkerMessages.OTHER_WORKER.value, negate=True), @@ -184,7 +156,7 @@ def test_no_workers( copy_app_yaml_to_cwd(merlin_server_dir) # Run the test - result = subprocess.run(command_to_test, capture_output=True, text=True, shell=True) + result = subprocess.run("merlin stop-workers", capture_output=True, text=True, shell=True) info = { "stdout": result.stdout, "stderr": result.stderr, @@ -194,19 +166,16 @@ def test_no_workers( # Ensure all test conditions are satisfied check_test_conditions(conditions, info) - @pytest.mark.parametrize("command_to_test", ["merlin stop-workers", "merlin query-workers"]) def test_no_flags( self, redis_broker_and_backend_function: RedisBrokerAndBackend, path_to_test_specs: FixtureStr, merlin_server_dir: FixtureStr, - command_to_test: str, ): """ - Test the `merlin stop-workers` and `merlin query-workers` commands with no flags. + Test the `merlin stop-workers` command with no flags. - Run the commands referenced above and ensure the text output from Merlin is correct. - For the `stop-workers` command, we check if all workers are stopped as well. + Run the command and ensure the text output from Merlin is correct. To see more information on exactly what this test is doing, see the `run_test_with_workers()` method. @@ -218,39 +187,32 @@ def test_no_flags( merlin_server_dir: A fixture to provide the path to the merlin_server directory that will be created by the `redis_server` fixture. - command_to_test: - The command that we're testing, obtained from the parametrize call. """ conditions = [ - HasRegex(self.get_no_workers_msg(command_to_test), negate=True), + HasRegex(WorkerMessages.NO_WORKERS_MSG_STOP.value, negate=True), HasRegex(WorkerMessages.STEP_1_WORKER.value), HasRegex(WorkerMessages.STEP_2_WORKER.value), HasRegex(WorkerMessages.OTHER_WORKER.value), ] - with self.run_test_with_workers(path_to_test_specs, merlin_server_dir, conditions, command_to_test): - if command_to_test == "merlin stop-workers": - # After the test runs and before the CeleryWorkersManager exits, ensure there are no workers on the app - from merlin.celery import app as celery_app + with self.run_test_with_workers(path_to_test_specs, merlin_server_dir, conditions): + # After the test runs and before the CeleryWorkersManager exits, ensure there are no workers on the app + from merlin.celery import app as celery_app - active_queues = celery_app.control.inspect().active_queues() - assert active_queues is None + active_queues = celery_app.control.inspect().active_queues() + assert active_queues is None - @pytest.mark.parametrize("command_to_test", ["merlin stop-workers", "merlin query-workers"]) def test_spec_flag( self, redis_broker_and_backend_function: RedisBrokerAndBackend, path_to_test_specs: FixtureStr, merlin_server_dir: FixtureStr, - command_to_test: str, ): """ - Test the `merlin stop-workers` and `merlin query-workers` commands with the `--spec` - flag. + Test the `merlin stop-workers` command with the `--spec` flag. - Run the commands referenced above with the `--spec` flag and ensure the text output - from Merlin is correct. For the `stop-workers` command, we check if all workers defined - in the spec file are stopped as well. To see more information on exactly what this test - is doing, see the `run_test_with_workers()` method. + Run the command with the `--spec` flag and ensure the text output + from Merlin is correct. To see more information on exactly what this + test is doing, see the `run_test_with_workers()` method. Parameters: redis_broker_and_backend_function: Fixture for setting up Redis broker and @@ -260,11 +222,9 @@ def test_spec_flag( merlin_server_dir: A fixture to provide the path to the merlin_server directory that will be created by the `redis_server` fixture. - command_to_test: - The command that we're testing, obtained from the parametrize call. """ conditions = [ - HasRegex(self.get_no_workers_msg(command_to_test), negate=True), + HasRegex(WorkerMessages.NO_WORKERS_MSG_STOP.value, negate=True), HasRegex(WorkerMessages.STEP_1_WORKER.value), HasRegex(WorkerMessages.STEP_2_WORKER.value), HasRegex(WorkerMessages.OTHER_WORKER.value), @@ -273,30 +233,24 @@ def test_spec_flag( path_to_test_specs, merlin_server_dir, conditions, - command_to_test, flag=f"--spec {os.path.join(path_to_test_specs, 'multiple_workers.yaml')}", ): - if command_to_test == "merlin stop-workers": - from merlin.celery import app as celery_app + from merlin.celery import app as celery_app - active_queues = celery_app.control.inspect().active_queues() - assert active_queues is None + active_queues = celery_app.control.inspect().active_queues() + assert active_queues is None - @pytest.mark.parametrize("command_to_test", ["merlin stop-workers", "merlin query-workers"]) def test_workers_flag( self, redis_broker_and_backend_function: RedisBrokerAndBackend, path_to_test_specs: FixtureStr, merlin_server_dir: FixtureStr, - command_to_test: str, ): """ - Test the `merlin stop-workers` and `merlin query-workers` commands with the `--workers` - flag. + Test the `merlin stop-workers` command with the `--workers` flag. - Run the commands referenced above with the `--workers` flag and ensure the text output - from Merlin is correct. For the `stop-workers` command, we check to make sure that all - workers given with this flag are stopped. To see more information on exactly what this + Run the command with the `--workers` flag and ensure the text output + from Merlin is correct. To see more information on exactly what this test is doing, see the `run_test_with_workers()` method. Parameters: @@ -307,11 +261,9 @@ def test_workers_flag( merlin_server_dir: A fixture to provide the path to the merlin_server directory that will be created by the `redis_server` fixture. - command_to_test: - The command that we're testing, obtained from the parametrize call. """ conditions = [ - HasRegex(self.get_no_workers_msg(command_to_test), negate=True), + HasRegex(WorkerMessages.NO_WORKERS_MSG_STOP.value, negate=True), HasRegex(WorkerMessages.STEP_1_WORKER.value), HasRegex(WorkerMessages.STEP_2_WORKER.value), HasRegex(WorkerMessages.OTHER_WORKER.value, negate=True), @@ -320,31 +272,25 @@ def test_workers_flag( path_to_test_specs, merlin_server_dir, conditions, - command_to_test, flag=f"--workers {WorkerMessages.STEP_1_WORKER.value} {WorkerMessages.STEP_2_WORKER.value}", ): - if command_to_test == "merlin stop-workers": - from merlin.celery import app as celery_app + from merlin.celery import app as celery_app - active_queues = celery_app.control.inspect().active_queues() - worker_name = f"celery@{WorkerMessages.OTHER_WORKER.value}" - assert worker_name in active_queues + active_queues = celery_app.control.inspect().active_queues() + worker_name = f"celery@{WorkerMessages.OTHER_WORKER.value}" + assert worker_name in active_queues - @pytest.mark.parametrize("command_to_test", ["merlin stop-workers", "merlin query-workers"]) def test_queues_flag( self, redis_broker_and_backend_function: RedisBrokerAndBackend, path_to_test_specs: FixtureStr, merlin_server_dir: FixtureStr, - command_to_test: str, ): """ - Test the `merlin stop-workers` and `merlin query-workers` commands with the `--queues` - flag. + Test the `merlin stop-workers` command with the `--queues` flag. - Run the commands referenced above with the `--queues` flag and ensure the text output - from Merlin is correct. For the `stop-workers` command, we check that only the workers - attached to the given queues are stopped. To see more information on exactly what this + Run the command with the `--queues` flag and ensure the text output + from Merlin is correct. To see more information on exactly what this test is doing, see the `run_test_with_workers()` method. Parameters: @@ -355,11 +301,9 @@ def test_queues_flag( merlin_server_dir: A fixture to provide the path to the merlin_server directory that will be created by the `redis_server` fixture. - command_to_test: - The command that we're testing, obtained from the parametrize call. """ conditions = [ - HasRegex(self.get_no_workers_msg(command_to_test), negate=True), + HasRegex(WorkerMessages.NO_WORKERS_MSG_STOP.value, negate=True), HasRegex(WorkerMessages.STEP_1_WORKER.value), HasRegex(WorkerMessages.STEP_2_WORKER.value, negate=True), HasRegex(WorkerMessages.OTHER_WORKER.value, negate=True), @@ -368,19 +312,17 @@ def test_queues_flag( path_to_test_specs, merlin_server_dir, conditions, - command_to_test, flag="--queues hello_queue", ): - if command_to_test == "merlin stop-workers": - from merlin.celery import app as celery_app - - active_queues = celery_app.control.inspect().active_queues() - workers_that_should_be_alive = [ - f"celery@{WorkerMessages.OTHER_WORKER.value}", - f"celery@{WorkerMessages.STEP_2_WORKER.value}", - ] - for worker_name in workers_that_should_be_alive: - assert worker_name in active_queues + from merlin.celery import app as celery_app + + active_queues = celery_app.control.inspect().active_queues() + workers_that_should_be_alive = [ + f"celery@{WorkerMessages.OTHER_WORKER.value}", + f"celery@{WorkerMessages.STEP_2_WORKER.value}", + ] + for worker_name in workers_that_should_be_alive: + assert worker_name in active_queues # pylint: enable=unused-argument,import-outside-toplevel diff --git a/tests/integration/definitions.py b/tests/integration/definitions.py index 671be7981..0ece8f5e5 100644 --- a/tests/integration/definitions.py +++ b/tests/integration/definitions.py @@ -314,7 +314,7 @@ def define_tests(): # pylint: disable=R0914,R0915 }, "default_worker assigned": { "cmds": f"{workers} {test_specs}/default_worker_test.yaml --echo", - "conditions": [HasReturnCode(), HasRegex(r"default_worker.*-Q '\[merlin\]_step_4_queue'")], + "conditions": [HasReturnCode(), HasRegex(r"default_worker.*-Q .*step_4_queue")], "run type": "local", }, "no default_worker assigned": { diff --git a/tests/unit/abstracts/test_factory.py b/tests/unit/abstracts/test_factory.py index a92b14819..2147e5111 100644 --- a/tests/unit/abstracts/test_factory.py +++ b/tests/unit/abstracts/test_factory.py @@ -163,17 +163,15 @@ def test_get_component_info_for_invalid_component(self, factory: TestableFactory ): # raises RuntimeError because of `_raise_component_error_class` factory.get_component_info("not_registered") - def test_discover_plugins_calls_both_hooks(self, mocker: MockerFixture, factory: TestableFactory): + def test_discover_plugins_calls_hooks(self, mocker: MockerFixture, factory: TestableFactory): """ - Test that _discover_plugins calls both plugin and module hooks. + Test that _discover_plugins calls plugin hooks. Args: mocker: PyTest mocker fixture. factory: An instance of the dummy `TestableFactory` class for testing. """ plugin_mock = mocker.patch.object(factory, "_discover_plugins_via_entry_points") - builtin_mock = mocker.patch.object(factory, "_discover_builtin_modules") factory._discover_plugins() plugin_mock.assert_called_once() - builtin_mock.assert_called_once() diff --git a/tests/unit/cli/commands/test_query_workers.py b/tests/unit/cli/commands/test_query_workers.py index 09152b3c2..4a9e39091 100644 --- a/tests/unit/cli/commands/test_query_workers.py +++ b/tests/unit/cli/commands/test_query_workers.py @@ -38,29 +38,38 @@ def test_add_parser_sets_up_query_workers_command(create_parser: FixtureCallable def test_process_command_without_spec(mocker: MockerFixture): """ - Ensure `process_command` calls `query_workers` directly if no spec is provided. + Ensure `process_command` calls worker_handler.query_workers if no spec is provided. Args: mocker: PyTest mocker fixture. """ - query_workers_mock = mocker.patch("merlin.cli.commands.query_workers.query_workers") + worker_handler_mock = mocker.Mock() + create_mock = mocker.patch( + "merlin.cli.commands.query_workers.worker_handler_factory.create", + return_value=worker_handler_mock, + ) args = Namespace( task_server="celery", spec=None, queues=["q1", "q2"], workers=["worker1", "worker2"], + format="rich", + local_db=False, ) cmd = QueryWorkersCommand() cmd.process_command(args) - query_workers_mock.assert_called_once_with("celery", [], ["q1", "q2"], ["worker1", "worker2"]) + create_mock.assert_called_once_with("celery") + worker_handler_mock.query_workers.assert_called_once_with( + "rich", queues=["q1", "q2"], workers=["worker1", "worker2"], local_db=False + ) def test_process_command_with_spec(mocker: MockerFixture, caplog: CaptureFixture): """ - Ensure `process_command` loads worker names from spec and passes them to `query_workers`. + Ensure `process_command` loads worker names from spec and passes them to worker_handler.query_workers. Args: mocker: PyTest mocker fixture. @@ -70,22 +79,31 @@ def test_process_command_with_spec(mocker: MockerFixture, caplog: CaptureFixture mock_spec = mocker.Mock() mock_spec.get_worker_names.return_value = ["foo", "bar"] + mock_spec.merlin = {"resources": {"task_server": "celery"}} mocker.patch("merlin.cli.commands.query_workers.verify_filepath", return_value="some/path/spec.yaml") mocker.patch("merlin.cli.commands.query_workers.MerlinSpec.load_specification", return_value=mock_spec) - query_workers_mock = mocker.patch("merlin.cli.commands.query_workers.query_workers") + + worker_handler_mock = mocker.Mock() + create_mock = mocker.patch( + "merlin.cli.commands.query_workers.worker_handler_factory.create", + return_value=worker_handler_mock, + ) args = Namespace( - task_server="celery", + task_server="ignored", spec="workflow.yaml", queues=None, workers=None, + format="rich", + local_db=False, ) cmd = QueryWorkersCommand() cmd.process_command(args) - query_workers_mock.assert_called_once_with("celery", ["foo", "bar"], None, None) + create_mock.assert_called_once_with("celery") + worker_handler_mock.query_workers.assert_called_once_with("rich", queues=None, workers=["foo", "bar"], local_db=False) assert "Searching for the following workers to stop" in caplog.text @@ -101,20 +119,30 @@ def test_process_command_logs_warning_for_unexpanded_worker(mocker: MockerFixtur mock_spec = mocker.Mock() mock_spec.get_worker_names.return_value = ["$ENV_VAR", "actual_worker"] + mock_spec.merlin = {"resources": {"task_server": "celery"}} mocker.patch("merlin.cli.commands.query_workers.verify_filepath", return_value="workflow.yaml") mocker.patch("merlin.cli.commands.query_workers.MerlinSpec.load_specification", return_value=mock_spec) - query_workers_mock = mocker.patch("merlin.cli.commands.query_workers.query_workers") + + worker_handler_mock = mocker.Mock() + mocker.patch( + "merlin.cli.commands.query_workers.worker_handler_factory.create", + return_value=worker_handler_mock, + ) args = Namespace( - task_server="celery", + task_server="ignored", spec="workflow.yaml", queues=None, workers=None, + format="rich", + local_db=False, ) cmd = QueryWorkersCommand() cmd.process_command(args) assert "Worker '$ENV_VAR' is unexpanded. Target provenance spec instead?" in caplog.text - query_workers_mock.assert_called_once_with("celery", ["$ENV_VAR", "actual_worker"], None, None) + worker_handler_mock.query_workers.assert_called_once_with( + "rich", queues=None, workers=["$ENV_VAR", "actual_worker"], local_db=False + ) diff --git a/tests/unit/cli/commands/test_run_workers.py b/tests/unit/cli/commands/test_run_workers.py index bd784f8d7..9d3d8b937 100644 --- a/tests/unit/cli/commands/test_run_workers.py +++ b/tests/unit/cli/commands/test_run_workers.py @@ -9,6 +9,7 @@ """ from argparse import Namespace +from unittest.mock import MagicMock from _pytest.capture import CaptureFixture from pytest_mock import MockerFixture @@ -78,13 +79,16 @@ def test_process_command_launches_workers(mocker: MockerFixture): mock_log.info.assert_called_once_with("Launching workers from 'workflow.yaml'") -def test_process_command_echo_only_mode_prints_command(mocker: MockerFixture, capsys: CaptureFixture): +def test_process_command_echo_only_mode_prints_command( + mocker: MockerFixture, capsys: CaptureFixture, mock_db_instance: MagicMock +): """ Test `process_command` prints the launch command and initializes config in echo-only mode. Args: mocker: PyTest mocker fixture. capsys: PyTest capsys fixture. + mock_db_instance: Mocked MerlinDatabase instance. """ mock_spec = mocker.Mock() mock_spec.get_workers_to_start.return_value = ["workerB"] @@ -97,7 +101,11 @@ def test_process_command_echo_only_mode_prints_command(mocker: MockerFixture, ca mocker.patch("merlin.cli.commands.run_workers.get_merlin_spec_with_override", return_value=(mock_spec, "file.yaml")) mocker.patch("merlin.cli.commands.run_workers.initialize_config") - mocker.patch("merlin.cli.commands.run_workers.worker_handler_factory.create", wraps=lambda _: CeleryWorkerHandler()) + mock_app = mocker.patch("merlin.celery.app") + mocker.patch( + "merlin.cli.commands.run_workers.worker_handler_factory.create", + wraps=lambda _: CeleryWorkerHandler(merlin_db=mock_db_instance, app=mock_app), + ) args = Namespace( specification="spec.yaml", diff --git a/tests/unit/common/test_encryption.py b/tests/unit/common/test_encryption.py index 30aad87a5..6ac8d329c 100644 --- a/tests/unit/common/test_encryption.py +++ b/tests/unit/common/test_encryption.py @@ -10,8 +10,8 @@ import os -import celery import pytest +from pytest_mock import MockerFixture from merlin.common.security.encrypt import _gen_key, _get_key, _get_key_path, decrypt, encrypt from merlin.common.security.encrypt_backend_traffic import _decrypt_decode, _encrypt_encode, set_backend_funcs @@ -124,22 +124,31 @@ def test_get_key( with open(key_path, "w") as key_file: key_file.write(test_encryption_key.decode("utf-8")) - def test_set_backend_funcs(self): + def test_set_backend_funcs(self, mocker: MockerFixture): """ Test the `set_backend_funcs` function. + + Args: + mocker: Pytest mocker fixture. """ - orig_encode = celery.backends.base.Backend.encode - orig_decode = celery.backends.base.Backend.decode + # Mock the Backend class to ensure clean state + mock_backend = mocker.patch("celery.backends.base.Backend") + + # Set up mock encode/decode attributes + mock_backend.encode = mocker.MagicMock() + mock_backend.decode = mocker.MagicMock() - # Make sure these values haven't been set yet - assert celery.backends.base.Backend.encode != _encrypt_encode - assert celery.backends.base.Backend.decode != _decrypt_decode + # Store original values + orig_encode = mock_backend.encode + orig_decode = mock_backend.decode + # Call the function set_backend_funcs() - # Ensure the new functions have been set - assert celery.backends.base.Backend.encode == _encrypt_encode - assert celery.backends.base.Backend.decode == _decrypt_decode + # Verify the functions were replaced + assert mock_backend.encode == _encrypt_encode + assert mock_backend.decode == _decrypt_decode - celery.backends.base.Backend.encode = orig_encode - celery.backends.base.Backend.decode = orig_decode + # Verify they're different from the originals + assert mock_backend.encode != orig_encode + assert mock_backend.decode != orig_decode diff --git a/tests/unit/db_scripts/entities/test_physical_worker_entity.py b/tests/unit/db_scripts/entities/test_physical_worker_entity.py index dcc3d03ff..92679f5e1 100644 --- a/tests/unit/db_scripts/entities/test_physical_worker_entity.py +++ b/tests/unit/db_scripts/entities/test_physical_worker_entity.py @@ -213,7 +213,7 @@ def test_set_worker_status(self, worker_entity: PhysicalWorkerEntity, mock_backe """ new_status = WorkerStatus.STOPPED worker_entity.set_worker_status(new_status) - assert worker_entity.entity_info.worker_status == new_status + assert worker_entity.entity_info.worker_status == new_status.value mock_backend.save.assert_called_once() def test_get_heartbeat_timestamp(self, worker_entity: PhysicalWorkerEntity, mock_model: MagicMock): diff --git a/tests/unit/db_scripts/test_data_models.py b/tests/unit/db_scripts/test_data_models.py index a3c7a7737..93e4b4ea9 100644 --- a/tests/unit/db_scripts/test_data_models.py +++ b/tests/unit/db_scripts/test_data_models.py @@ -512,7 +512,7 @@ def test_default_initialization(self): assert worker.launch_cmd is None assert worker.args == {} assert worker.pid is None - assert worker.worker_status == WorkerStatus.STOPPED + assert worker.worker_status == WorkerStatus.STOPPED.value assert isinstance(worker.heartbeat_timestamp, datetime) assert isinstance(worker.latest_start_time, datetime) assert worker.host is None diff --git a/tests/unit/monitor/test_celery_monitor.py b/tests/unit/monitor/test_celery_monitor.py index b82d98de9..5c8032910 100644 --- a/tests/unit/monitor/test_celery_monitor.py +++ b/tests/unit/monitor/test_celery_monitor.py @@ -19,14 +19,19 @@ @pytest.fixture -def monitor() -> CeleryMonitor: +def monitor(mocker: MockerFixture, mock_db_instance: MagicMock) -> CeleryMonitor: """ Fixture to provide a CeleryMonitor instance. + Args: + mocker: Pytest mocker fixture. + mock_db_instance: Mocked MerlinDatabase instance. + Returns: An instance of the `CeleryMonitor` object. """ - return CeleryMonitor() + mock_app = mocker.patch("merlin.celery.Celery") + return CeleryMonitor(merlin_db=mock_db_instance, app=mock_app) def test_wait_for_workers_success(mocker: MockerFixture, monitor: CeleryMonitor): @@ -37,7 +42,7 @@ def test_wait_for_workers_success(mocker: MockerFixture, monitor: CeleryMonitor) mocker: PyTest mocker fixture. monitor: An instance of the `CeleryMonitor` object. """ - mock_get_workers = mocker.patch("merlin.monitor.celery_monitor.get_workers_from_app", return_value=["worker1@node"]) + mock_get_workers = monitor.worker_handler.get_workers_from_app = MagicMock(return_value=["worker1@node"]) monitor.wait_for_workers(["worker1"], sleep=1) @@ -52,7 +57,7 @@ def test_wait_for_workers_timeout(mocker: MockerFixture, monitor: CeleryMonitor) mocker: PyTest mocker fixture. monitor: An instance of the `CeleryMonitor` object. """ - mocker.patch("merlin.monitor.celery_monitor.get_workers_from_app", return_value=[]) + monitor.worker_handler.get_workers_from_app = MagicMock(return_value=[]) mocker.patch("time.sleep") with pytest.raises(NoWorkersException): diff --git a/tests/unit/monitor/test_monitor_factory.py b/tests/unit/monitor/test_monitor_factory.py index ebd4934f7..e995281cb 100644 --- a/tests/unit/monitor/test_monitor_factory.py +++ b/tests/unit/monitor/test_monitor_factory.py @@ -8,7 +8,10 @@ Tests for the `monitor_factory.py` module. """ +from unittest.mock import MagicMock + import pytest +from pytest_mock import MockerFixture from merlin.exceptions import MerlinInvalidTaskServerError from merlin.monitor.celery_monitor import CeleryMonitor @@ -94,14 +97,16 @@ def test_list_available_monitors(self, monitor_factory: MonitorFactory): assert "celery" in available assert len(available) == 1 - def test_create_valid_monitor(self, monitor_factory: MonitorFactory): + def test_create_valid_monitor(self, mocker: MockerFixture, monitor_factory: MonitorFactory, mock_db_instance: MagicMock): """ Test that `create` instantiates a monitor for a valid task server. Args: + mocker: Pytest mocker fixture. monitor_factory: Instance of `MonitorFactory` for testing. + mock_db_instance: Mocked MerlinDatabase instance. """ - monitor = monitor_factory.create("celery") + monitor = monitor_factory.create("celery", {"merlin_db": mock_db_instance}) assert isinstance(monitor, CeleryMonitor) def test_create_invalid_monitor_raises(self, monitor_factory: MonitorFactory): diff --git a/tests/unit/workers/formatters/test_formatter_factory.py b/tests/unit/workers/formatters/test_formatter_factory.py new file mode 100644 index 000000000..07c247a74 --- /dev/null +++ b/tests/unit/workers/formatters/test_formatter_factory.py @@ -0,0 +1,286 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Tests for the `merlin/workers/formatters/formatter_factory.py` module. +""" + +from typing import Dict, List + +import pytest +from pytest_mock import MockerFixture + +from merlin.db_scripts.merlin_db import MerlinDatabase +from merlin.exceptions import MerlinWorkerFormatterNotSupportedError +from merlin.workers.formatters.formatter_factory import WorkerFormatterFactory +from merlin.workers.formatters.worker_formatter import WorkerFormatter + + +class DummyJSONFormatter(WorkerFormatter): + """Dummy JSON formatter implementation for testing.""" + + def __init__(self, *args, **kwargs): + pass + + def format_and_display(self, logical_workers: List, filters: Dict, merlin_db: MerlinDatabase): + return f"JSON formatted {len(logical_workers)} workers" + + +class DummyRichFormatter(WorkerFormatter): + """Dummy Rich formatter implementation for testing.""" + + def __init__(self, *args, **kwargs): + pass + + def format_and_display(self, logical_workers: List, filters: Dict, merlin_db: MerlinDatabase): + return f"Rich formatted {len(logical_workers)} workers" + + +class DummyCSVFormatter(WorkerFormatter): + """Dummy CSV formatter implementation for testing.""" + + def __init__(self, *args, **kwargs): + pass + + def format_and_display(self, logical_workers: List, filters: Dict, merlin_db: MerlinDatabase): + return f"CSV formatted {len(logical_workers)} workers" + + +class TestWorkerFormatterFactory: + """ + Test suite for the `WorkerFormatterFactory`. + + This class verifies that the factory properly registers, validates, instantiates, + and handles Merlin worker formatters. It mocks built-ins for test isolation. + """ + + @pytest.fixture + def formatter_factory(self, mocker: MockerFixture) -> WorkerFormatterFactory: + """ + A fixture that returns a fresh instance of `WorkerFormatterFactory` with built-in formatters patched. + + Args: + mocker: PyTest mocker fixture. + + Returns: + A factory instance with mocked formatter classes. + """ + mocker.patch("merlin.workers.formatters.formatter_factory.JSONWorkerFormatter", DummyJSONFormatter) + mocker.patch("merlin.workers.formatters.formatter_factory.RichWorkerFormatter", DummyRichFormatter) + return WorkerFormatterFactory() + + def test_list_available_formatters(self, formatter_factory: WorkerFormatterFactory): + """ + Test that `list_available` returns the expected built-in formatter names. + + Args: + formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + """ + available = formatter_factory.list_available() + assert set(available) == {"json", "rich"} + + def test_create_valid_formatter(self, formatter_factory: WorkerFormatterFactory): + """ + Test that `create` returns a valid formatter instance for a registered name. + + Args: + formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + """ + json_instance = formatter_factory.create("json") + assert isinstance(json_instance, DummyJSONFormatter) + + rich_instance = formatter_factory.create("rich") + assert isinstance(rich_instance, DummyRichFormatter) + + def test_create_valid_formatter_with_alias(self, formatter_factory: WorkerFormatterFactory): + """ + Test that aliases are resolved to canonical formatter names. + + Args: + formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + """ + formatter_factory.register("csv", DummyCSVFormatter, aliases=["comma", "spreadsheet"]) + + instance_by_name = formatter_factory.create("csv") + instance_by_alias = formatter_factory.create("comma") + instance_by_alias2 = formatter_factory.create("spreadsheet") + + assert isinstance(instance_by_name, DummyCSVFormatter) + assert isinstance(instance_by_alias, DummyCSVFormatter) + assert isinstance(instance_by_alias2, DummyCSVFormatter) + + def test_create_invalid_formatter_raises(self, formatter_factory: WorkerFormatterFactory): + """ + Test that `create` raises `MerlinWorkerFormatterNotSupportedError` for unknown formatter types. + + Args: + formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + """ + with pytest.raises(MerlinWorkerFormatterNotSupportedError, match="unknown_formatter"): + formatter_factory.create("unknown_formatter") + + def test_invalid_registration_type_error(self, formatter_factory: WorkerFormatterFactory): + """ + Test that trying to register a non-WorkerFormatter raises TypeError. + + Args: + formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + """ + + class NotAFormatter: + pass + + with pytest.raises(TypeError, match="must inherit from WorkerFormatter"): + formatter_factory.register("fake_formatter", NotAFormatter) + + def test_register_overwrites_existing_formatter(self, formatter_factory: WorkerFormatterFactory): + """ + Test that registering a formatter with an existing name overwrites it. + + Args: + formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + """ + # Initially json should be DummyJSONFormatter + instance1 = formatter_factory.create("json") + assert isinstance(instance1, DummyJSONFormatter) + + # Register a different formatter with the same name + formatter_factory.register("json", DummyCSVFormatter) + + # Should now return the new formatter type + instance2 = formatter_factory.create("json") + assert isinstance(instance2, DummyCSVFormatter) + + def test_list_available_includes_registered_formatters(self, formatter_factory: WorkerFormatterFactory): + """ + Test that `list_available` includes dynamically registered formatters. + + Args: + formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + """ + # Initial built-in formatters + initial_available = set(formatter_factory.list_available()) + assert initial_available == {"json", "rich"} + + # Register additional formatter + formatter_factory.register("csv", DummyCSVFormatter) + + # Should now include the new formatter + updated_available = set(formatter_factory.list_available()) + assert updated_available == {"json", "rich", "csv"} + + # TODO should the factory list aliases as well? + def test_list_available_excludes_aliases(self, formatter_factory: WorkerFormatterFactory): + """ + Test that `list_available` returns canonical names, not aliases. + + Args: + formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + """ + formatter_factory.register("csv", DummyCSVFormatter, aliases=["comma", "spreadsheet"]) + + available = set(formatter_factory.list_available()) + # Should contain canonical name but not aliases + assert "csv" in available + assert "comma" not in available + assert "spreadsheet" not in available + + # TODO if we change 'config' to 'kwargs' in MerlinBaseFactory class create method, uncomment this + # def test_create_with_constructor_arguments(self, formatter_factory: WorkerFormatterFactory): + # """ + # Test that `create` can pass arguments to formatter constructors. + + # Args: + # formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + # """ + # class ParameterizedFormatter(WorkerFormatter): + # def __init__(self, param1=None, param2=None): + # self.param1 = param1 + # self.param2 = param2 + + # def format_and_display(self, logical_workers: List, filters: Dict, merlin_db: MerlinDatabase): + # return f"Formatted with {self.param1} and {self.param2}" + + # formatter_factory.register("parameterized", ParameterizedFormatter) + + # # Test creating with arguments + # instance = formatter_factory.create("parameterized", param1="test", param2=42) + # assert isinstance(instance, ParameterizedFormatter) + # assert instance.param1 == "test" + # assert instance.param2 == 42 + + def test_entry_point_group_returns_correct_namespace(self, formatter_factory: WorkerFormatterFactory): + """ + Test that the factory uses the correct entry point namespace. + + Args: + formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + """ + entry_point_group = formatter_factory._entry_point_group() + assert entry_point_group == "merlin.workers.formatters" + + def test_validate_component_accepts_valid_formatter(self, formatter_factory: WorkerFormatterFactory): + """ + Test that `_validate_component` accepts valid WorkerFormatter subclasses. + + Args: + formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + """ + # Should not raise an exception + formatter_factory._validate_component(DummyJSONFormatter) + formatter_factory._validate_component(DummyRichFormatter) + formatter_factory._validate_component(DummyCSVFormatter) + + def test_validate_component_rejects_invalid_formatter(self, formatter_factory: WorkerFormatterFactory): + """ + Test that `_validate_component` rejects non-WorkerFormatter classes. + + Args: + formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + """ + + class InvalidFormatter: + pass + + with pytest.raises(TypeError, match="must inherit from WorkerFormatter"): + formatter_factory._validate_component(InvalidFormatter) + + def test_raise_component_error_class_returns_correct_exception(self, formatter_factory: WorkerFormatterFactory): + """ + Test that the factory raises the correct exception type for invalid components. + + Args: + formatter_factory: Instance of the `WorkerFormatterFactory` for testing. + """ + with pytest.raises(MerlinWorkerFormatterNotSupportedError, match="test message"): + formatter_factory._raise_component_error_class("test message") + + def test_factory_instance_isolation(self, mocker: MockerFixture): + """ + Test that different factory instances don't interfere with each other. + + Args: + mocker: PyTest mocker fixture. + """ + # Mock the built-ins for both factories + mocker.patch("merlin.workers.formatters.formatter_factory.JSONWorkerFormatter", DummyJSONFormatter) + mocker.patch("merlin.workers.formatters.formatter_factory.RichWorkerFormatter", DummyRichFormatter) + + factory1 = WorkerFormatterFactory() + factory2 = WorkerFormatterFactory() + + # Register formatter in only one factory + factory1.register("csv", DummyCSVFormatter) + + # factory1 should have the new formatter + assert "csv" in factory1.list_available() + csv_instance = factory1.create("csv") + assert isinstance(csv_instance, DummyCSVFormatter) + + # factory2 should not have the new formatter + assert "csv" not in factory2.list_available() + with pytest.raises(MerlinWorkerFormatterNotSupportedError): + factory2.create("csv") diff --git a/tests/unit/workers/formatters/test_json_formatter.py b/tests/unit/workers/formatters/test_json_formatter.py new file mode 100644 index 000000000..0c72d8aeb --- /dev/null +++ b/tests/unit/workers/formatters/test_json_formatter.py @@ -0,0 +1,253 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Tests for the `merlin/workers/formatters/json_formatter.py` module. +""" + +import json +from datetime import datetime, timedelta +from typing import Dict, List +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from merlin.common.enums import WorkerStatus +from merlin.db_scripts.entities.logical_worker_entity import LogicalWorkerEntity +from merlin.db_scripts.entities.physical_worker_entity import PhysicalWorkerEntity +from merlin.db_scripts.merlin_db import MerlinDatabase +from merlin.workers.formatters.json_formatter import JSONWorkerFormatter + + +class TestJSONWorkerFormatter: + """Tests for the JSONWorkerFormatter class.""" + + @pytest.fixture + def formatter(self, mocker: MockerFixture) -> JSONWorkerFormatter: + """ + Create a JSONWorkerFormatter instance for testing. + + Args: + mocker: Pytest mocker fixture. + + Returns: + JSONWorkerFormatter instance with mocked console. + """ + # Mock the console from the base class + mock_console = MagicMock() + mocker.patch("merlin.workers.formatters.worker_formatter.Console", return_value=mock_console) + return JSONWorkerFormatter() + + @pytest.fixture + def mock_logical_workers(self) -> List[MagicMock]: + """ + Create mock logical worker entities for testing. + + Returns: + List of mock logical worker entities. + """ + worker1 = MagicMock(spec=LogicalWorkerEntity) + worker1.get_name.return_value = "logical_worker1" + worker1.get_queues.return_value = ["queue1", "queue2", "custom_queue"] + worker1.get_physical_workers.return_value = ["phys1", "phys2"] + + worker2 = MagicMock(spec=LogicalWorkerEntity) + worker2.get_name.return_value = "logical_worker2" + worker2.get_queues.return_value = ["queue3"] + worker2.get_physical_workers.return_value = [] # No physical workers + + worker3 = MagicMock(spec=LogicalWorkerEntity) + worker3.get_name.return_value = "logical_worker3" + worker3.get_queues.return_value = ["queue4"] + worker3.get_physical_workers.return_value = ["phys3"] + + return [worker1, worker2, worker3] + + @pytest.fixture + def mock_physical_workers(self) -> List[MagicMock]: + """ + Create mock physical worker entities for testing. + + Returns: + List of mock physical worker entities. + """ + # Current time for consistent testing + now = datetime.now() + + worker1 = MagicMock(spec=PhysicalWorkerEntity) + worker1.get_id.return_value = "phys1" + worker1.get_name.return_value = "physical_worker1" + worker1.get_host.return_value = "host1.example.com" + worker1.get_pid.return_value = 12345 + worker1.get_worker_status.return_value = WorkerStatus.RUNNING + worker1.get_restart_count.return_value = 0 + worker1.get_latest_start_time.return_value = now - timedelta(hours=2) + worker1.get_heartbeat_timestamp.return_value = now - timedelta(minutes=1) + + worker2 = MagicMock(spec=PhysicalWorkerEntity) + worker2.get_id.return_value = "phys2" + worker2.get_name.return_value = "physical_worker2" + worker2.get_host.return_value = "host2.example.com" + worker2.get_pid.return_value = 54321 + worker2.get_worker_status.return_value = WorkerStatus.STOPPED + worker2.get_restart_count.return_value = 3 + worker2.get_latest_start_time.return_value = None + worker2.get_heartbeat_timestamp.return_value = None + + worker3 = MagicMock(spec=PhysicalWorkerEntity) + worker3.get_id.return_value = "phys3" + worker3.get_name.return_value = "physical_worker3" + worker3.get_host.return_value = "host3.example.com" + worker3.get_pid.return_value = 99999 + worker3.get_worker_status.return_value = WorkerStatus.STALLED + worker3.get_restart_count.return_value = 1 + worker3.get_latest_start_time.return_value = now - timedelta(hours=1) + worker3.get_heartbeat_timestamp.return_value = now - timedelta(minutes=30) + + return [worker1, worker2, worker3] + + @pytest.fixture + def mock_db(self) -> MagicMock: + """ + Create a mock MerlinDatabase for testing. + + Returns: + Mock MerlinDatabase instance. + """ + return MagicMock(spec=MerlinDatabase) + + @pytest.fixture + def sample_stats(self) -> Dict[str, int]: + """ + Create sample worker statistics for testing. + + Returns: + Dictionary containing sample worker statistics. + """ + return { + "total_logical": 3, + "logical_with_instances": 2, + "logical_without_instances": 1, + "total_physical": 3, + "physical_running": 1, + "physical_stopped": 1, + "physical_stalled": 1, + "physical_rebooting": 0, + } + + def test_format_and_display_basic_structure( + self, + formatter: JSONWorkerFormatter, + mock_logical_workers: List[MagicMock], + mock_physical_workers: List[MagicMock], + mock_db: MagicMock, + sample_stats: Dict[str, int], + ): + """ + Test that format_and_display outputs valid JSON with correct basic structure. + + Args: + formatter: JSONWorkerFormatter instance for testing. + mock_logical_workers: List of mock logical worker entities. + mock_physical_workers: List of mock physical worker entities. + mock_db: Mock MerlinDatabase instance. + sample_stats: Sample worker statistics dictionary. + """ + # Setup database to return physical workers + mock_db.get.side_effect = mock_physical_workers + + # Mock get_worker_statistics method + formatter.get_worker_statistics = MagicMock(return_value=sample_stats) + + filters = {"queues": ["queue1"], "name": ["worker1"]} + formatter.format_and_display(mock_logical_workers, filters, mock_db) + + # Get the JSON output from the mocked console.print call + formatter.console.print.assert_called_once() + json_output = formatter.console.print.call_args[0][0] + data = json.loads(json_output) + + # Verify top-level structure + assert "filters" in data + assert "timestamp" in data + assert "logical_workers" in data + assert "summary" in data + + # Verify filters are preserved + assert data["filters"] == filters + + # Verify timestamp is ISO format + datetime.fromisoformat(data["timestamp"]) # Should not raise exception + + # Verify summary matches expected stats + assert data["summary"] == sample_stats + + # Verify logical workers array exists + assert isinstance(data["logical_workers"], list) + + def test_format_and_display_with_filters( + self, + formatter: JSONWorkerFormatter, + mock_logical_workers: List[MagicMock], + mock_physical_workers: List[MagicMock], + mock_db: MagicMock, + sample_stats: Dict[str, int], + ): + """ + Test JSON output includes filters correctly. + + Args: + formatter: JSONWorkerFormatter instance for testing. + mock_logical_workers: List of mock logical worker entities. + mock_physical_workers: List of mock physical worker entities. + mock_db: Mock MerlinDatabase instance. + sample_stats: Sample worker statistics dictionary. + """ + mock_db.get.side_effect = mock_physical_workers + formatter.get_worker_statistics = MagicMock(return_value=sample_stats) + + filters = {"queues": ["queue1", "queue2"], "name": ["worker_a", "worker_b"]} + formatter.format_and_display(mock_logical_workers, filters, mock_db) + + # Get the JSON output from the mocked console.print call + formatter.console.print.assert_called_once() + json_output = formatter.console.print.call_args[0][0] + data = json.loads(json_output) + + assert data["filters"]["queues"] == ["queue1", "queue2"] + assert data["filters"]["name"] == ["worker_a", "worker_b"] + + def test_format_and_display_with_empty_filters( + self, + formatter: JSONWorkerFormatter, + mock_logical_workers: List[MagicMock], + mock_physical_workers: List[MagicMock], + mock_db: MagicMock, + sample_stats: Dict[str, int], + ): + """ + Test JSON output with empty filters. + + Args: + formatter: JSONWorkerFormatter instance for testing. + mock_logical_workers: List of mock logical worker entities. + mock_physical_workers: List of mock physical worker entities. + mock_db: Mock MerlinDatabase instance. + sample_stats: Sample worker statistics dictionary. + """ + mock_db.get.side_effect = mock_physical_workers + formatter.get_worker_statistics = MagicMock(return_value=sample_stats) + + filters = {} + formatter.format_and_display(mock_logical_workers, filters, mock_db) + + # Get the JSON output from the mocked console.print call + formatter.console.print.assert_called_once() + json_output = formatter.console.print.call_args[0][0] + data = json.loads(json_output) + + assert data["filters"] == {} diff --git a/tests/unit/workers/formatters/test_rich_formatter.py b/tests/unit/workers/formatters/test_rich_formatter.py new file mode 100644 index 000000000..893e76176 --- /dev/null +++ b/tests/unit/workers/formatters/test_rich_formatter.py @@ -0,0 +1,736 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Tests for the `merlin/workers/formatters/rich_formatter.py` module. +""" + +from datetime import datetime, timedelta +from typing import List, Union +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture +from rich.text import Text + +from merlin.common.enums import WorkerStatus +from merlin.db_scripts.entities.logical_worker_entity import LogicalWorkerEntity +from merlin.db_scripts.entities.physical_worker_entity import PhysicalWorkerEntity +from merlin.db_scripts.merlin_db import MerlinDatabase +from merlin.workers.formatters.rich_formatter import ( + ColumnConfig, + LayoutConfig, + LayoutSize, + ResponsiveLayoutManager, + RichWorkerFormatter, +) + + +class TestLayoutSize: + """Tests for the LayoutSize enum.""" + + def test_layout_size_values(self): + """Test that LayoutSize enum has expected values.""" + assert LayoutSize.COMPACT.value == "compact" + assert LayoutSize.NARROW.value == "narrow" + assert LayoutSize.MEDIUM.value == "medium" + assert LayoutSize.WIDE.value == "wide" + + +class TestColumnConfig: + """Tests for the ColumnConfig dataclass.""" + + def test_column_config_defaults(self): + """Test that ColumnConfig has correct default values.""" + config = ColumnConfig("test_key", "Test Title") + + assert config.key == "test_key" + assert config.title == "Test Title" + assert config.style == "white" + assert config.width is None + assert config.max_width is None + assert config.justify == "left" + assert config.no_wrap is False + assert config.formatter is None + + def test_column_config_custom_values(self): + """Test that ColumnConfig accepts custom values.""" + + def formatter(x): + return str(x).upper() + + config = ColumnConfig( + key="custom_key", + title="Custom Title", + style="bold red", + width=10, + max_width=20, + justify="center", + no_wrap=True, + formatter=formatter, + ) + + assert config.key == "custom_key" + assert config.title == "Custom Title" + assert config.style == "bold red" + assert config.width == 10 + assert config.max_width == 20 + assert config.justify == "center" + assert config.no_wrap is True + assert config.formatter == formatter + + +class TestLayoutConfig: + """Tests for the LayoutConfig dataclass.""" + + def test_layout_config_defaults(self): + """Test that LayoutConfig has correct default values.""" + config = LayoutConfig(LayoutSize.MEDIUM) + + assert config.size == LayoutSize.MEDIUM + assert config.show_summary_panels is True + assert config.panels_horizontal is True + assert config.physical_worker_columns is None + assert config.logical_worker_columns is None + assert config.use_compact_view is False + + +class TestResponsiveLayoutManager: + """Tests for the ResponsiveLayoutManager class.""" + + @pytest.fixture + def layout_manager(self) -> ResponsiveLayoutManager: + """ + Create a ResponsiveLayoutManager instance for testing. + + Returns: + ResponsiveLayoutManager instance. + """ + return ResponsiveLayoutManager() + + @pytest.mark.parametrize( + "width, expected_size", + [ + (30, LayoutSize.COMPACT), # Compact + (59, LayoutSize.COMPACT), + (60, LayoutSize.NARROW), # Narrow + (70, LayoutSize.NARROW), + (79, LayoutSize.NARROW), + (80, LayoutSize.MEDIUM), # Medium + (100, LayoutSize.MEDIUM), + (119, LayoutSize.MEDIUM), + (120, LayoutSize.WIDE), # Wide + (150, LayoutSize.WIDE), + (200, LayoutSize.WIDE), + ], + ) + def test_get_layout_size(self, layout_manager: ResponsiveLayoutManager, width: int, expected_size: LayoutSize): + """ + Test that get_layout_size returns the correct layout size. + + Args: + layout_manager: ResponsiveLayoutManager instance. + width: Terminal width to test. + expected_size: Expected LayoutSize result. + """ + assert layout_manager.get_layout_size(width) == expected_size + + def test_get_layout_config_returns_correct_config(self, layout_manager: ResponsiveLayoutManager): + """ + Test that get_layout_config returns the correct LayoutConfig. + + Args: + layout_manager: ResponsiveLayoutManager instance. + """ + config = layout_manager.get_layout_config(100) + assert config.size == LayoutSize.MEDIUM + assert isinstance(config, LayoutConfig) + + @pytest.mark.parametrize( + "status, expected_icon, expected_text", + [ + (WorkerStatus.RUNNING, "✓", "RUNNING"), + (WorkerStatus.STOPPED, "✗", "STOPPED"), + (WorkerStatus.STALLED, "⚠", "STALLED"), + (WorkerStatus.REBOOTING, "↻", "REBOOTING"), + ], + ) + def test_format_status( + self, layout_manager: ResponsiveLayoutManager, status: WorkerStatus, expected_icon: str, expected_text: str + ): + """ + Test that various worker statuses are formatted correctly. + + Args: + layout_manager: ResponsiveLayoutManager instance. + status: WorkerStatus to format. + expected_icon: Expected icon in the formatted output. + expected_text: Expected text in the formatted output. + """ + formatted = layout_manager._format_status(status) + assert isinstance(formatted, Text) + assert expected_icon in str(formatted) + assert expected_text in str(formatted) + + +class TestRichWorkerFormatter: + """Tests for the RichWorkerFormatter class.""" + + @pytest.fixture + def formatter(self, mocker: MockerFixture) -> RichWorkerFormatter: + """ + Create a RichWorkerFormatter instance for testing. + + Args: + mocker: Pytest mocker fixture. + + Returns: + RichWorkerFormatter instance. + """ + # Mock the console to control its behavior + mock_console = MagicMock() + mock_console.size.width = 120 # Default to wide layout + mocker.patch("merlin.workers.formatters.worker_formatter.Console", return_value=mock_console) + return RichWorkerFormatter() + + @pytest.fixture + def mock_logical_workers(self) -> List[MagicMock]: + """ + Create mock logical worker entities for testing. + + Returns: + List of mock LogicalWorkerEntity instances. + """ + worker1 = MagicMock(spec=LogicalWorkerEntity) + worker1.get_name.return_value = "logical_worker1" + worker1.get_queues.return_value = ["queue1", "queue2"] + worker1.get_physical_workers.return_value = ["phys1", "phys2"] + + worker2 = MagicMock(spec=LogicalWorkerEntity) + worker2.get_name.return_value = "logical_worker2" + worker2.get_queues.return_value = ["queue3"] + worker2.get_physical_workers.return_value = [] # No physical workers + + return [worker1, worker2] + + @pytest.fixture + def mock_physical_workers(self) -> List[MagicMock]: + """ + Create mock physical worker entities for testing. + + Returns: + List of mock PhysicalWorkerEntity instances. + """ + worker1 = MagicMock(spec=PhysicalWorkerEntity) + worker1.get_id.return_value = "phys1" + worker1.get_name.return_value = "physical_worker1" + worker1.get_host.return_value = "host1" + worker1.get_pid.return_value = 12345 + worker1.get_worker_status.return_value = WorkerStatus.RUNNING + worker1.get_restart_count.return_value = 0 + worker1.get_latest_start_time.return_value = datetime.now() - timedelta(hours=2) + worker1.get_heartbeat_timestamp.return_value = datetime.now() - timedelta(minutes=1) + + worker2 = MagicMock(spec=PhysicalWorkerEntity) + worker2.get_id.return_value = "phys2" + worker2.get_name.return_value = "physical_worker2" + worker2.get_host.return_value = "host2" + worker2.get_pid.return_value = 54321 + worker2.get_worker_status.return_value = WorkerStatus.STOPPED + worker2.get_restart_count.return_value = 2 + worker2.get_latest_start_time.return_value = None + worker2.get_heartbeat_timestamp.return_value = None + + return [worker1, worker2] + + @pytest.fixture + def mock_db(self) -> MagicMock: + """ + Create a mock MerlinDatabase for testing. + + Returns: + Mock MerlinDatabase instance. + """ + return MagicMock(spec=MerlinDatabase) + + def test_get_queues_str_removes_merlin_prefix(self, formatter: RichWorkerFormatter): + """ + Test that _get_queues_str removes [merlin]_ prefix correctly. + + Args: + formatter: RichWorkerFormatter instance. + """ + queues = ["[merlin]_queue1", "[merlin]_queue2", "custom_queue"] + result = formatter._get_queues_str(queues) + assert result == "custom_queue, queue1, queue2" + + def test_get_queues_str_sorts_queues(self, formatter: RichWorkerFormatter): + """ + Test that _get_queues_str sorts queue names. + + Args: + formatter: RichWorkerFormatter instance. + """ + queues = ["[merlin]_zebra", "[merlin]_alpha", "[merlin]_beta"] + result = formatter._get_queues_str(queues) + assert result == "alpha, beta, zebra" + + @pytest.mark.parametrize( + "status, expected_icon, expected_color", + [ + (WorkerStatus.RUNNING, "✓", "green"), + (WorkerStatus.STOPPED, "✗", "red"), + (WorkerStatus.STALLED, "⚠", "yellow"), + (WorkerStatus.REBOOTING, "↻", "cyan"), + ], + ) + def test_format_status( + self, formatter: RichWorkerFormatter, status: Union[WorkerStatus, str], expected_icon: str, expected_color: str + ): + """ + Test that _format_status returns Rich Text with icons. + + Args: + formatter: RichWorkerFormatter instance. + status: WorkerStatus or string to format. + expected_icon: Expected icon in the formatted output. + expected_color: Expected color in the formatted output. + """ + formatted = formatter._format_status(status) + assert isinstance(formatted, Text) + assert expected_icon in str(formatted) + assert expected_color in formatted.style + + assert status.name in str(formatted) + + @pytest.mark.parametrize( + "duration, expected", + [ + (timedelta(days=2, hours=5), "2d 5h 0m"), + (timedelta(days=1, hours=1, minutes=1), "1d 1h 1m"), + (timedelta(hours=3, minutes=30), "3h 30m"), + (timedelta(minutes=45), "45m"), + (timedelta(seconds=30), "30s"), + ], + ) + def test_format_time_duration(self, formatter: RichWorkerFormatter, duration: timedelta, expected: str): + """ + Test time duration formatting. + + Args: + formatter: RichWorkerFormatter instance. + duration: timedelta to format. + expected: Expected formatted string. + """ + result = formatter._format_time_duration(duration) + assert result == expected + + @pytest.mark.parametrize( + "timestamp, expected", + [ + (datetime.now() - timedelta(seconds=30), "Just now"), + (datetime.now() - timedelta(minutes=10), "10m ago"), + (datetime.now() - timedelta(hours=2), "2h ago"), + (None, "-"), + ], + ) + def test_format_last_heartbeat(self, formatter: RichWorkerFormatter, timestamp: datetime, expected: str): + """ + Test heartbeat formatting for very recent heartbeats. + + Args: + formatter: RichWorkerFormatter instance. + timestamp: datetime of the last heartbeat. + expected: Expected formatted string. + """ + result = formatter._format_last_heartbeat(timestamp) + assert isinstance(result, Text) + assert expected in str(result) + + @pytest.mark.parametrize( + "status, timestamp, expected", + [ + (WorkerStatus.RUNNING, datetime.now() - timedelta(hours=1, minutes=30), "1h 30m"), + (WorkerStatus.RUNNING, None, "-"), + ], + ) + def test_format_uptime_or_downtime_running_worker( + self, formatter: RichWorkerFormatter, status: WorkerStatus, timestamp: timedelta, expected: str + ): + """ + Test uptime or downtime formatting. + + Args: + formatter: RichWorkerFormatter instance. + status: WorkerStatus of the worker. + timestamp: datetime of the latest start time. + expected: Expected formatted string. + """ + mock_worker = MagicMock() + mock_worker.get_worker_status.return_value = status + mock_worker.get_latest_start_time.return_value = timestamp + + result = formatter._format_uptime_or_downtime(mock_worker) + assert result == expected + + @pytest.mark.parametrize( + "status, timestamp, expected", + [ + (WorkerStatus.STOPPED, datetime.now() - timedelta(minutes=15), "down 15m"), + (WorkerStatus.STOPPED, None, "stopped"), + ], + ) + def test_format_uptime_or_downtime_stopped_worker( + self, formatter: RichWorkerFormatter, status: WorkerStatus, timestamp: timedelta, expected: str + ): + """ + Test uptime or downtime formatting. + + Args: + formatter: RichWorkerFormatter instance. + status: WorkerStatus of the worker. + timestamp: datetime of the stop time. + expected: Expected formatted string. + """ + mock_worker = MagicMock() + mock_worker.get_worker_status.return_value = status + mock_worker.get_stop_time.return_value = timestamp + + result = formatter._format_uptime_or_downtime(mock_worker) + assert result == expected + + def test_get_physical_worker_data( + self, + formatter: RichWorkerFormatter, + mock_logical_workers: List[MagicMock], + mock_physical_workers: List[MagicMock], + mock_db: MagicMock, + ): + """ + Test extraction of physical worker data for table display. + + Args: + formatter: RichWorkerFormatter instance. + mock_logical_workers: List of mock LogicalWorkerEntity instances. + mock_physical_workers: List of mock PhysicalWorkerEntity instances. + mock_db: Mock MerlinDatabase instance. + """ + # Setup database to return physical workers + mock_db.get.side_effect = mock_physical_workers + + data = formatter._get_physical_worker_data(mock_logical_workers, mock_db) + + assert len(data) == 2 # 2 physical workers for first logical, 0 for second logical + assert data[0]["worker"] == "logical_worker1" + assert data[0]["host"] == "host1" + assert data[0]["pid"] == "12345" + assert data[0]["worker_status"] == WorkerStatus.RUNNING + + def test_get_logical_workers_without_instances_data( + self, formatter: RichWorkerFormatter, mock_logical_workers: List[MagicMock] + ): + """ + Test extraction of logical workers without physical instances. + + Args: + formatter: RichWorkerFormatter instance. + mock_logical_workers: List of mock LogicalWorkerEntity instances. + """ + data = formatter._get_logical_workers_without_instances_data(mock_logical_workers) + + # Only logical_worker2 has no physical workers + assert len(data) == 1 + assert data[0]["worker"] == "logical_worker2" + assert data[0]["queues"] == "queue3" + assert isinstance(data[0]["status"], Text) + assert "NO INSTANCES" in str(data[0]["status"]) + + def test_sort_physical_workers(self, formatter: RichWorkerFormatter): + """ + Test that physical workers are sorted correctly. + + Args: + formatter: RichWorkerFormatter instance. + """ + data = [ + {"_sort_status": "STOPPED", "worker": "worker2", "instance": "inst2"}, + {"_sort_status": "RUNNING", "worker": "worker1", "instance": "inst1"}, + {"_sort_status": "RUNNING", "worker": "worker1", "instance": "inst2"}, + {"_sort_status": "STOPPED", "worker": "worker1", "instance": "inst1"}, + ] + + sorted_data = formatter._sort_physical_workers(data) + + # Running workers should come first, then sorted by worker and instance name + assert sorted_data[0]["_sort_status"] == "RUNNING" + assert sorted_data[1]["_sort_status"] == "RUNNING" + assert sorted_data[2]["_sort_status"] == "STOPPED" + assert sorted_data[3]["_sort_status"] == "STOPPED" + + def test_build_summary_panels_with_filters(self, formatter: RichWorkerFormatter): + """ + Test building summary panels with filters applied. + + Args: + formatter: RichWorkerFormatter instance. + """ + stats = { + "total_logical": 5, + "logical_with_instances": 3, + "logical_without_instances": 2, + "total_physical": 8, + "physical_running": 6, + "physical_stopped": 2, + "physical_stalled": 0, + "physical_rebooting": 0, + } + filters = {"queues": ["queue1", "queue2"], "name": ["worker1", "worker2"]} + + panels = formatter._build_summary_panels(stats, filters) + + assert len(panels) == 3 # Filter, Logical, Physical panels + # Check that filter panel contains expected content + filter_panel_content = panels[0] + assert "queue1, queue2" in filter_panel_content.renderable + assert "worker1, worker2" in filter_panel_content.renderable + + def test_build_summary_panels_no_filters(self, formatter: RichWorkerFormatter): + """ + Test building summary panels without filters. + + Args: + formatter: RichWorkerFormatter instance. + """ + stats = { + "total_logical": 3, + "logical_with_instances": 2, + "logical_without_instances": 1, + "total_physical": 0, + "physical_running": 0, + "physical_stopped": 0, + "physical_stalled": 0, + "physical_rebooting": 0, + } + filters = {} + + panels = formatter._build_summary_panels(stats, filters) + + # No filter panel, only logical panel (no physical since total is 0) + assert len(panels) == 1 + + def test_build_compact_view( + self, + formatter: RichWorkerFormatter, + mock_logical_workers: List[MagicMock], + mock_physical_workers: List[MagicMock], + mock_db: MagicMock, + ): + """ + Test building compact view for narrow terminals. + + Args: + formatter: RichWorkerFormatter instance. + mock_logical_workers: List of mock LogicalWorkerEntity instances. + mock_physical_workers: List of mock PhysicalWorkerEntity instances. + mock_db: Mock MerlinDatabase instance. + """ + mock_db.get.side_effect = mock_physical_workers + + compact_view = formatter._build_compact_view(mock_logical_workers, mock_db) + + assert "logical_worker1" in compact_view + assert "logical_worker2" in compact_view + assert "NO INSTANCES" in compact_view + assert "host1" in compact_view + + def test_build_responsive_table(self, formatter: RichWorkerFormatter): + """ + Test building a responsive table with column configuration. + + Args: + formatter: RichWorkerFormatter instance. + """ + columns = [ + ColumnConfig(key="name", title="Name", style="bold white"), + ColumnConfig(key="status", title="Status", style="green", formatter=lambda x: f"[{x}]"), + ] + data = [{"name": "worker1", "status": "running"}, {"name": "worker2", "status": "stopped"}] + + table = formatter._build_responsive_table("Test Table", columns, data) + + assert table.title == "Test Table" + assert len(table.columns) == 2 + + def test_format_and_display_compact_view( + self, mocker: MockerFixture, mock_logical_workers: List[MagicMock], mock_db: MagicMock + ): + """ + Test format_and_display uses compact view for narrow terminals. + + Args: + mocker: Pytest mocker fixture. + mock_logical_workers: List of mock LogicalWorkerEntity instances. + mock_db: Mock MerlinDatabase instance. + """ + # Create formatter with narrow console + mock_console = MagicMock() + mock_console.size.width = 40 # Compact layout + mocker.patch("merlin.workers.formatters.worker_formatter.Console", return_value=mock_console) + formatter = RichWorkerFormatter() + + # Mock get_worker_statistics + stats = { + "total_logical": 2, + "total_physical": 2, + "logical_with_instances": 1, + "logical_without_instances": 1, + "physical_running": 1, + "physical_stopped": 1, + "physical_stalled": 0, + "physical_rebooting": 0, + } + mocker.patch.object(formatter, "get_worker_statistics", return_value=stats) + + filters = {} + formatter.format_and_display(mock_logical_workers, filters, mock_db) + + # Should call _display_compact_view instead of normal tables + assert mock_console.print.called + + def test_format_and_display_normal_view( + self, + mocker: MockerFixture, + mock_logical_workers: List[MagicMock], + mock_physical_workers: List[MagicMock], + mock_db: MagicMock, + ): + """ + Test format_and_display uses normal view for wide terminals. + + Args: + mocker: Pytest mocker fixture. + mock_logical_workers: List of mock LogicalWorkerEntity instances. + mock_physical_workers: List of mock PhysicalWorkerEntity instances. + mock_db: Mock MerlinDatabase instance. + """ + # Create formatter with wide console + mock_console = MagicMock() + mock_console.size.width = 150 # Wide layout + mocker.patch("merlin.workers.formatters.worker_formatter.Console", return_value=mock_console) + formatter = RichWorkerFormatter() + + # Mock database to return physical workers + mock_db.get.side_effect = mock_physical_workers + + # Mock get_worker_statistics + stats = { + "total_logical": 2, + "total_physical": 2, + "logical_with_instances": 1, + "logical_without_instances": 1, + "physical_running": 1, + "physical_stopped": 1, + "physical_stalled": 0, + "physical_rebooting": 0, + } + mocker.patch.object(formatter, "get_worker_statistics", return_value=stats) + + filters = {"queues": ["queue1"]} + formatter.format_and_display(mock_logical_workers, filters, mock_db) + + # Should display summary panels and tables + assert mock_console.print.called + # Should be called multiple times (empty lines, panels, tables) + assert mock_console.print.call_count > 3 + + def test_display_summary_panels_horizontal(self, mocker: MockerFixture, formatter: RichWorkerFormatter): + """ + Test that summary panels are displayed horizontally when configured. + + Args: + mocker: Pytest mocker fixture. + formatter: RichWorkerFormatter instance. + """ + mock_columns = mocker.patch("merlin.workers.formatters.rich_formatter.Columns") + stats = { + "total_logical": 1, + "logical_with_instances": 1, + "logical_without_instances": 0, + "total_physical": 1, + "physical_running": 0, + "physical_stopped": 1, + "physical_stalled": 0, + "physical_rebooting": 0, + } + filters = {} + layout_config = LayoutConfig(LayoutSize.WIDE, panels_horizontal=True) + + formatter._display_summary_panels(stats, filters, layout_config) + + # Should create Columns object for horizontal layout + mock_columns.assert_called_once() + + def test_display_summary_panels_vertical(self, mocker: MockerFixture, formatter: RichWorkerFormatter): + """ + Test that summary panels are displayed vertically when configured. + + Args: + mocker: Pytest mocker fixture. + formatter: RichWorkerFormatter instance. + """ + stats = { + "total_logical": 1, + "logical_with_instances": 0, + "logical_without_instances": 1, + "total_physical": 0, + "physical_running": 0, + "physical_stopped": 0, + "physical_stalled": 0, + "physical_rebooting": 0, + } + filters = {} + layout_config = LayoutConfig(LayoutSize.NARROW, panels_horizontal=False) + + # Mock console.print to track calls + mock_print = mocker.patch.object(formatter.console, "print") + + formatter._display_summary_panels(stats, filters, layout_config) + + # Should print each panel individually (not using Columns) + assert mock_print.called + + def test_display_compact_view(self, mocker: MockerFixture, formatter: RichWorkerFormatter): + """ + Test display of compact view. + + Args: + mocker: Pytest mocker fixture. + formatter: RichWorkerFormatter instance. + """ + compact_view = "Test compact view content" + filters = {"queues": ["queue1"]} + stats = { + "total_logical": 1, + "logical_with_instances": 1, + "logical_without_instances": 0, + "total_physical": 2, + "physical_running": 1, + "physical_stopped": 1, + "physical_stalled": 0, + "physical_rebooting": 0, + } + + mock_print = mocker.patch.object(formatter.console, "print") + + formatter._display_compact_view(compact_view, filters, stats) + + # Should print title, filters, summary, and compact view + assert mock_print.call_count == 4 + + # Check that filter information was included in one of the print calls + print_args = [str(call[0][0]) for call in mock_print.call_args_list] + filter_found = any("queue1" in arg for arg in print_args) + assert filter_found diff --git a/tests/unit/workers/formatters/test_worker_formatter.py b/tests/unit/workers/formatters/test_worker_formatter.py new file mode 100644 index 000000000..b4e09a57d --- /dev/null +++ b/tests/unit/workers/formatters/test_worker_formatter.py @@ -0,0 +1,332 @@ +############################################################################## +# Copyright (c) Lawrence Livermore National Security, LLC and other Merlin +# Project developers. See top-level LICENSE and COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute to Merlin. +############################################################################## + +""" +Tests for the `merlin/workers/formatters/worker_formatter.py` module. +""" + +from typing import Dict, List +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from merlin.common.enums import WorkerStatus +from merlin.db_scripts.entities.logical_worker_entity import LogicalWorkerEntity +from merlin.db_scripts.merlin_db import MerlinDatabase +from merlin.workers.formatters.worker_formatter import WorkerFormatter + + +class DummyWorkerFormatter(WorkerFormatter): + """Dummy implementation of WorkerFormatter for testing.""" + + def __init__(self): + super().__init__() + self.formatted_data = None + self.display_called = False + + def format_and_display(self, logical_workers: List, filters: Dict, merlin_db: MerlinDatabase): + """Mock implementation that stores the call parameters.""" + self.formatted_data = {"logical_workers": logical_workers, "filters": filters, "merlin_db": merlin_db} + self.display_called = True + return f"Formatted {len(logical_workers)} workers with filters {filters}" + + +def test_abstract_formatter_cannot_be_instantiated(): + """Test that attempting to instantiate the abstract base class raises a TypeError.""" + with pytest.raises(TypeError): + WorkerFormatter() + + +def test_unimplemented_method_raises_not_implemented(): + """Test that calling abstract methods on a subclass without implementation raises NotImplementedError.""" + + class IncompleteFormatter(WorkerFormatter): + pass + + # Should raise TypeError due to unimplemented abstract method + with pytest.raises(TypeError): + IncompleteFormatter() + + +class TestWorkerFormatter: + """Tests for the WorkerFormatter abstract base class and its concrete implementations.""" + + @pytest.fixture + def formatter(self, mocker: MockerFixture) -> DummyWorkerFormatter: + """ + Create a DummyWorkerFormatter instance for testing. + + Args: + mocker: Pytest mocker fixture. + + Returns: + DummyWorkerFormatter instance with mocked console. + """ + # Mock the console to avoid actual Rich console creation + mock_console = MagicMock() + mocker.patch("merlin.workers.formatters.worker_formatter.Console", return_value=mock_console) + return DummyWorkerFormatter() + + @pytest.fixture + def mock_logical_workers(self) -> List[MagicMock]: + """ + Create mock logical worker entities for testing. + + Returns: + List of mock logical worker entities. + """ + worker1 = MagicMock(spec=LogicalWorkerEntity) + worker1.get_name.return_value = "worker1" + worker1.get_queues.return_value = ["queue1", "queue2"] + worker1.get_physical_workers.return_value = ["phys1", "phys2"] + + worker2 = MagicMock(spec=LogicalWorkerEntity) + worker2.get_name.return_value = "worker2" + worker2.get_queues.return_value = ["queue3"] + worker2.get_physical_workers.return_value = [] + + return [worker1, worker2] + + @pytest.fixture + def mock_physical_workers(self) -> List[MagicMock]: + """ + Create mock physical worker entities for testing. + + Returns: + List of mock physical worker entities. + """ + worker1 = MagicMock() + worker1.get_worker_status.return_value = WorkerStatus.RUNNING + + worker2 = MagicMock() + worker2.get_worker_status.return_value = WorkerStatus.STOPPED + + worker3 = MagicMock() + worker3.get_worker_status.return_value = WorkerStatus.STALLED + + return [worker1, worker2, worker3] + + @pytest.fixture + def mock_db(self) -> MagicMock: + """ + Create a mock MerlinDatabase for testing. + + Returns: + Mock MerlinDatabase instance. + """ + return MagicMock(spec=MerlinDatabase) + + def test_console_initialization(self, formatter: DummyWorkerFormatter): + """ + Test that WorkerFormatter initializes with a Rich Console. + + Args: + formatter: DummyWorkerFormatter instance for testing. + """ + assert hasattr(formatter, "console") + assert formatter.console is not None + + def test_format_and_display_abstract_method_implemented( + self, formatter: DummyWorkerFormatter, mock_logical_workers: List[MagicMock], mock_db: MagicMock + ): + """ + Test that concrete implementation can override format_and_display. + + Args: + formatter: DummyWorkerFormatter instance for testing. + mock_logical_workers: List of mock logical worker entities. + mock_db: Mock MerlinDatabase instance. + """ + filters = {"queues": ["test_queue"]} + result = formatter.format_and_display(mock_logical_workers, filters, mock_db) + + assert formatter.display_called is True + assert "Formatted 2 workers" in result + assert formatter.formatted_data["logical_workers"] == mock_logical_workers + assert formatter.formatted_data["filters"] == filters + assert formatter.formatted_data["merlin_db"] == mock_db + + def test_get_worker_statistics_basic_functionality( + self, + formatter: DummyWorkerFormatter, + mock_logical_workers: List[MagicMock], + mock_physical_workers: List[MagicMock], + mock_db: MagicMock, + ): + """ + Test that get_worker_statistics computes basic statistics correctly. + + Args: + formatter: DummyWorkerFormatter instance for testing. + mock_logical_workers: List of mock logical worker entities. + mock_physical_workers: List of mock physical worker entities. + mock_db: Mock MerlinDatabase instance. + """ + # Setup database to return physical workers + mock_db.get.side_effect = mock_physical_workers + + stats = formatter.get_worker_statistics(mock_logical_workers, mock_db) + + # Verify basic counts + assert stats["total_logical"] == 2 + assert stats["logical_with_instances"] == 1 # Only worker1 has physical workers + assert stats["logical_without_instances"] == 1 # worker2 has no physical workers + assert stats["total_physical"] == 2 # worker1 has 2 physical workers + + def test_get_worker_statistics_status_counts( + self, + formatter: DummyWorkerFormatter, + mock_logical_workers: List[MagicMock], + mock_physical_workers: List[MagicMock], + mock_db: MagicMock, + ): + """ + Test that get_worker_statistics counts worker statuses correctly. + + Args: + formatter: DummyWorkerFormatter instance for testing. + mock_logical_workers: List of mock logical worker entities. + mock_physical_workers: List of mock physical worker entities. + mock_db: Mock MerlinDatabase instance. + """ + # Setup database to return physical workers for first logical worker only + mock_db.get.side_effect = mock_physical_workers[:2] # First 2 physical workers + + stats = formatter.get_worker_statistics(mock_logical_workers, mock_db) + + # Verify status counts + assert stats["physical_running"] == 1 + assert stats["physical_stopped"] == 1 + assert stats["physical_stalled"] == 0 + assert stats["physical_rebooting"] == 0 + + def test_get_worker_statistics_with_empty_logical_workers(self, formatter: DummyWorkerFormatter, mock_db: MagicMock): + """ + Test get_worker_statistics with empty logical workers list. + + Args: + formatter: DummyWorkerFormatter instance for testing. + mock_db: Mock MerlinDatabase instance. + """ + stats = formatter.get_worker_statistics([], mock_db) + + # All counts should be zero + assert stats["total_logical"] == 0 + assert stats["logical_with_instances"] == 0 + assert stats["logical_without_instances"] == 0 + assert stats["total_physical"] == 0 + assert stats["physical_running"] == 0 + assert stats["physical_stopped"] == 0 + assert stats["physical_stalled"] == 0 + assert stats["physical_rebooting"] == 0 + + def test_get_worker_statistics_with_all_status_types(self, formatter: DummyWorkerFormatter, mock_db: MagicMock): + """ + Test get_worker_statistics counts all worker status types correctly. + + Args: + formatter: DummyWorkerFormatter instance for testing. + mock_db: Mock MerlinDatabase instance. + """ + # Create logical worker with multiple physical workers of different statuses + logical_worker = MagicMock(spec=LogicalWorkerEntity) + logical_worker.get_physical_workers.return_value = ["p1", "p2", "p3", "p4"] + + # Create physical workers with all status types + physical_workers = [] + statuses = [WorkerStatus.RUNNING, WorkerStatus.STOPPED, WorkerStatus.STALLED, WorkerStatus.REBOOTING] + for i, status in enumerate(statuses): + worker = MagicMock() + worker.get_worker_status.return_value = status + physical_workers.append(worker) + + mock_db.get.side_effect = physical_workers + + stats = formatter.get_worker_statistics([logical_worker], mock_db) + + # Verify all status types are counted + assert stats["total_logical"] == 1 + assert stats["logical_with_instances"] == 1 + assert stats["logical_without_instances"] == 0 + assert stats["total_physical"] == 4 + assert stats["physical_running"] == 1 + assert stats["physical_stopped"] == 1 + assert stats["physical_stalled"] == 1 + assert stats["physical_rebooting"] == 1 + + def test_get_worker_statistics_database_interaction( + self, + formatter: DummyWorkerFormatter, + mock_logical_workers: List[MagicMock], + mock_physical_workers: List[MagicMock], + mock_db: MagicMock, + ): + """ + Test that get_worker_statistics interacts with database correctly. + + Args: + formatter: DummyWorkerFormatter instance for testing. + mock_logical_workers: List of mock logical worker entities. + mock_physical_workers: List of mock physical worker entities. + mock_db: Mock MerlinDatabase instance. + """ + mock_db.get.side_effect = mock_physical_workers + + formatter.get_worker_statistics(mock_logical_workers, mock_db) + + # Verify database was called for each physical worker ID + expected_calls = [("physical_worker", "phys1"), ("physical_worker", "phys2")] + actual_calls = [call[0] for call in mock_db.get.call_args_list] + + for expected_call in expected_calls: + assert expected_call in actual_calls + + def test_get_worker_statistics_handles_mixed_scenarios(self, formatter: DummyWorkerFormatter, mock_db: MagicMock): + """ + Test get_worker_statistics with mixed scenarios (some workers with/without instances). + + Args: + formatter: DummyWorkerFormatter instance for testing. + mock_db: Mock MerlinDatabase instance. + """ + # Create logical workers: one with multiple instances, one without, one with single instance + logical_workers = [] + + # Worker 1: Has 2 physical workers + worker1 = MagicMock(spec=LogicalWorkerEntity) + worker1.get_physical_workers.return_value = ["p1", "p2"] + logical_workers.append(worker1) + + # Worker 2: No physical workers + worker2 = MagicMock(spec=LogicalWorkerEntity) + worker2.get_physical_workers.return_value = [] + logical_workers.append(worker2) + + # Worker 3: Has 1 physical worker + worker3 = MagicMock(spec=LogicalWorkerEntity) + worker3.get_physical_workers.return_value = ["p3"] + logical_workers.append(worker3) + + # Create physical workers + physical_workers = [ + MagicMock(get_worker_status=lambda: WorkerStatus.RUNNING), + MagicMock(get_worker_status=lambda: WorkerStatus.STOPPED), + MagicMock(get_worker_status=lambda: WorkerStatus.STALLED), + ] + + mock_db.get.side_effect = physical_workers + + stats = formatter.get_worker_statistics(logical_workers, mock_db) + + assert stats["total_logical"] == 3 + assert stats["logical_with_instances"] == 2 # worker1 and worker3 + assert stats["logical_without_instances"] == 1 # worker2 + assert stats["total_physical"] == 3 + assert stats["physical_running"] == 1 + assert stats["physical_stopped"] == 1 + assert stats["physical_stalled"] == 1 + assert stats["physical_rebooting"] == 0 diff --git a/tests/unit/workers/handlers/test_celery_handler.py b/tests/unit/workers/handlers/test_celery_handler.py index 62340f9d3..b73398c6b 100644 --- a/tests/unit/workers/handlers/test_celery_handler.py +++ b/tests/unit/workers/handlers/test_celery_handler.py @@ -14,6 +14,7 @@ import pytest from pytest_mock import MockerFixture +from merlin.common.enums import WorkerStatus from merlin.workers.celery_worker import CeleryWorker from merlin.workers.handlers import CeleryWorkerHandler @@ -43,20 +44,82 @@ class TestCeleryWorkerHandler: """ @pytest.fixture - def handler(self) -> CeleryWorkerHandler: - return CeleryWorkerHandler() + def handler(self, mocker: MockerFixture, mock_db_instance: MagicMock) -> CeleryWorkerHandler: + """ + Create a CeleryWorkerHandler instance with mocked database. + + Args: + mocker: Pytest mocker fixture. + mock_db_instance: Mocked MerlinDatabase instance. + + Returns: + CeleryWorkerHandler instance with mocked dependencies. + """ + mock_app = mocker.patch("merlin.celery.app") + return CeleryWorkerHandler(merlin_db=mock_db_instance, app=mock_app) @pytest.fixture def mock_db(self, mocker: MockerFixture) -> MagicMock: + """ + Mock the MerlinDatabase used in CeleryWorker constructor. + + Args: + mocker: Pytest mocker fixture. + + Returns: + A mocked MerlinDatabase instance. + """ return mocker.patch("merlin.workers.celery_worker.MerlinDatabase") @pytest.fixture def workers(self, mock_db: MagicMock) -> List[DummyCeleryWorker]: + """ + Create a list of dummy CeleryWorker instances for testing. + + Args: + mock_db: Mocked MerlinDatabase instance. + + Returns: + List of DummyCeleryWorker instances. + """ return [ DummyCeleryWorker("worker1"), DummyCeleryWorker("worker2"), ] + @pytest.fixture + def mock_logical_workers(self) -> List[MagicMock]: + """ + Create mock logical worker entities for testing query_workers. + + Returns: + List of mock logical worker entities. + """ + worker1 = MagicMock() + worker1.get_name.return_value = "logical_worker1" + worker1.get_queues.return_value = ["[merlin]_queue1", "[merlin]_queue2"] + + worker2 = MagicMock() + worker2.get_name.return_value = "logical_worker2" + worker2.get_queues.return_value = ["[merlin]_queue3"] + + return [worker1, worker2] + + @pytest.fixture + def mock_formatter(self, mocker: MockerFixture) -> MagicMock: + """ + Mock the worker formatter factory and formatter. + + Args: + mocker: Pytest mocker fixture. + + Returns: + Mock formatter instance. + """ + mock_formatter = MagicMock() + mocker.patch("merlin.workers.handlers.celery_handler.worker_formatter_factory.create", return_value=mock_formatter) + return mock_formatter + def test_echo_only_prints_commands( self, handler: CeleryWorkerHandler, workers: List[DummyCeleryWorker], capsys: pytest.CaptureFixture ): @@ -100,3 +163,474 @@ def test_default_kwargs_are_used(self, handler: CeleryWorkerHandler, workers: Li for worker in workers: assert worker.launched_with == ("", False) + + def test_build_filters_with_queues_and_workers(self, handler: CeleryWorkerHandler): + """ + Test that `_build_filters` correctly constructs filters dictionary. + + Args: + handler: CeleryWorkerHandler instance. + """ + queues = ["queue1", "queue2"] + workers = ["worker1", "worker2"] + + filters = handler._build_filters(queues, workers) + + assert filters == {"queues": ["[merlin]_queue1", "[merlin]_queue2"], "name": ["worker1", "worker2"]} + + def test_build_filters_with_only_queues(self, handler: CeleryWorkerHandler): + """ + Test that `_build_filters` handles only queues parameter. + + Args: + handler: CeleryWorkerHandler instance. + """ + queues = ["queue1"] + + filters = handler._build_filters(queues, None) + + assert filters == {"queues": ["[merlin]_queue1"]} + + def test_build_filters_with_only_workers(self, handler: CeleryWorkerHandler): + """ + Test that `_build_filters` handles only workers parameter. + + Args: + handler: CeleryWorkerHandler instance. + """ + workers = ["worker1"] + + filters = handler._build_filters(None, workers) + + assert filters == {"name": ["worker1"]} + + def test_build_filters_with_no_parameters(self, handler: CeleryWorkerHandler): + """ + Test that `_build_filters` returns empty dict when no parameters provided. + + Args: + handler: CeleryWorkerHandler instance. + """ + filters = handler._build_filters(None, None) + + assert filters == {} + + def test_query_workers_calls_database_and_formatter( + self, + handler: CeleryWorkerHandler, + mock_logical_workers: List[MagicMock], + mock_formatter: MagicMock, + mocker: MockerFixture, + ): + """ + Test that `query_workers` retrieves data from database and calls formatter. + + Args: + handler: CeleryWorkerHandler instance. + mock_logical_workers: Mock logical worker entities. + mock_formatter: Mock formatter instance. + mocker: Pytest mocker fixture. + """ + # Mock the database get_all method + handler.merlin_db.get_all.return_value = mock_logical_workers + + # Mock the validation method to avoid Celery inspection + mocker.patch.object(handler, "_validate_worker_status") + + handler.query_workers("rich", queues=["queue1"], workers=["worker1"]) + + # Verify database was called with correct filters + expected_filters = {"queues": ["[merlin]_queue1"], "name": ["worker1"]} + handler.merlin_db.get_all.assert_called_once_with("logical_worker", filters=expected_filters) + + # Verify formatter was created and called + mock_formatter.format_and_display.assert_called_once_with(mock_logical_workers, expected_filters, handler.merlin_db) + + def test_query_workers_with_no_filters( + self, + handler: CeleryWorkerHandler, + mock_logical_workers: List[MagicMock], + mock_formatter: MagicMock, + mocker: MockerFixture, + ): + """ + Test that `query_workers` works correctly when no filters are provided. + + Args: + handler: CeleryWorkerHandler instance. + mock_logical_workers: Mock logical worker entities. + mock_formatter: Mock formatter instance. + mocker: Pytest mocker fixture. + """ + handler.merlin_db.get_all.return_value = mock_logical_workers + + # Mock the validation method to avoid Celery inspection + mocker.patch.object(handler, "_validate_worker_status") + + handler.query_workers("json") + + # Verify database was called with empty filters + handler.merlin_db.get_all.assert_called_once_with("logical_worker", filters={}) + + # Verify formatter was called correctly + mock_formatter.format_and_display.assert_called_once_with(mock_logical_workers, {}, handler.merlin_db) + + def test_query_workers_uses_correct_formatter( + self, handler: CeleryWorkerHandler, mock_logical_workers: List[MagicMock], mocker: MockerFixture + ): + """ + Test that `query_workers` uses the correct formatter type. + + Args: + handler: CeleryWorkerHandler instance. + mock_logical_workers: Mock logical worker entities. + mocker: Pytest mocker fixture. + """ + handler.merlin_db.get_all.return_value = mock_logical_workers + + # Mock the validation method to avoid Celery inspection + mocker.patch.object(handler, "_validate_worker_status") + + mock_factory = mocker.patch("merlin.workers.handlers.celery_handler.worker_formatter_factory") + mock_formatter = MagicMock() + mock_factory.create.return_value = mock_formatter + + handler.query_workers("json", queues=["test_queue"]) + + # Verify the correct formatter type was requested + mock_factory.create.assert_called_once_with("json") + mock_formatter.format_and_display.assert_called_once() + + def test_query_workers_handles_empty_results( + self, handler: CeleryWorkerHandler, mock_formatter: MagicMock, mocker: MockerFixture + ): + """ + Test that `query_workers` handles empty database results gracefully. + + Args: + handler: CeleryWorkerHandler instance. + mock_formatter: Mock formatter instance. + mocker: Pytest mocker fixture. + """ + handler.merlin_db.get_all.return_value = [] + + # Mock the validation method to avoid Celery inspection + mocker.patch.object(handler, "_validate_worker_status") + + handler.query_workers("rich") + + # Verify formatter is still called with empty list + mock_formatter.format_and_display.assert_called_once_with([], {}, handler.merlin_db) + + def test_query_workers_passes_all_parameters_to_formatter( + self, + handler: CeleryWorkerHandler, + mock_logical_workers: List[MagicMock], + mock_formatter: MagicMock, + mocker: MockerFixture, + ): + """ + Test that `query_workers` passes all necessary parameters to formatter. + + Args: + handler: CeleryWorkerHandler instance. + mock_logical_workers: Mock logical worker entities. + mock_formatter: Mock formatter instance. + mocker: Pytest mocker fixture. + """ + handler.merlin_db.get_all.return_value = mock_logical_workers + + # Mock the validation method to avoid Celery inspection + mocker.patch.object(handler, "_validate_worker_status") + + queues = ["[merlin]_queue1", "[merlin]_queue2"] + workers = ["worker1"] + + handler.query_workers("rich", queues=queues, workers=workers) + + expected_filters = {"queues": queues, "name": workers} + + # Verify all parameters are passed correctly + mock_formatter.format_and_display.assert_called_once_with(mock_logical_workers, expected_filters, handler.merlin_db) + + def test_get_active_workers_returns_worker_queue_mapping(self, handler: CeleryWorkerHandler): + """ + Test that `get_active_workers` correctly maps workers to their queues. + + Args: + handler: CeleryWorkerHandler instance. + """ + # Mock Celery app and inspection + mock_inspect = MagicMock() + handler.app.control.inspect.return_value = mock_inspect + + # Mock active queues response + mock_inspect.active_queues.return_value = { + "celery@worker1": [{"name": "[merlin]_queue1"}, {"name": "[merlin]_queue2"}], + "celery@worker2": [{"name": "[merlin]_queue1"}], + } + + result = handler.get_active_workers() + + expected = {"celery@worker1": ["[merlin]_queue1", "[merlin]_queue2"], "celery@worker2": ["[merlin]_queue1"]} + assert result == expected + + def test_get_active_workers_handles_no_active_workers(self, handler: CeleryWorkerHandler): + """ + Test that `get_active_workers` handles case when no workers are active. + + Args: + handler: CeleryWorkerHandler instance. + """ + # Mock Celery app and inspection + mock_inspect = MagicMock() + handler.app.control.inspect.return_value = mock_inspect + + # Mock empty response + mock_inspect.active_queues.return_value = None + + result = handler.get_active_workers() + + assert result == {} + + def test_get_active_workers_handles_empty_worker_dict(self, handler: CeleryWorkerHandler): + """ + Test that `get_active_workers` handles empty worker dictionary. + + Args: + handler: CeleryWorkerHandler instance. + """ + # Mock Celery app and inspection + mock_inspect = MagicMock() + handler.app.control.inspect.return_value = mock_inspect + + # Mock empty dictionary response + mock_inspect.active_queues.return_value = {} + + result = handler.get_active_workers() + + assert result == {} + + def test_validate_worker_status_marks_dead_workers_as_stalled(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `_validate_worker_status` marks workers as stalled when not found in Celery. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + # # Mock Celery app import + # mock_app = MagicMock() + # mocker.patch("merlin.workers.handlers.celery_handler.app", mock_app) + + # Mock get_active_workers to return empty dict (no live workers) + mocker.patch.object(handler, "get_active_workers", return_value={}) + + # Create mock physical worker that's marked as RUNNING + mock_physical = MagicMock() + mock_physical.get_worker_status.return_value = WorkerStatus.RUNNING + mock_physical.get_name.return_value = "celery@dead_worker" + + # Create mock logical worker + mock_logical = MagicMock() + mock_logical.get_physical_workers.return_value = ["physical_id_1"] + + # Mock database get method + handler.merlin_db.get.return_value = mock_physical + + handler._validate_worker_status([mock_logical]) + + # Verify status was set to STALLED + mock_physical.set_worker_status.assert_called_once() + + def test_validate_worker_status_leaves_running_workers_unchanged( + self, handler: CeleryWorkerHandler, mocker: MockerFixture + ): + """ + Test that `_validate_worker_status` doesn't change status of workers found in Celery. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + # # Mock Celery app import + # mock_app = MagicMock() + # mocker.patch("merlin.workers.handlers.celery_handler.app", mock_app) + + # Mock get_active_workers to return live worker + mocker.patch.object(handler, "get_active_workers", return_value={"celery@live_worker": ["[merlin]_queue1"]}) + + # Create mock physical worker that's marked as RUNNING + mock_physical = MagicMock() + mock_physical.get_worker_status.return_value = WorkerStatus.RUNNING + mock_physical.get_name.return_value = "celery@live_worker" + + # Create mock logical worker + mock_logical = MagicMock() + mock_logical.get_physical_workers.return_value = ["physical_id_1"] + + # Mock database get method + handler.merlin_db.get.return_value = mock_physical + + handler._validate_worker_status([mock_logical]) + + # Verify status was NOT changed + mock_physical.set_worker_status.assert_not_called() + + def test_validate_worker_status_ignores_stopped_workers(self, handler: CeleryWorkerHandler, mocker: MockerFixture): + """ + Test that `_validate_worker_status` doesn't check workers already marked as stopped. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + # # Mock Celery app import + # mock_app = MagicMock() + # mocker.patch("merlin.workers.handlers.celery_handler.app", mock_app) + + # Mock get_active_workers (doesn't matter what it returns) + mocker.patch.object(handler, "get_active_workers", return_value={}) + + # Create mock physical worker that's marked as STOPPED + mock_physical = MagicMock() + mock_physical.get_worker_status.return_value = WorkerStatus.STOPPED + mock_physical.get_name.return_value = "celery@stopped_worker" + + # Create mock logical worker + mock_logical = MagicMock() + mock_logical.get_physical_workers.return_value = ["physical_id_1"] + + # Mock database get method + handler.merlin_db.get.return_value = mock_physical + + handler._validate_worker_status([mock_logical]) + + # Verify status was NOT changed (worker already stopped) + mock_physical.set_worker_status.assert_not_called() + + def test_validate_worker_status_handles_multiple_physical_workers( + self, handler: CeleryWorkerHandler, mocker: MockerFixture + ): + """ + Test that `_validate_worker_status` validates all physical workers for a logical worker. + + Args: + handler: CeleryWorkerHandler instance. + mocker: Pytest mocker fixture. + """ + # # Mock Celery app import + # mock_app = MagicMock() + # mocker.patch("merlin.workers.handlers.celery_handler.app", mock_app) + + # Mock get_active_workers - only worker1 is live + mocker.patch.object(handler, "get_active_workers", return_value={"celery@worker1": ["[merlin]_queue1"]}) + + # Create mock physical workers + mock_physical1 = MagicMock() + mock_physical1.get_worker_status.return_value = WorkerStatus.RUNNING + mock_physical1.get_name.return_value = "celery@worker1" + + mock_physical2 = MagicMock() + mock_physical2.get_worker_status.return_value = WorkerStatus.RUNNING + mock_physical2.get_name.return_value = "celery@worker2" + + # Create mock logical worker with multiple physical workers + mock_logical = MagicMock() + mock_logical.get_physical_workers.return_value = ["physical_id_1", "physical_id_2"] + + # Mock database get method to return different workers + handler.merlin_db.get.side_effect = [mock_physical1, mock_physical2] + + handler._validate_worker_status([mock_logical]) + + # Verify worker1 status was NOT changed (it's live) + mock_physical1.set_worker_status.assert_not_called() + + # Verify worker2 status WAS changed (it's not live) + mock_physical2.set_worker_status.assert_called_once() + + def test_get_workers_from_app_returns_worker_list(self, handler: CeleryWorkerHandler): + """ + Test that `get_workers_from_app` returns a list of connected workers. + + Args: + handler: CeleryWorkerHandler instance. + """ + # Mock Celery app inspection + mock_inspect = MagicMock() + handler.app.control.inspect.return_value = mock_inspect + + # Mock ping response with worker names as dict keys + mock_inspect.ping.return_value = { + "celery@worker1": {"ok": "pong"}, + "celery@worker2": {"ok": "pong"}, + "celery@worker3": {"ok": "pong"}, + } + + result = handler.get_workers_from_app() + + expected = ["celery@worker1", "celery@worker2", "celery@worker3"] + assert sorted(result) == sorted(expected) + + def test_get_workers_from_app_handles_no_workers(self, handler: CeleryWorkerHandler): + """ + Test that `get_workers_from_app` returns empty list when no workers are connected. + + Args: + handler: CeleryWorkerHandler instance. + """ + # Mock Celery app inspection + mock_inspect = MagicMock() + handler.app.control.inspect.return_value = mock_inspect + + # Mock ping returning None (no workers) + mock_inspect.ping.return_value = None + + result = handler.get_workers_from_app() + + assert result == [] + + def test_get_workers_from_app_handles_empty_worker_dict(self, handler: CeleryWorkerHandler): + """ + Test that `get_workers_from_app` returns empty list when ping returns empty dict. + + Args: + handler: CeleryWorkerHandler instance. + """ + # Mock Celery app inspection + mock_inspect = MagicMock() + handler.app.control.inspect.return_value = mock_inspect + + # Mock ping returning empty dict + mock_inspect.ping.return_value = {} + + result = handler.get_workers_from_app() + + assert result == [] + + def test_get_workers_from_app_preserves_worker_names(self, handler: CeleryWorkerHandler): + """ + Test that `get_workers_from_app` preserves exact worker names from Celery. + + Args: + handler: CeleryWorkerHandler instance. + """ + # Mock Celery app inspection + mock_inspect = MagicMock() + handler.app.control.inspect.return_value = mock_inspect + + # Mock ping response with various worker name formats + mock_inspect.ping.return_value = { + "celery@worker1.hostname.com": {"ok": "pong"}, + "celery@worker2": {"ok": "pong"}, + "worker3@localhost": {"ok": "pong"}, + } + + result = handler.get_workers_from_app() + + # Verify all names are preserved exactly + assert "celery@worker1.hostname.com" in result + assert "celery@worker2" in result + assert "worker3@localhost" in result + assert len(result) == 3 diff --git a/tests/unit/workers/handlers/test_worker_handler.py b/tests/unit/workers/handlers/test_worker_handler.py index b89b2bc69..7bd566076 100644 --- a/tests/unit/workers/handlers/test_worker_handler.py +++ b/tests/unit/workers/handlers/test_worker_handler.py @@ -9,6 +9,7 @@ """ from typing import Any, Dict, List +from unittest.mock import MagicMock import pytest @@ -28,8 +29,8 @@ def get_metadata(self) -> Dict: class DummyWorkerHandler(MerlinWorkerHandler): - def __init__(self): - super().__init__() + def __init__(self, merlin_db: MagicMock): + super().__init__(merlin_db=merlin_db) self.started = False self.stopped = False self.queried = False @@ -69,11 +70,11 @@ class IncompleteHandler(MerlinWorkerHandler): IncompleteHandler() -def test_launch_workers_calls_worker_launch(): +def test_launch_workers_calls_worker_launch(mock_db_instance: MagicMock): """ Test that `start_workers` calls each worker's `start` method. """ - handler = DummyWorkerHandler() + handler = DummyWorkerHandler(merlin_db=mock_db_instance) workers = [DummyWorker("w1", {}, {}), DummyWorker("w2", {}, {})] result = handler.start_workers(workers) @@ -82,22 +83,22 @@ def test_launch_workers_calls_worker_launch(): assert result == ["launched", "launched"] -def test_stop_workers_sets_flag(): +def test_stop_workers_sets_flag(mock_db_instance: MagicMock): """ Test that `stop_workers` sets the internal state and returns expected value. """ - handler = DummyWorkerHandler() + handler = DummyWorkerHandler(merlin_db=mock_db_instance) response = handler.stop_workers() assert handler.stopped assert response == "Stopped all workers" -def test_query_workers_returns_summary(): +def test_query_workers_returns_summary(mock_db_instance: MagicMock): """ Test that `query_workers` returns a valid summary of current worker state. """ - handler = DummyWorkerHandler() + handler = DummyWorkerHandler(merlin_db=mock_db_instance) workers = [DummyWorker("a", {}, {}), DummyWorker("b", {}, {})] handler.start_workers(workers)