Skip to content

Commit

Permalink
Fix routing batch function deadlocks and unordered batches (#649)
Browse files Browse the repository at this point in the history
* Add checking step `input_batch_size` multiple

* Fix unordered batches when using `routing_batch_function`

* Fix `can_generate` condition

* Remove metadata and style

* Fix getting data for batch when irregular batch sizes

* Fix steps receiving routed batches getting stuck

* Fix `_last_batch_convergence_step` method

* Fix stop not checking for `None`

* Fix issues related to the queues

* Remove unused variable

* Add integration tests timeout

* Fix deadlock caused becase next expected batch in convergence step

* Update unit tests

* Add timeout to tests

* Simplify condition

* Fix unit test

* Update timeouts
  • Loading branch information
gabrielmbmb authored May 20, 2024
1 parent 4ea1fc0 commit 690013a
Show file tree
Hide file tree
Showing 13 changed files with 489 additions and 121 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ jobs:

- name: Integration Tests
run: make integration-tests
timeout-minutes: 5
8 changes: 0 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
---
description: Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs.
hide:
- toc
---

<style>.md-typeset h1, .md-content__button { display: none;}</style>

<div align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://github.com/argilla-io/distilabel/blob/main/docs/assets/distilabel-white.png?raw=true">
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ docs = [
"CairoSVG >= 2.7.1",
"mknotebooks >= 0.8.0",
]
tests = ["pytest >= 7.4.0", "pytest-asyncio", "nest-asyncio"]
tests = ["pytest >= 7.4.0", "pytest-asyncio", "nest-asyncio", "pytest-timeout"]

# Optional LLMs, integrations, etc
anthropic = ["anthropic >= 0.20.0"]
Expand Down
20 changes: 10 additions & 10 deletions src/distilabel/distiset.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,22 @@ def create_distiset( # noqa: C901
continue

files = [str(file) for file in list_files_in_dir(file)]
try:
if files:
if files:
try:
ds = load_dataset(
"parquet", name=file.stem, data_files={"train": files}
)
if not enable_metadata and DISTILABEL_METADATA_KEY in ds.column_names:
ds = ds.remove_columns(DISTILABEL_METADATA_KEY)
distiset[file.stem] = ds
else:
logger.warning(
f"No output files for step '{file.stem}', can't create a dataset."
" Did the step produce any data?"
)
except ArrowInvalid:
logger.warning(f"❌ Failed to load the subset from '{file}' directory.")
continue
except ArrowInvalid:
logger.warning(f"❌ Failed to load the subset from '{file}' directory.")
continue
else:
logger.warning(
f"No output files for step '{file.stem}', can't create a dataset."
" Did the step produce any data?"
)

# If there's only one dataset i.e. one config, then set the config name to `default`
if len(distiset.keys()) == 1:
Expand Down
12 changes: 12 additions & 0 deletions src/distilabel/pipeline/_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import networkx as nx

from distilabel.pipeline.constants import (
CONVERGENCE_STEP_ATTR_NAME,
ROUTING_BATCH_FUNCTION_ATTR_NAME,
STEP_ATTR_NAME,
)
Expand Down Expand Up @@ -353,6 +354,9 @@ def _validate_convergence_step(
):
return

# Mark the step as a convergence step
self.set_step_attr(step.name, CONVERGENCE_STEP_ATTR_NAME, True) # type: ignore

# Check if all the predecessors of the step are receiving routed batches from the
# same step
previous_steps_predecessors = [
Expand Down Expand Up @@ -431,6 +435,14 @@ def _validate_routing_batch_function(
f" from step '{predecessor_step.name}' to step '{step.name}'."
)

if batch_size % step.input_batch_size != 0: # type: ignore
raise ValueError(
f"Step '{step.name}' should have an `input_batch_size` that is a multiple"
f" of the `input_batch_size` or `batch_size` of the previous step."
f" This is because the batches are being routed with a `routing_batch_function`"
f" from step '{predecessor_step.name}' to step '{step.name}'."
)

return True

def _validate_process_step_input_parameter(
Expand Down
Loading

0 comments on commit 690013a

Please sign in to comment.