Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/scripts/path_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
"copy/paste",
"delete/recreate",
"edit/score",
"Enable/disable",
"file/console",
"files/functions",
"I/O",
Expand Down
4 changes: 4 additions & 0 deletions ci/vale/styles/config/vocabularies/nat/accept.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,7 @@ Zep
Optuna
[Oo]ptimizable
[Cc]heckpointed
[Ss]anitization
[Uu]nregister(ing|ed|s)?
[Aa]rgs
[Kk]wargs
337 changes: 233 additions & 104 deletions docs/source/build-workflows/advanced/middleware.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/nat/data_models/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import typing

from .common import BaseModelRegistryTag
Expand Down
6 changes: 2 additions & 4 deletions src/nat/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
"""Middleware implementations for NeMo Agent Toolkit."""

from nat.middleware.cache_middleware import CacheMiddleware
from nat.middleware.function_middleware import FunctionMiddleware
from nat.middleware.function_middleware import FunctionMiddlewareChain
from nat.middleware.function_middleware import validate_middleware
Expand All @@ -24,12 +23,11 @@
from nat.middleware.middleware import Middleware

__all__ = [
"CacheMiddleware",
"CallNext",
"CallNextStream",
"FunctionMiddlewareContext",
"Middleware",
"FunctionMiddleware",
"FunctionMiddlewareChain",
"FunctionMiddlewareContext",
"Middleware",
"validate_middleware",
]
14 changes: 14 additions & 0 deletions src/nat/middleware/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,15 @@
import logging
from collections.abc import AsyncIterator
from typing import Any
from typing import Literal

from pydantic import Field

from nat.builder.context import Context
from nat.builder.context import ContextState
from nat.data_models.middleware import FunctionMiddlewareBaseConfig
from nat.middleware.function_middleware import CallNext
from nat.middleware.function_middleware import CallNextStream
from nat.middleware.function_middleware import FunctionMiddleware
from nat.middleware.function_middleware import FunctionMiddlewareContext
from nat.middleware.middleware import PostInvokeContext
from nat.middleware.middleware import PreInvokeContext

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,6 +82,23 @@ def __init__(self, *, enabled_mode: str, similarity_threshold: float) -> None:
self._similarity_threshold = similarity_threshold
self._cache: dict[str, Any] = {}

# ==================== Abstract Method Implementations ====================

@property
def enabled(self) -> bool:
"""Middleware always enabled."""
return True

async def pre_invoke(self, context: PreInvokeContext) -> PreInvokeContext | None:
"""Not used - CacheMiddleware overrides function_middleware_invoke."""
return None

async def post_invoke(self, context: PostInvokeContext) -> PostInvokeContext | None:
"""Not used - CacheMiddleware overrides function_middleware_invoke."""
return None

# ==================== Cache Logic ====================

def _should_cache(self) -> bool:
"""Check if caching should be enabled based on the current context."""
if self._enabled_mode == "always":
Expand Down Expand Up @@ -145,8 +160,11 @@ def _find_similar_key(self, input_str: str) -> str | None:

return best_match

async def function_middleware_invoke(self, value: Any, call_next: CallNext,
context: FunctionMiddlewareContext) -> Any:
async def function_middleware_invoke(self,
*args: Any,
call_next: CallNext,
context: FunctionMiddlewareContext,
**kwargs: Any) -> Any:
"""Cache middleware for single-output invocations.

Implements the four-phase middleware pattern:
Expand All @@ -157,23 +175,27 @@ async def function_middleware_invoke(self, value: Any, call_next: CallNext,
4. **Continue**: Return the result (cached or fresh)

Args:
value: The input value to process
*args: The positional arguments to process
call_next: Callable to invoke the next middleware or function
context: Metadata about the function being wrapped
**kwargs: Additional function arguments

Returns:
The cached output if found, otherwise the fresh output
"""
# Phase 1: Preprocess - check if caching should be enabled
# Check if caching should be enabled for this invocation
if not self._should_cache():
return await call_next(value)
return await call_next(*args, **kwargs)

# Use first arg as cache key (primary input)
value = args[0] if args else None

# Phase 1: Preprocess - serialize the input
input_str = self._serialize_input(value)
if input_str is None:
# Can't serialize, pass through to next middleware/function
logger.debug("Could not serialize input for function %s, bypassing cache", context.name)
return await call_next(value)
return await call_next(*args, **kwargs)

