Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Any, Optional, TypeVar
from typing import Any, Optional, TypeVar, Union

import durabletask.internal.orchestrator_service_pb2 as pb
from dapr.ext.workflow.logger import Logger, LoggerOptions
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(

async def schedule_new_workflow(
self,
workflow: Workflow,
workflow: Union[Workflow, str],
*,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
Expand All @@ -82,7 +82,7 @@ async def schedule_new_workflow(
"""Schedules a new workflow instance for execution.

Args:
workflow: The workflow to schedule.
workflow: The workflow to schedule. Can be a workflow callable or a workflow name string.
input: The optional input to pass to the scheduled workflow instance. This must be a
serializable value.
instance_id: The unique ID of the workflow instance to schedule. If not specified, a
Expand All @@ -96,11 +96,12 @@ async def schedule_new_workflow(
Returns:
The ID of the scheduled workflow instance.
"""
workflow_name = (
workflow.__dict__['_dapr_alternate_name']
if hasattr(workflow, '_dapr_alternate_name')
else workflow.__name__
)
if isinstance(workflow, str):
workflow_name = workflow
elif hasattr(workflow, '_dapr_alternate_name'):
workflow_name = workflow.__dict__['_dapr_alternate_name']
else:
workflow_name = workflow.__name__
return await self.__obj.schedule_new_orchestration(
workflow_name,
input=input,
Expand Down
22 changes: 10 additions & 12 deletions ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Any, Optional, TypeVar
from typing import Any, Optional, TypeVar, Union

import durabletask.internal.orchestrator_service_pb2 as pb
from dapr.ext.workflow.logger import Logger, LoggerOptions
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(

def schedule_new_workflow(
self,
workflow: Workflow,
workflow: Union[Workflow, str],
*,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
Expand All @@ -85,7 +85,7 @@ def schedule_new_workflow(
"""Schedules a new workflow instance for execution.

Args:
workflow: The workflow to schedule.
workflow: The workflow to schedule. Can be a workflow callable or a workflow name string.
input: The optional input to pass to the scheduled workflow instance. This must be a
serializable value.
instance_id: The unique ID of the workflow instance to schedule. If not specified, a
Expand All @@ -99,16 +99,14 @@ def schedule_new_workflow(
Returns:
The ID of the scheduled workflow instance.
"""
if hasattr(workflow, '_dapr_alternate_name'):
return self.__obj.schedule_new_orchestration(
workflow.__dict__['_dapr_alternate_name'],
input=input,
instance_id=instance_id,
start_at=start_at,
reuse_id_policy=reuse_id_policy,
)
if isinstance(workflow, str):
workflow_name = workflow
elif hasattr(workflow, '_dapr_alternate_name'):
workflow_name = workflow.__dict__['_dapr_alternate_name']
else:
workflow_name = workflow.__name__
return self.__obj.schedule_new_orchestration(
workflow.__name__,
workflow_name,
input=input,
instance_id=instance_id,
start_at=start_at,
Expand Down
12 changes: 12 additions & 0 deletions ext/dapr-ext-workflow/tests/test_workflow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def details(self):


class FakeTaskHubGrpcClient:
def __init__(self):
self.last_scheduled_workflow_name = None

def schedule_new_orchestration(
self,
workflow,
Expand All @@ -55,6 +58,7 @@ def schedule_new_orchestration(
start_at,
reuse_id_policy: Union[pb.OrchestrationIdReusePolicy, None] = None,
):
self.last_scheduled_workflow_name = workflow
return mock_schedule_result

def get_orchestration_state(self, instance_id, fetch_payloads):
Expand Down Expand Up @@ -112,6 +116,14 @@ class WorkflowClientTest(unittest.TestCase):
def mock_client_wf(ctx: DaprWorkflowContext, input):
print(f'{input}')

def test_schedule_workflow_by_name_string(self):
fake_client = FakeTaskHubGrpcClient()
with mock.patch('durabletask.client.TaskHubGrpcClient', return_value=fake_client):
wfClient = DaprWorkflowClient()
result = wfClient.schedule_new_workflow(workflow='my_registered_workflow', input='data')
assert result == mock_schedule_result
assert fake_client.last_scheduled_workflow_name == 'my_registered_workflow'

def test_client_functions(self):
with mock.patch(
'durabletask.client.TaskHubGrpcClient', return_value=FakeTaskHubGrpcClient()
Expand Down
14 changes: 14 additions & 0 deletions ext/dapr-ext-workflow/tests/test_workflow_client_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def details(self):


class FakeAsyncTaskHubGrpcClient:
def __init__(self):
self.last_scheduled_workflow_name = None

async def schedule_new_orchestration(
self,
workflow,
Expand All @@ -56,6 +59,7 @@ async def schedule_new_orchestration(
start_at,
reuse_id_policy: Union[pb.OrchestrationIdReusePolicy, None] = None,
):
self.last_scheduled_workflow_name = workflow
return mock_schedule_result

async def get_orchestration_state(self, instance_id, *, fetch_payloads):
Expand Down Expand Up @@ -113,6 +117,16 @@ class WorkflowClientAioTest(unittest.IsolatedAsyncioTestCase):
def mock_client_wf(ctx: DaprWorkflowContext, input):
print(f'{input}')

async def test_schedule_workflow_by_name_string(self):
fake_client = FakeAsyncTaskHubGrpcClient()
with mock.patch('durabletask.aio.client.AsyncTaskHubGrpcClient', return_value=fake_client):
wfClient = DaprWorkflowClient()
result = await wfClient.schedule_new_workflow(
workflow='my_registered_workflow', input='data'
)
assert result == mock_schedule_result
assert fake_client.last_scheduled_workflow_name == 'my_registered_workflow'

async def test_client_functions(self):
with mock.patch(
'durabletask.aio.client.AsyncTaskHubGrpcClient',
Expand Down
Loading