From 5b2397487dcfe7940e9c34cca2ee92eac1306382 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Fri, 13 Mar 2026 08:42:31 -0500 Subject: [PATCH] feat: enable wf to start via just name like go sdk Signed-off-by: Samantha Coyle --- .../ext/workflow/aio/dapr_workflow_client.py | 17 +++++++------- .../dapr/ext/workflow/dapr_workflow_client.py | 22 +++++++++---------- .../tests/test_workflow_client.py | 12 ++++++++++ .../tests/test_workflow_client_aio.py | 14 ++++++++++++ 4 files changed, 45 insertions(+), 20 deletions(-) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py index cd5e632f1..781a14f90 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py @@ -16,7 +16,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any, Optional, TypeVar +from typing import Any, Optional, TypeVar, Union import durabletask.internal.orchestrator_service_pb2 as pb from dapr.ext.workflow.logger import Logger, LoggerOptions @@ -72,7 +72,7 @@ def __init__( async def schedule_new_workflow( self, - workflow: Workflow, + workflow: Union[Workflow, str], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, @@ -82,7 +82,7 @@ async def schedule_new_workflow( """Schedules a new workflow instance for execution. Args: - workflow: The workflow to schedule. + workflow: The workflow to schedule. Can be a workflow callable or a workflow name string. input: The optional input to pass to the scheduled workflow instance. This must be a serializable value. instance_id: The unique ID of the workflow instance to schedule. If not specified, a @@ -96,11 +96,12 @@ async def schedule_new_workflow( Returns: The ID of the scheduled workflow instance. """ - workflow_name = ( - workflow.__dict__['_dapr_alternate_name'] - if hasattr(workflow, '_dapr_alternate_name') - else workflow.__name__ - ) + if isinstance(workflow, str): + workflow_name = workflow + elif hasattr(workflow, '_dapr_alternate_name'): + workflow_name = workflow.__dict__['_dapr_alternate_name'] + else: + workflow_name = workflow.__name__ return await self.__obj.schedule_new_orchestration( workflow_name, input=input, diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py index 36a731c47..527ca4a67 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py @@ -16,7 +16,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any, Optional, TypeVar +from typing import Any, Optional, TypeVar, Union import durabletask.internal.orchestrator_service_pb2 as pb from dapr.ext.workflow.logger import Logger, LoggerOptions @@ -75,7 +75,7 @@ def __init__( def schedule_new_workflow( self, - workflow: Workflow, + workflow: Union[Workflow, str], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, @@ -85,7 +85,7 @@ def schedule_new_workflow( """Schedules a new workflow instance for execution. Args: - workflow: The workflow to schedule. + workflow: The workflow to schedule. Can be a workflow callable or a workflow name string. input: The optional input to pass to the scheduled workflow instance. This must be a serializable value. instance_id: The unique ID of the workflow instance to schedule. If not specified, a @@ -99,16 +99,14 @@ def schedule_new_workflow( Returns: The ID of the scheduled workflow instance. """ - if hasattr(workflow, '_dapr_alternate_name'): - return self.__obj.schedule_new_orchestration( - workflow.__dict__['_dapr_alternate_name'], - input=input, - instance_id=instance_id, - start_at=start_at, - reuse_id_policy=reuse_id_policy, - ) + if isinstance(workflow, str): + workflow_name = workflow + elif hasattr(workflow, '_dapr_alternate_name'): + workflow_name = workflow.__dict__['_dapr_alternate_name'] + else: + workflow_name = workflow.__name__ return self.__obj.schedule_new_orchestration( - workflow.__name__, + workflow_name, input=input, instance_id=instance_id, start_at=start_at, diff --git a/ext/dapr-ext-workflow/tests/test_workflow_client.py b/ext/dapr-ext-workflow/tests/test_workflow_client.py index a12a8844b..7d66d68bc 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_client.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_client.py @@ -47,6 +47,9 @@ def details(self): class FakeTaskHubGrpcClient: + def __init__(self): + self.last_scheduled_workflow_name = None + def schedule_new_orchestration( self, workflow, @@ -55,6 +58,7 @@ def schedule_new_orchestration( start_at, reuse_id_policy: Union[pb.OrchestrationIdReusePolicy, None] = None, ): + self.last_scheduled_workflow_name = workflow return mock_schedule_result def get_orchestration_state(self, instance_id, fetch_payloads): @@ -112,6 +116,14 @@ class WorkflowClientTest(unittest.TestCase): def mock_client_wf(ctx: DaprWorkflowContext, input): print(f'{input}') + def test_schedule_workflow_by_name_string(self): + fake_client = FakeTaskHubGrpcClient() + with mock.patch('durabletask.client.TaskHubGrpcClient', return_value=fake_client): + wfClient = DaprWorkflowClient() + result = wfClient.schedule_new_workflow(workflow='my_registered_workflow', input='data') + assert result == mock_schedule_result + assert fake_client.last_scheduled_workflow_name == 'my_registered_workflow' + def test_client_functions(self): with mock.patch( 'durabletask.client.TaskHubGrpcClient', return_value=FakeTaskHubGrpcClient() diff --git a/ext/dapr-ext-workflow/tests/test_workflow_client_aio.py b/ext/dapr-ext-workflow/tests/test_workflow_client_aio.py index c84fcbfe6..d27047ced 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_client_aio.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_client_aio.py @@ -47,6 +47,9 @@ def details(self): class FakeAsyncTaskHubGrpcClient: + def __init__(self): + self.last_scheduled_workflow_name = None + async def schedule_new_orchestration( self, workflow, @@ -56,6 +59,7 @@ async def schedule_new_orchestration( start_at, reuse_id_policy: Union[pb.OrchestrationIdReusePolicy, None] = None, ): + self.last_scheduled_workflow_name = workflow return mock_schedule_result async def get_orchestration_state(self, instance_id, *, fetch_payloads): @@ -113,6 +117,16 @@ class WorkflowClientAioTest(unittest.IsolatedAsyncioTestCase): def mock_client_wf(ctx: DaprWorkflowContext, input): print(f'{input}') + async def test_schedule_workflow_by_name_string(self): + fake_client = FakeAsyncTaskHubGrpcClient() + with mock.patch('durabletask.aio.client.AsyncTaskHubGrpcClient', return_value=fake_client): + wfClient = DaprWorkflowClient() + result = await wfClient.schedule_new_workflow( + workflow='my_registered_workflow', input='data' + ) + assert result == mock_schedule_result + assert fake_client.last_scheduled_workflow_name == 'my_registered_workflow' + async def test_client_functions(self): with mock.patch( 'durabletask.aio.client.AsyncTaskHubGrpcClient',