Skip to content

Commit 420f91c

Browse files
committed
refactor: introduce dedicated error types for workflow invocation polling
1 parent 6eb6fc5 commit 420f91c

3 files changed

Lines changed: 48 additions & 25 deletions

File tree

src/runwayml/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@
3636
UnprocessableEntityError,
3737
APIResponseValidationError,
3838
)
39-
from .lib.polling import TaskFailedError, TaskTimeoutError
39+
from .lib.polling import (
40+
TaskFailedError,
41+
TaskTimeoutError,
42+
WorkflowInvocationFailedError,
43+
WorkflowInvocationTimeoutError,
44+
)
4045
from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
4146
from ._utils._logs import setup_logging as _setup_logging
4247

@@ -84,6 +89,8 @@
8489
"DefaultAioHttpClient",
8590
"TaskFailedError",
8691
"TaskTimeoutError",
92+
"WorkflowInvocationFailedError",
93+
"WorkflowInvocationTimeoutError",
8794
]
8895

8996
if not _t.TYPE_CHECKING:

src/runwayml/lib/polling.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ async def wait_for_task_output(self, timeout: Union[float, None] = 60 * 10) -> T
127127

128128

129129
class TaskFailedError(Exception):
130-
def __init__(self, task_details: Union[TaskRetrieveResponse, WorkflowInvocationRetrieveResponse]):
130+
def __init__(self, task_details: TaskRetrieveResponse):
131131
self.task_details = task_details
132132
super().__init__(f"Task failed")
133133

134134

135135
class TaskTimeoutError(Exception):
136-
def __init__(self, task_details: Union[TaskRetrieveResponse, WorkflowInvocationRetrieveResponse]):
136+
def __init__(self, task_details: TaskRetrieveResponse):
137137
self.task_details = task_details
138138
super().__init__(f"Task timed out")
139139

@@ -265,12 +265,14 @@ def wait_for_task_output(self, timeout: Union[float, None] = 60 * 10) -> Workflo
265265
"""
266266
When called, this will block until the workflow invocation is complete.
267267
268-
If the invocation fails or is cancelled, a `TaskFailedError` will be raised.
268+
If the invocation fails or is cancelled, a `WorkflowInvocationFailedError`
269+
will be raised.
269270
270271
Args:
271272
timeout: The maximum amount of time to wait in seconds. If not specified,
272-
the default timeout is 10 minutes. Will raise a `TaskTimeoutError` if
273-
the invocation does not complete within the timeout.
273+
the default timeout is 10 minutes. Will raise a
274+
`WorkflowInvocationTimeoutError` if the invocation does not complete
275+
within the timeout.
274276
275277
Returns:
276278
The workflow invocation details, equivalent to calling
@@ -284,13 +286,15 @@ async def wait_for_task_output(self, timeout: Union[float, None] = 60 * 10) -> W
284286
"""
285287
When called, this will wait until the workflow invocation is complete.
286288
287-
If the invocation fails or is cancelled, a `TaskFailedError` will be raised.
289+
If the invocation fails or is cancelled, a `WorkflowInvocationFailedError`
290+
will be raised.
288291
289292
Args:
290293
timeout: The maximum amount of time to wait in seconds. If not specified,
291-
the default timeout is 10 minutes. Will raise a `TaskTimeoutError` if
292-
the invocation does not complete within the timeout. Setting this to
293-
`None` will wait indefinitely (disabling the timeout).
294+
the default timeout is 10 minutes. Will raise a
295+
`WorkflowInvocationTimeoutError` if the invocation does not complete
296+
within the timeout. Setting this to `None` will wait indefinitely
297+
(disabling the timeout).
294298
295299
Returns:
296300
The workflow invocation details, equivalent to awaiting
@@ -361,6 +365,18 @@ class AsyncAwaitableWISucceeded(AsyncAwaitableWorkflowInvocationResponseMixin, W
361365
]
362366

363367

368+
class WorkflowInvocationFailedError(Exception):
369+
def __init__(self, invocation_details: WorkflowInvocationRetrieveResponse):
370+
self.invocation_details = invocation_details
371+
super().__init__("Workflow invocation failed")
372+
373+
374+
class WorkflowInvocationTimeoutError(Exception):
375+
def __init__(self, invocation_details: WorkflowInvocationRetrieveResponse):
376+
self.invocation_details = invocation_details
377+
super().__init__("Workflow invocation timed out")
378+
379+
364380
def _make_sync_wait_for_workflow_invocation_output(
365381
client: "RunwayML",
366382
) -> Callable[["AwaitableWorkflowInvocationResponseMixin", Union[float, None]], WorkflowInvocationRetrieveResponse]:
@@ -374,9 +390,9 @@ def wait_for_task_output(
374390
if details.status == "SUCCEEDED":
375391
return details
376392
if details.status == "FAILED" or details.status == "CANCELLED":
377-
raise TaskFailedError(details)
393+
raise WorkflowInvocationFailedError(details)
378394
if timeout is not None and time.time() - start_time > timeout:
379-
raise TaskTimeoutError(details)
395+
raise WorkflowInvocationTimeoutError(details)
380396

381397
return wait_for_task_output
382398

@@ -406,9 +422,9 @@ async def wait_for_task_output(
406422
if details.status == "SUCCEEDED":
407423
return details
408424
if details.status == "FAILED" or details.status == "CANCELLED":
409-
raise TaskFailedError(details)
425+
raise WorkflowInvocationFailedError(details)
410426
if timeout is not None and anyio.current_time() - start_time > timeout:
411-
raise TaskTimeoutError(details)
427+
raise WorkflowInvocationTimeoutError(details)
412428

413429
return wait_for_task_output
414430

tests/test_workflow_invocation_polling.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88

9-
from runwayml import RunwayML, AsyncRunwayML, TaskFailedError, TaskTimeoutError
9+
from runwayml import RunwayML, AsyncRunwayML, WorkflowInvocationFailedError, WorkflowInvocationTimeoutError
1010
from runwayml.types.workflow_invocation_retrieve_response import (
1111
Failed,
1212
Pending,
@@ -84,10 +84,10 @@ def test_raises_on_failed(self, _mock_sleep: MagicMock) -> None:
8484
response = _make_pending()
8585
patched: Any = inject_sync_workflow_invocation_wait_method(client, response)
8686

87-
with pytest.raises(TaskFailedError) as exc_info:
87+
with pytest.raises(WorkflowInvocationFailedError) as exc_info:
8888
patched.wait_for_task_output()
8989

90-
assert exc_info.value.task_details.status == "FAILED"
90+
assert exc_info.value.invocation_details.status == "FAILED"
9191

9292
@patch("runwayml.lib.polling.time.sleep", return_value=None)
9393
def test_raises_on_cancelled(self, _mock_sleep: MagicMock) -> None:
@@ -100,10 +100,10 @@ def test_raises_on_cancelled(self, _mock_sleep: MagicMock) -> None:
100100
response = _make_pending()
101101
patched: Any = inject_sync_workflow_invocation_wait_method(client, response)
102102

103-
with pytest.raises(TaskFailedError) as exc_info:
103+
with pytest.raises(WorkflowInvocationFailedError) as exc_info:
104104
patched.wait_for_task_output()
105105

106-
assert exc_info.value.task_details.status == "CANCELLED"
106+
assert exc_info.value.invocation_details.status == "CANCELLED"
107107

108108
@patch("runwayml.lib.polling.time.time")
109109
@patch("runwayml.lib.polling.time.sleep", return_value=None)
@@ -119,7 +119,7 @@ def test_raises_on_timeout(self, _mock_sleep: MagicMock, mock_time: MagicMock) -
119119
response = _make_pending()
120120
patched: Any = inject_sync_workflow_invocation_wait_method(client, response)
121121

122-
with pytest.raises(TaskTimeoutError):
122+
with pytest.raises(WorkflowInvocationTimeoutError):
123123
patched.wait_for_task_output(timeout=600)
124124

125125

@@ -169,10 +169,10 @@ async def test_raises_on_failed(self, _mock_time: MagicMock, _mock_sleep: AsyncM
169169
response = _make_pending()
170170
patched: Any = inject_async_workflow_invocation_wait_method(client, response)
171171

172-
with pytest.raises(TaskFailedError) as exc_info:
172+
with pytest.raises(WorkflowInvocationFailedError) as exc_info:
173173
await patched.wait_for_task_output()
174174

175-
assert exc_info.value.task_details.status == "FAILED"
175+
assert exc_info.value.invocation_details.status == "FAILED"
176176

177177
@patch("runwayml.lib.polling.anyio.sleep", new_callable=AsyncMock)
178178
@patch("runwayml.lib.polling.anyio.current_time", return_value=0.0)
@@ -186,10 +186,10 @@ async def test_raises_on_cancelled(self, _mock_time: MagicMock, _mock_sleep: Asy
186186
response = _make_pending()
187187
patched: Any = inject_async_workflow_invocation_wait_method(client, response)
188188

189-
with pytest.raises(TaskFailedError) as exc_info:
189+
with pytest.raises(WorkflowInvocationFailedError) as exc_info:
190190
await patched.wait_for_task_output()
191191

192-
assert exc_info.value.task_details.status == "CANCELLED"
192+
assert exc_info.value.invocation_details.status == "CANCELLED"
193193

194194
@patch("runwayml.lib.polling.anyio.sleep", new_callable=AsyncMock)
195195
@patch("runwayml.lib.polling.anyio.current_time")
@@ -205,5 +205,5 @@ async def test_raises_on_timeout(self, mock_time: MagicMock, _mock_sleep: AsyncM
205205
response = _make_pending()
206206
patched: Any = inject_async_workflow_invocation_wait_method(client, response)
207207

208-
with pytest.raises(TaskTimeoutError):
208+
with pytest.raises(WorkflowInvocationTimeoutError):
209209
await patched.wait_for_task_output(timeout=600)

0 commit comments

Comments
 (0)