diff --git a/executor/engine/__init__.py b/executor/engine/__init__.py index d672cd6..f297759 100644 --- a/executor/engine/__init__.py +++ b/executor/engine/__init__.py @@ -1,7 +1,7 @@ from .core import Engine, EngineSetting from .job import LocalJob, ThreadJob, ProcessJob -__version__ = '0.2.0' +__version__ = '0.2.1' __all__ = [ 'Engine', 'EngineSetting', diff --git a/executor/engine/core.py b/executor/engine/core.py index 6aa7ed7..4902f1b 100644 --- a/executor/engine/core.py +++ b/executor/engine/core.py @@ -277,13 +277,21 @@ async def wait_async( await asyncio.sleep(time_delta) total_time -= time_delta - async def join(self, timeout: T.Optional[float] = None): + async def join( + self, + jobs: T.Optional[T.List[Job]] = None, + timeout: T.Optional[float] = None): """Join all running and pending jobs.""" - running = self.jobs.running.values() - pending = self.jobs.pending.values() + if jobs is None: + jobs_for_join = ( + self.jobs.running.values() + + self.jobs.pending.values() + ) + else: + jobs_for_join = jobs tasks = [ asyncio.create_task(job.join()) - for job in (running + pending) + for job in jobs_for_join ] if len(tasks) > 0: await asyncio.wait(tasks, timeout=timeout) diff --git a/tests/test_engine.py b/tests/test_engine.py index 7bff98c..6413d58 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -269,6 +269,17 @@ def sleep_5s(): await engine.cancel_all_async() +@pytest.mark.asyncio +async def test_join_jobs(): + engine = Engine() + job1 = ThreadJob(lambda x: x**2, (2,)) + job2 = ThreadJob(lambda x: x**2, (3,)) + await engine.submit_async(job1, job2) + await engine.join(jobs=[job1, job2]) + assert job1.status == "done" + assert job2.result() == 9 + + def test_engine_start_stop(): engine = Engine() engine.start()