Skip to content

Commit

Permalink
Create new distilabel.constants module to store constants and avoid…
Browse files Browse the repository at this point in the history
… circular imports (#861)

* Bump version to `1.4.0`

* Refactor constants modules to a higher level to avoid circular imports

* Update src/distilabel/__init__.py

Co-authored-by: Gabriel Martín Blázquez <[email protected]>

---------

Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
plaguss and gabrielmbmb authored Aug 7, 2024
1 parent 63f948b commit 1b43450
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/distilabel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@

from rich import traceback as rich_traceback

__version__ = "1.3.0"
__version__ = "1.3.1"

rich_traceback.install(show_locals=True)
5 changes: 1 addition & 4 deletions src/distilabel/cli/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@
from pydantic import HttpUrl, ValidationError
from pydantic.type_adapter import TypeAdapter

from distilabel.pipeline.constants import (
ROUTING_BATCH_FUNCTION_ATTR_NAME,
STEP_ATTR_NAME,
)
from distilabel.constants import ROUTING_BATCH_FUNCTION_ATTR_NAME, STEP_ATTR_NAME
from distilabel.pipeline.local import Pipeline

if TYPE_CHECKING:
Expand Down
37 changes: 37 additions & 0 deletions src/distilabel/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Final

# Steps related constants
DISTILABEL_METADATA_KEY: Final[str] = "distilabel_metadata"

# Pipeline related constants
STEP_ATTR_NAME: Final[str] = "step"
INPUT_QUEUE_ATTR_NAME: Final[str] = "input_queue"
RECEIVES_ROUTED_BATCHES_ATTR_NAME: Final[str] = "receives_routed_batches"
ROUTING_BATCH_FUNCTION_ATTR_NAME: Final[str] = "routing_batch_function"
CONVERGENCE_STEP_ATTR_NAME: Final[str] = "convergence_step"
LAST_BATCH_SENT_FLAG: Final[str] = "last_batch_sent"


__all__ = [
"STEP_ATTR_NAME",
"INPUT_QUEUE_ATTR_NAME",
"RECEIVES_ROUTED_BATCHES_ATTR_NAME",
"ROUTING_BATCH_FUNCTION_ATTR_NAME",
"CONVERGENCE_STEP_ATTR_NAME",
"LAST_BATCH_SENT_FLAG",
"DISTILABEL_METADATA_KEY",
]
4 changes: 2 additions & 2 deletions src/distilabel/distiset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyarrow.lib import ArrowInvalid
from typing_extensions import Self

from distilabel.pipeline.constants import STEP_ATTR_NAME
from distilabel.constants import STEP_ATTR_NAME
from distilabel.utils.card.dataset_card import (
DistilabelDatasetCard,
size_categories_parser,
Expand Down Expand Up @@ -536,7 +536,7 @@ def create_distiset( # noqa: C901
>>> distiset = create_distiset(Path.home() / ".cache/distilabel/pipelines/path-to-pipe-hashname")
```
"""
from distilabel.steps.constants import DISTILABEL_METADATA_KEY
from distilabel.constants import DISTILABEL_METADATA_KEY

logger = logging.getLogger("distilabel.distiset")

Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/pipeline/_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import networkx as nx

from distilabel.pipeline.constants import (
from distilabel.constants import (
CONVERGENCE_STEP_ATTR_NAME,
ROUTING_BATCH_FUNCTION_ATTR_NAME,
STEP_ATTR_NAME,
Expand Down
12 changes: 6 additions & 6 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@
from upath import UPath

from distilabel import __version__
from distilabel.distiset import create_distiset
from distilabel.mixins.requirements import RequirementsMixin
from distilabel.pipeline._dag import DAG
from distilabel.pipeline.batch import _Batch
from distilabel.pipeline.batch_manager import _BatchManager
from distilabel.pipeline.constants import (
from distilabel.constants import (
CONVERGENCE_STEP_ATTR_NAME,
INPUT_QUEUE_ATTR_NAME,
LAST_BATCH_SENT_FLAG,
RECEIVES_ROUTED_BATCHES_ATTR_NAME,
ROUTING_BATCH_FUNCTION_ATTR_NAME,
STEP_ATTR_NAME,
)
from distilabel.distiset import create_distiset
from distilabel.mixins.requirements import RequirementsMixin
from distilabel.pipeline._dag import DAG
from distilabel.pipeline.batch import _Batch
from distilabel.pipeline.batch_manager import _BatchManager
from distilabel.pipeline.write_buffer import _WriteBuffer
from distilabel.steps.base import GeneratorStep
from distilabel.steps.generators.utils import make_generator_step
Expand Down
6 changes: 3 additions & 3 deletions src/distilabel/pipeline/batch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, Union

from distilabel.pipeline._dag import DAG
from distilabel.pipeline.batch import _Batch
from distilabel.pipeline.constants import (
from distilabel.constants import (
RECEIVES_ROUTED_BATCHES_ATTR_NAME,
STEP_ATTR_NAME,
)
from distilabel.pipeline._dag import DAG
from distilabel.pipeline.batch import _Batch
from distilabel.steps.base import _Step
from distilabel.utils.files import list_files_in_dir
from distilabel.utils.serialization import (
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/pipeline/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import sys
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from distilabel.constants import INPUT_QUEUE_ATTR_NAME
from distilabel.distiset import create_distiset
from distilabel.llms.vllm import vLLM
from distilabel.pipeline.base import BasePipeline
from distilabel.pipeline.constants import INPUT_QUEUE_ATTR_NAME
from distilabel.pipeline.step_wrapper import _StepWrapper
from distilabel.utils.logging import setup_logging, stop_logging
from distilabel.utils.serialization import TYPE_INFO_KEY
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/pipeline/step_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from queue import Queue
from typing import Any, Dict, List, Optional, Union, cast

from distilabel.constants import LAST_BATCH_SENT_FLAG
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.pipeline.batch import _Batch
from distilabel.pipeline.constants import LAST_BATCH_SENT_FLAG
from distilabel.pipeline.typing import StepLoadStatus
from distilabel.steps.base import GeneratorStep, Step, _Step

Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pydantic import Field
from typing_extensions import override

from distilabel.constants import DISTILABEL_METADATA_KEY
from distilabel.llms.base import LLM
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import (
Expand All @@ -26,7 +27,6 @@
StepInput,
_Step,
)
from distilabel.steps.constants import DISTILABEL_METADATA_KEY
from distilabel.utils.dicts import group_dicts

if TYPE_CHECKING:
Expand Down

0 comments on commit 1b43450

Please sign in to comment.