Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from dapr.ext.workflow.workflow_state import WorkflowState
from grpc.aio import AioRpcError

from dapr.aio.clients.grpc.interceptors import DaprClientTimeoutInterceptorAsync
from dapr.clients import DaprInternalError
from dapr.clients.http.client import DAPR_API_TOKEN_HEADER
from dapr.conf import settings
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
secure_channel=uri.tls,
log_handler=options.log_handler,
log_formatter=options.log_formatter,
interceptors=[DaprClientTimeoutInterceptorAsync()],
)

async def schedule_new_workflow(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from grpc import RpcError

from dapr.clients import DaprInternalError
from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor
from dapr.clients.http.client import DAPR_API_TOKEN_HEADER
from dapr.conf import settings
from dapr.conf.helpers import GrpcEndpoint
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
secure_channel=uri.tls,
log_handler=options.log_handler,
log_formatter=options.log_formatter,
interceptors=[DaprClientTimeoutInterceptor()],
)

def schedule_new_workflow(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dapr.ext.workflow.workflow_context import Workflow

from dapr.clients import DaprInternalError
from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor
from dapr.clients.http.client import DAPR_API_TOKEN_HEADER
from dapr.conf import settings
from dapr.conf.helpers import GrpcEndpoint
Expand Down Expand Up @@ -71,13 +72,17 @@ def __init__(
raise DaprInternalError(f'{error}') from error

options = self._logger.get_options()
all_interceptors = []
if interceptors:
all_interceptors.extend(interceptors)
all_interceptors.append(DaprClientTimeoutInterceptor())
self.__worker = worker.TaskHubGrpcWorker(
host_address=uri.endpoint,
metadata=metadata,
secure_channel=uri.tls,
log_handler=options.log_handler,
log_formatter=options.log_formatter,
interceptors=interceptors,
interceptors=all_interceptors,
concurrency_options=worker.ConcurrencyOptions(
maximum_concurrent_activity_work_items=maximum_concurrent_activity_work_items,
maximum_concurrent_orchestration_work_items=maximum_concurrent_orchestration_work_items,
Expand Down
15 changes: 15 additions & 0 deletions ext/dapr-ext-workflow/tests/test_workflow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,19 @@ def _inner_get_orchestration_state(self, instance_id, state: client.Orchestratio
)


class WorkflowClientTimeoutInterceptorTest(unittest.TestCase):
def test_timeout_interceptor_is_passed_to_client(self):
with mock.patch('durabletask.client.TaskHubGrpcClient') as mock_client_cls:
DaprWorkflowClient()
mock_client_cls.assert_called_once()
call_kwargs = mock_client_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 1)
from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor)


class WorkflowClientTest(unittest.TestCase):
def mock_client_wf(ctx: DaprWorkflowContext, input):
print(f'{input}')
Expand Down Expand Up @@ -186,3 +199,5 @@ def test_client_functions(self):

actual_purge_result = wfClient.purge_workflow(instance_id=mock_instance_id)
assert actual_purge_result == mock_purge_result
actual_purge_result = wfClient.purge_workflow(instance_id=mock_instance_id)
assert actual_purge_result == mock_purge_result
15 changes: 15 additions & 0 deletions ext/dapr-ext-workflow/tests/test_workflow_client_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ def _inner_get_orchestration_state(self, instance_id, state: client.Orchestratio
)


class WorkflowClientAioTimeoutInterceptorTest(unittest.IsolatedAsyncioTestCase):
async def test_timeout_interceptor_is_passed_to_client(self):
with mock.patch('durabletask.aio.client.AsyncTaskHubGrpcClient') as mock_client_cls:
DaprWorkflowClient()
mock_client_cls.assert_called_once()
call_kwargs = mock_client_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 1)
from dapr.aio.clients.grpc.interceptors import DaprClientTimeoutInterceptorAsync

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptorAsync)


class WorkflowClientAioTest(unittest.IsolatedAsyncioTestCase):
def mock_client_wf(ctx: DaprWorkflowContext, input):
print(f'{input}')
Expand Down Expand Up @@ -190,3 +203,5 @@ async def test_client_functions(self):

actual_purge_result = await wfClient.purge_workflow(instance_id=mock_instance_id)
assert actual_purge_result == mock_purge_result
actual_purge_result = await wfClient.purge_workflow(instance_id=mock_instance_id)
assert actual_purge_result == mock_purge_result
55 changes: 55 additions & 0 deletions ext/dapr-ext-workflow/tests/test_workflow_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import List
from unittest import mock

import grpc
from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext
from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext
from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name
Expand All @@ -39,6 +40,56 @@ def add_named_activity(self, name: str, fn):
self._activity_fns[name] = fn


class WorkflowRuntimeTimeoutInterceptorTest(unittest.TestCase):
def setUp(self):
listActivities.clear()
listOrchestrators.clear()
self._registry_patch = mock.patch(
'durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()
)
self._registry_patch.start()

def tearDown(self):
mock.patch.stopall()

def test_timeout_interceptor_is_prepended(self):
with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls:
WorkflowRuntime()
mock_worker_cls.assert_called_once()
call_kwargs = mock_worker_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 1)
from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor)

def test_timeout_interceptor_with_custom_interceptors(self):
custom_interceptor = mock.MagicMock(spec=grpc.UnaryUnaryClientInterceptor)
with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls:
WorkflowRuntime(interceptors=[custom_interceptor])
call_kwargs = mock_worker_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 2)
from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor)
self.assertIs(interceptors[1], custom_interceptor)

def test_timeout_interceptor_preserves_custom_interceptor_order(self):
custom1 = mock.MagicMock(spec=grpc.UnaryUnaryClientInterceptor)
custom2 = mock.MagicMock(spec=grpc.UnaryStreamClientInterceptor)
with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls:
WorkflowRuntime(interceptors=[custom1, custom2])
call_kwargs = mock_worker_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 3)
from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor)
self.assertIs(interceptors[1], custom1)
self.assertIs(interceptors[2], custom2)


class WorkflowRuntimeTest(unittest.TestCase):
def setUp(self):
listActivities.clear()
Expand Down Expand Up @@ -630,3 +681,7 @@ def my_fn(ctx):
with self.assertRaises(ValueError) as ctx:
alternate_name(name='second')(my_fn)
self.assertIn('already has an alternate name', str(ctx.exception))

with self.assertRaises(ValueError) as ctx:
alternate_name(name='second')(my_fn)
self.assertIn('already has an alternate name', str(ctx.exception))