Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tokenization at scale #287

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
rev: 23.9.1
hooks:
- id: black
language_version: python3.10
language_version: python3.11
stages: [pre-commit]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.278
Expand Down
4 changes: 3 additions & 1 deletion src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import click
import click_pathlib
from modalities.utils.logging import get_logger
from pydantic import BaseModel, FilePath

from modalities.api import (
Expand Down Expand Up @@ -148,7 +149,7 @@ def CMD_entry_point_data_create_raw_index(src_path: Path, index_path: Path):


@data.command(name="pack_encoded_data")
@click.argument("config_path", type=FilePath)
@click.option("--config_path", type=FilePath, required=True)
def CMD_entry_point_pack_encoded_data(config_path: FilePath):
"""Utility to encode an indexed, large jsonl-file.
(see also `create_index` for more information)
Expand All @@ -158,6 +159,7 @@ def CMD_entry_point_pack_encoded_data(config_path: FilePath):
Args:
config_path (FilePath): Path to the config file describing the tokenization setup.
"""
get_logger().info(f"Loading config from {config_path}.")
config_dict = load_app_config_dict(config_path)

pack_encoded_data(config_dict=config_dict)
Expand Down
147 changes: 124 additions & 23 deletions src/modalities/api.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,44 @@
#!/usr/bin/env python

import multiprocessing as mp
import os
from enum import Enum
from pathlib import Path

from pydantic import FilePath

import modalities.inference.inference as inference
from modalities.checkpointing.checkpoint_conversion import CheckpointConversion
from modalities.config.component_factory import ComponentFactory
from modalities.config.instantiation_models import PackedDatasetComponentsInstantiationModel
from modalities.dataloader.create_index import IndexGenerator
from modalities.dataloader.create_packed_data import EmbeddedStreamData, PackedDataGenerator, join_embedded_stream_data
from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader
from modalities.config.instantiation_models import TokenizationInstantiationModel
from modalities.dataloader.preprocessing.indexation.create_index import IndexGenerator
from modalities.dataloader.preprocessing.queued_processing.process_controller import PipelineStep, ProcessController
from modalities.dataloader.preprocessing.queued_processing.queued_processing import Processor
from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import (
EmbeddedStreamData,
join_embedded_stream_data,
)
from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader
from modalities.dataloader.preprocessing.tokenization.tokenization_strategies import (
ProcessingStrategyFactory,
WorkerTypes,
populate_reader_q,
)
from modalities.models.huggingface_adapters.hf_adapter import HFModelAdapter
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
from modalities.utils.logging import get_logger


def create_raw_data_index(src_path: Path, index_path: Path):
class FileExistencePolicy(Enum):
SKIP = "skip"
ERROR = "error"
OVERRIDE = "override"


def create_raw_data_index(
src_path: Path, index_path: Path, file_existence_policy: FileExistencePolicy = FileExistencePolicy.ERROR
):
"""Creates the index file for the content of a large jsonl-file. The index file
contains the byte-offsets and lengths of each line in the jsonl-file.
Background is the ability to further process the respective file without loading it,
Expand All @@ -31,13 +52,24 @@ def create_raw_data_index(src_path: Path, index_path: Path):
Raises:
ValueError: If the index file already exists.
"""
index_path = LargeFileLinesReader.default_index_path(src_path, index_path)
os.makedirs(index_path.parent, exist_ok=True)
index_path = LocalLargeFileLinesReader.default_index_path(src_path, index_path)
if index_path.exists():
raise ValueError("index already exists. delete it or specify different output folder.")
if file_existence_policy == FileExistencePolicy.SKIP:
get_logger(name="main").warning(f"Index already exists at {str(index_path)}. Skipping index creation.")
return
elif file_existence_policy == FileExistencePolicy.OVERRIDE:
get_logger(name="main").warning(f"Index already exists at {str(index_path)}. Overriding it.")
os.remove(index_path)
elif file_existence_policy == FileExistencePolicy.ERROR:
raise ValueError("index already exists. delete it or specify different output folder.")
else:
raise ValueError(f"Unknown file existence policy: {file_existence_policy}")

get_logger(name="main").info(
f"Reading raw data from {str(src_path)} and" f" writing index to {str(index_path)} ..."
)
os.makedirs(index_path.parent, exist_ok=True)

print(f"reading raw data from {src_path}")
print(f"writing index to {index_path}")
generator = IndexGenerator(src_path)
generator.create_index(index_path)

Expand Down Expand Up @@ -88,22 +120,91 @@ def pack_encoded_data(config_dict: dict):
# ResolverRegistry to work dynamically with any type-hinted config object from config.py.
registry = Registry(COMPONENTS)
component_factory = ComponentFactory(registry=registry)
components: PackedDatasetComponentsInstantiationModel = component_factory.build_components(
config_dict=config_dict, components_model_type=PackedDatasetComponentsInstantiationModel
instantion_model: TokenizationInstantiationModel = component_factory.build_components(
config_dict=config_dict, components_model_type=TokenizationInstantiationModel
)

generator = PackedDataGenerator(
components.settings.src_path,
index_path=components.settings.index_path,
tokenizer=components.tokenizer,
eod_token=components.settings.eod_token,
jq_pattern=components.settings.jq_pattern,
number_of_processes=components.settings.num_cpus,
processing_batch_size=components.settings.processing_batch_size,
raw_samples_queue_size=components.settings.raw_samples_queue_size,
processed_samples_queue_size=components.settings.processed_samples_queue_size,
# build the queues
reader_q, tokenizer_q, writer_q, logging_message_q = ProcessingStrategyFactory.get_process_queues(
writer_q_maxsize=instantion_model.writer_q_maxsize, tokenizer_q_maxsize=instantion_model.tokenizer_q_maxsize
)
generator.run(components.settings.dst_path)

# build the workers
stop_event = mp.Event()

tokenizer_q_key = "tokenizer_q"
writer_q_key = "writer_q"
logging_message_q_key = "logging_message_q"

reader_settings = instantion_model.reader_worker_settings.reader_settings

reader_workers = [
Processor(
in_q=reader_q,
out_qs={tokenizer_q_key: tokenizer_q, logging_message_q_key: logging_message_q},
in_q_timeout=instantion_model.in_q_timeout,
out_q_timeout=instantion_model.out_q_timeout,
strategy=ProcessingStrategyFactory.get_reader_strategy(
reader_settings, tokenizer_q_key=tokenizer_q_key, logging_message_q_key=logging_message_q_key
),
process_type=WorkerTypes.READER,
process_id=i,
logging_message_q_key=logging_message_q_key,
stop_event=stop_event,
)
for i in range(instantion_model.reader_worker_settings.num_workers)
]

tokenizer_workers = [
Processor(
in_q=tokenizer_q,
out_qs={writer_q_key: writer_q, logging_message_q_key: logging_message_q},
in_q_timeout=instantion_model.in_q_timeout,
out_q_timeout=instantion_model.out_q_timeout,
strategy=ProcessingStrategyFactory.get_tokenizer_strategy(
tokenizer_settings=instantion_model.tokenizer_worker_settings.tokenizer_settings,
writer_q_key=writer_q_key,
logging_message_q_key=logging_message_q_key,
),
process_type=WorkerTypes.TOKENIZER,
process_id=i,
logging_message_q_key=logging_message_q_key,
stop_event=stop_event,
)
for i in range(instantion_model.tokenizer_worker_settings.num_workers)
]

writer_worker = Processor(
in_q=writer_q,
out_qs={logging_message_q_key: logging_message_q},
in_q_timeout=instantion_model.in_q_timeout,
out_q_timeout=instantion_model.out_q_timeout,
strategy=ProcessingStrategyFactory.get_writing_strategy(
ww_settings=instantion_model.writer_worker_settings, logging_message_q_key=logging_message_q_key
),
process_type=WorkerTypes.WRITER,
process_id=0,
logging_message_q_key=logging_message_q_key,
stop_event=stop_event,
)

pipeline_steps = [
PipelineStep(name="reading", input_queue=reader_q, processors=reader_workers),
PipelineStep(name="tokenizing", input_queue=tokenizer_q, processors=tokenizer_workers),
PipelineStep(name="writing", input_queue=writer_q, processors=[writer_worker]),
]

def populate():
populate_reader_q(
reader_q=reader_q,
index_start=instantion_model.index_start,
num_samples=instantion_model.num_samples,
num_reader_processes=instantion_model.reader_worker_settings.num_workers,
batch_size=instantion_model.batch_size,
)

process_controller = ProcessController(pipeline_steps=pipeline_steps, populate_jobs=populate)
process_controller.run()


def merge_packed_data_files(src_paths: list[Path], target_path: Path):
Expand Down
4 changes: 2 additions & 2 deletions src/modalities/config/component_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _build_component(
# instantiate component config
component_key = current_component_config["component_key"]
variant_key = current_component_config["variant_key"]
current_component_config = self._instantiate_component_config(
current_component_config = self.instantiate_component_config(
component_key=component_key,
variant_key=variant_key,
config_dict=materialized_component_config["config"],
Expand Down Expand Up @@ -139,7 +139,7 @@ def _is_reference_config(config_dict: dict) -> bool:
# TODO instead of field checks, we should introduce an enum for the config type.
return {"instance_key", "pass_type"} == config_dict.keys()

def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: dict) -> BaseModel:
def instantiate_component_config(self, component_key: str, variant_key: str, config_dict: dict) -> BaseModel:
component_config_type: Type[BaseModel] = self.registry.get_config(component_key, variant_key)
self._assert_valid_config_keys(
component_key=component_key,
Expand Down
75 changes: 59 additions & 16 deletions src/modalities/config/instantiation_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from pathlib import Path
from typing import Annotated, Any, Optional

Expand All @@ -16,10 +15,10 @@
PydanticPytorchDeviceType,
PydanticPytorchModuleType,
PydanticTextInferenceComponentType,
PydanticTokenizerIFType,
)
from modalities.config.utils import parse_torch_device
from modalities.dataloader.dataset import Dataset
from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LargeFileLinesReaderTypes
from modalities.util import warn_rank_0


Expand Down Expand Up @@ -191,20 +190,64 @@ def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationMode
return self


class PackedDatasetComponentsInstantiationModel(BaseModel):
class PackedDatasetSettings(BaseModel):
src_path: FilePath
dst_path: Optional[Path] = None
index_path: Optional[FilePath] = None
jq_pattern: str
num_cpus: Annotated[int, Field(strict=True, ge=1)] = os.cpu_count()
eod_token: str
processing_batch_size: Annotated[int, Field(strict=True, ge=1)]
raw_samples_queue_size: Annotated[int, Field(strict=True, ge=1)]
processed_samples_queue_size: Annotated[int, Field(strict=True, ge=1)]

tokenizer: PydanticTokenizerIFType
settings: PackedDatasetSettings
class TokenizationInstantiationModel(BaseModel):
class ReaderWorkerSettings(BaseModel):
class ReaderSettings(BaseModel):
class LocalReaderArgs(BaseModel):
raw_data_path: Path
index_path: Optional[Path] = None
encoding: Optional[str] = "utf-8"

class GlobalReaderArgs(BaseModel):
global_inorder_index_path: Path
raw_data_file_list_path: Path
raw_data_root_path: Path
global_shuffle_index_path: Optional[Path] = None
encoding: Optional[str] = "utf-8"

reader_type: LargeFileLinesReaderTypes
reader_args: LocalReaderArgs | GlobalReaderArgs

num_workers: Annotated[int, Field(strict=True, ge=1)]
reader_settings: ReaderSettings

class TokenizerWorkerSettings(BaseModel):
class TokenizerSettings(BaseModel):
class TokenizerInstantitionSettings(BaseModel):
tokenizer_component_key: str
tokenizer_variant_key: str
config: dict[str, Any]

tokenizer_instantiation_settings: TokenizerInstantitionSettings
eod_token: str
jq_pattern: str

num_workers: Annotated[int, Field(strict=True, ge=1)]
tokenizer_settings: TokenizerSettings

class WriterWorkerSettings(BaseModel):
dst_path: Path
index_start: Annotated[int, Field(strict=True, ge=0)]

@field_validator("dst_path")
def ensure_path_does_not_exist(cls, value):
path = Path(value) # Convert to Path object if it's a string
if path.exists():
raise ValueError(f"The filepath '{path}' already exists.")
return path

paths: dict[str, Path]
reader_worker_settings: ReaderWorkerSettings
tokenizer_worker_settings: TokenizerWorkerSettings
writer_worker_settings: WriterWorkerSettings
tokenizer_q_maxsize: Annotated[int, Field(strict=True, ge=1)]
writer_q_maxsize: Annotated[int, Field(strict=True, ge=1)]
index_start: Annotated[int, Field(strict=True, ge=0)]
num_samples: Annotated[int, Field(strict=True, ge=1)]
batch_size: Annotated[int, Field(strict=True, ge=1)]
logging_interval: Annotated[int, Field(strict=True, ge=1)]
in_q_timeout: Annotated[int, Field(strict=True, ge=0)]
out_q_timeout: Annotated[int, Field(strict=True, ge=0)]


class TextGenerationInstantiationModel(BaseModel):
Expand Down
Loading
Loading