# Phase 1: Preprocess - look for a similar cached input
similar_key = self._find_similar_key(input_str)
Expand All @@ -187,7 +209,7 @@ async def function_middleware_invoke(self, value: Any, call_next: CallNext,

# Phase 2: Call next - no cache hit, call next middleware/function
logger.debug("Cache miss for function %s", context.name)
result = await call_next(value)
result = await call_next(*args, **kwargs)

# Phase 3: Postprocess - cache the result for future use
self._cache[input_str] = result
Expand All @@ -197,9 +219,10 @@ async def function_middleware_invoke(self, value: Any, call_next: CallNext,
return result

async def function_middleware_stream(self,
value: Any,
*args: Any,
call_next: CallNextStream,
context: FunctionMiddlewareContext) -> AsyncIterator[Any]:
context: FunctionMiddlewareContext,
**kwargs: Any) -> AsyncIterator[Any]:
"""Cache middleware for streaming invocations - bypasses caching.

Streaming results are not cached as they would need to be buffered
Expand All @@ -213,9 +236,10 @@ async def function_middleware_stream(self,
4. **Continue**: Complete the stream

Args:
value: The input value to process
*args: The positional arguments to process
call_next: Callable to invoke the next middleware or function stream
context: Metadata about the function being wrapped
**kwargs: Additional function arguments

Yields:
Chunks from the stream (unmodified)
Expand All @@ -224,33 +248,7 @@ async def function_middleware_stream(self,
logger.debug("Streaming call for function %s, bypassing cache", context.name)

# Phase 2-3: Call next and process chunks - yield chunks as they arrive
async for chunk in call_next(value):
async for chunk in call_next(*args, **kwargs):
yield chunk

# Phase 4: Continue - stream is complete (implicit)


class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
"""Configuration for cache middleware.

The cache middleware memoizes function outputs based on input similarity,
with support for both exact and fuzzy matching.

Args:
enabled_mode: Controls when caching is active:
- "always": Cache is always enabled
- "eval": Cache only active when Context.is_evaluating is True
similarity_threshold: Float between 0 and 1 for input matching:
- 1.0: Exact string matching (fastest)
- < 1.0: Fuzzy matching using difflib similarity
"""

enabled_mode: Literal["always", "eval"] = Field(
default="eval", description="When caching is enabled: 'always' or 'eval' (only during evaluation)")
similarity_threshold: float = Field(default=1.0,
ge=0.0,
le=1.0,
description="Similarity threshold between 0 and 1. Use 1.0 for exact matching")


__all__ = ["CacheMiddleware", "CacheMiddlewareConfig"]
44 changes: 44 additions & 0 deletions src/nat/middleware/cache/cache_middleware_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for cache middleware."""

from typing import Literal

from pydantic import Field

from nat.data_models.middleware import FunctionMiddlewareBaseConfig


class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
"""Configuration for cache middleware.

The cache middleware memoizes function outputs based on input similarity,
with support for both exact and fuzzy matching.

Args:
enabled_mode: Controls when caching is active:
- "always": Cache is always enabled
- "eval": Cache only active when Context.is_evaluating is True
similarity_threshold: Float between 0 and 1 for input matching:
- 1.0: Exact string matching (fastest)
- < 1.0: Fuzzy matching using difflib similarity
"""

enabled_mode: Literal["always", "eval"] = Field(
default="eval", description="When caching is enabled: 'always' or 'eval' (only during evaluation)")
similarity_threshold: float = Field(default=1.0,
ge=0.0,
le=1.0,
description="Similarity threshold between 0 and 1. Use 1.0 for exact matching")
33 changes: 33 additions & 0 deletions src/nat/middleware/cache/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nat.builder.builder import Builder
from nat.cli.register_workflow import register_middleware
from nat.middleware.cache.cache_middleware import CacheMiddleware
from nat.middleware.cache.cache_middleware_config import CacheMiddlewareConfig


@register_middleware(config_type=CacheMiddlewareConfig)
async def cache_middleware(config: CacheMiddlewareConfig, builder: Builder):
"""Build a cache middleware from configuration.

Args:
config: The cache middleware configuration
builder: The workflow builder (unused but required by component pattern)

Yields:
A configured cache middleware instance
"""
yield CacheMiddleware(enabled_mode=config.enabled_mode, similarity_threshold=config.similarity_threshold)
14 changes: 14 additions & 0 deletions src/nat/middleware/dynamic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading