diff --git a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py index cdcc781767648..103db30e807fd 100644 --- a/devel-common/src/tests_common/test_utils/operators/run_deferrable.py +++ b/devel-common/src/tests_common/test_utils/operators/run_deferrable.py @@ -21,10 +21,14 @@ from typing import TYPE_CHECKING, Any from airflow.exceptions import TaskDeferred +from airflow.utils.module_loading import import_string +from airflow.utils.session import NEW_SESSION from tests_common.test_utils.mock_context import mock_context if TYPE_CHECKING: + from sqlalchemy.orm import Session + from airflow.models import Operator from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -45,14 +49,25 @@ def execute_operator(operator: Operator) -> tuple[Any, Any]: return asyncio.run(deferrable_operator(context, operator)) -async def deferrable_operator(context, operator): +async def deferrable_operator(context, operator, session: Session = NEW_SESSION): result = None triggered_events = [] try: + if operator.start_from_trigger: + trigger_cls = import_string(operator.start_trigger_args.trigger_cls) + trigger = trigger_cls( + **operator.expand_start_trigger_args(context=context, session=session).trigger_kwargs + ) + raise TaskDeferred( + trigger=trigger, + method_name=operator.start_trigger_args.next_method, + kwargs=operator.start_trigger_args.next_kwargs, + timeout=operator.start_trigger_args.timeout, + ) operator.render_template_fields(context=context) result = operator.execute(context=context) except TaskDeferred as deferred: - task = deferred + task: TaskDeferred | None = deferred while task: events = await run_tigger(task.trigger) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index 5e008d50e8449..8d04383f8217a 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -32,7 +32,7 @@ MSGraphTrigger, ResponseSerializer, ) -from airflow.providers.microsoft.azure.version_compat import XCOM_RETURN_KEY, BaseOperator +from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_1_PLUS, XCOM_RETURN_KEY, BaseOperator if TYPE_CHECKING: from io import BytesIO @@ -109,6 +109,7 @@ class MSGraphAsyncOperator(BaseOperator): Bytes will be base64 encoded into a string, so it can be stored as an XCom. """ + start_from_trigger = AIRFLOW_V_3_1_PLUS template_fields: Sequence[str] = ( "url", "response_type", @@ -162,27 +163,55 @@ def __init__( self.result_processor = result_processor self.event_handler = event_handler or default_event_handler self.serializer: ResponseSerializer = serializer() + if self.start_from_trigger: + try: + from airflow.triggers.base import StartTriggerArgs + + self.start_trigger_args = StartTriggerArgs( + trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", + trigger_kwargs=dict( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", + ), + next_method=self.execute_complete.__name__, + ) + except ImportError: + self.start_from_trigger = False def execute(self, context: Context) -> None: - self.defer( - trigger=MSGraphTrigger( - url=self.url, - response_type=self.response_type, - path_parameters=self.path_parameters, - url_template=self.url_template, - method=self.method, - query_parameters=self.query_parameters, - headers=self.headers, - data=self.data, - conn_id=self.conn_id, - timeout=self.timeout, - proxies=self.proxies, - scopes=self.scopes, - api_version=self.api_version, - serializer=type(self.serializer), - ), - method_name=self.execute_complete.__name__, - ) + if not self.start_from_trigger: + self.defer( + trigger=MSGraphTrigger( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=type(self.serializer), + ), + method_name=self.execute_complete.__name__, + ) + return def execute_complete( self, @@ -228,14 +257,14 @@ def execute_complete( self.trigger_next_link( response=response, method_name=self.execute_complete.__name__, context=context ) - except TaskDeferred as exception: + except TaskDeferred as task_deferred: self.append_result( results=results, result=result, append_result_as_list_if_absent=True, ) self.push_xcom(context=context, value=results) - raise exception + raise task_deferred if not results: return result diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index 2472ef71c01e3..c0b503a90a3c5 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -25,7 +25,10 @@ from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook from airflow.providers.microsoft.azure.operators.msgraph import execute_callable from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer -from airflow.providers.microsoft.azure.version_compat import BaseSensorOperator +from airflow.providers.microsoft.azure.version_compat import ( + AIRFLOW_V_3_1_PLUS, + BaseSensorOperator, +) if TYPE_CHECKING: from datetime import timedelta @@ -60,6 +63,7 @@ class MSGraphSensor(BaseSensorOperator): Bytes will be base64 encoded into a string, so it can be stored as an XCom. """ + start_from_trigger = AIRFLOW_V_3_1_PLUS template_fields: Sequence[str] = ( "url", "response_type", @@ -107,8 +111,61 @@ def __init__( self.event_processor = event_processor self.result_processor = result_processor self.serializer = serializer() + if self.start_from_trigger: + try: + from airflow.triggers.base import StartTriggerArgs + + self.start_trigger_args = StartTriggerArgs( + trigger_cls=f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}", + trigger_kwargs=dict( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=f"{type(self.serializer).__module__}.{type(self.serializer).__name__}", + ), + next_method=self.execute_complete.__name__, + ) + except ImportError: + self.start_from_trigger = False + + def execute(self, context: Context) -> None: + if not self.start_from_trigger: + self.defer( + trigger=MSGraphTrigger( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + serializer=type(self.serializer), + ), + method_name=self.execute_complete.__name__, + ) + return - def execute(self, context: Context): + def retry_execute( + self, + context: Context, + **kwargs, + ) -> Any: self.defer( trigger=MSGraphTrigger( url=self.url, @@ -129,13 +186,6 @@ def execute(self, context: Context): method_name=self.execute_complete.__name__, ) - def retry_execute( - self, - context: Context, - **kwargs, - ) -> Any: - self.execute(context=context) - def execute_complete( self, context: Context,