Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate ModularTokenizerOp with Hugging Face remote 🤗 #368

Merged
merged 6 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion fuse/data/tokenizers/modular_tokenizer/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
from fuse.data.tokenizers.modular_tokenizer.inject_utils import (
InjectorToModularTokenizerLib,
)
from huggingface_hub import snapshot_download, HfApi
from huggingface_hub.utils import validate_hf_hub_args, SoftTemporaryDirectory


from warnings import warn
from pathlib import Path
from collections import defaultdict
from typing import Tuple, Optional, Union, Any
from typing import Any, Tuple, Dict, List, Optional, Union
import os
import re

Expand Down Expand Up @@ -506,3 +509,73 @@ def __call__(
)

return sample_dict

@classmethod
def from_pretrained(
cls, identifier: str, pad_token: str = "<PAD>", max_size: Optional[int] = None
) -> "ModularTokenizerOp":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment that for arguments details directs to snapshot_download

if not os.path.isdir(identifier):
# Try to download from hub
try:
# Download the entire repo
identifier = snapshot_download(
repo_id=str(identifier),
# revision=revision,
# cache_dir=cache_dir,
# force_download=force_download,
# proxies=proxies,
# resume_download=resume_download,
# token=token,
# local_files_only=local_files_only,
allow_patterns="tokenizer/",
)
identifier = os.path.join(identifier, "tokenizer")
except Exception as e:
raise Exception(
f"Couldn't find the checkpoint path nor download from HF hub! {identifier}"
) from e

tokenizer_op = cls(
tokenizer_path=identifier, pad_token=pad_token, max_size=max_size
)
return tokenizer_op

def save_pretrained(self, save_directory: Union[str, Path]) -> None:
print(f"Saving @ {save_directory=}")
self._tokenizer.save(path=str(save_directory))
Comment on lines +562 to +564
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, all of the information is being stored in tokenizer. That's why it's enough to save it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct


@validate_hf_hub_args
def push_to_hub(
self,
repo_id: str,
*,
commit_message: str = "Push model using huggingface_hub.",
private: bool = False,
token: Optional[str] = None,
branch: Optional[str] = None,
create_pr: Optional[bool] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
delete_patterns: Optional[Union[List[str], str]] = None,
model_card_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
api = HfApi(token=token)
repo_id = api.create_repo(
repo_id=repo_id, private=private, exist_ok=True
).repo_id
# Push the files to the repo in a single commit
with SoftTemporaryDirectory() as tmp:
saved_path = Path(tmp) / repo_id
tokenzier_dirpath = saved_path / "tokenizer"
self.save_pretrained(tokenzier_dirpath)
return api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=saved_path,
commit_message=commit_message,
revision=branch,
create_pr=create_pr,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
delete_patterns=delete_patterns,
)
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ exclude =


[mypy]
python_version = 3.7
python_version = 3.9
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

due mypy pre-commit errors

warn_return_any = True
warn_unused_configs = True
disallow_untyped_defs = True
Expand Down
Loading