Skip to content

Commit

Permalink
allow join specific jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanguage committed Jul 20, 2023
1 parent ae3e5ca commit 1a71474
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
2 changes: 1 addition & 1 deletion executor/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .core import Engine, EngineSetting
from .job import LocalJob, ThreadJob, ProcessJob

__version__ = '0.2.0'
__version__ = '0.2.1'

__all__ = [
'Engine', 'EngineSetting',
Expand Down
16 changes: 12 additions & 4 deletions executor/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,21 @@ async def wait_async(
await asyncio.sleep(time_delta)
total_time -= time_delta

async def join(self, timeout: T.Optional[float] = None):
async def join(
self,
jobs: T.Optional[T.List[Job]] = None,
timeout: T.Optional[float] = None):
"""Join all running and pending jobs."""
running = self.jobs.running.values()
pending = self.jobs.pending.values()
if jobs is None:
jobs_for_join = (
self.jobs.running.values() +
self.jobs.pending.values()
)
else:
jobs_for_join = jobs
tasks = [
asyncio.create_task(job.join())
for job in (running + pending)
for job in jobs_for_join
]
if len(tasks) > 0:
await asyncio.wait(tasks, timeout=timeout)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,17 @@ def sleep_5s():
await engine.cancel_all_async()


@pytest.mark.asyncio
async def test_join_jobs():
engine = Engine()
job1 = ThreadJob(lambda x: x**2, (2,))
job2 = ThreadJob(lambda x: x**2, (3,))
await engine.submit_async(job1, job2)
await engine.join(jobs=[job1, job2])
assert job1.status == "done"
assert job2.result() == 9


def test_engine_start_stop():
engine = Engine()
engine.start()
Expand Down

0 comments on commit 1a71474

Please sign in to comment.