-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate ModularTokenizerOp
with Hugging Face remote 🤗
#368
Changes from 2 commits
59b45f3
37fe63d
c60edb3
7ccb6d0
eb2139f
e3bff42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,11 +4,14 @@ | |
from fuse.data.tokenizers.modular_tokenizer.inject_utils import ( | ||
InjectorToModularTokenizerLib, | ||
) | ||
from huggingface_hub import snapshot_download, HfApi | ||
from huggingface_hub.utils import validate_hf_hub_args, SoftTemporaryDirectory | ||
|
||
|
||
from warnings import warn | ||
from pathlib import Path | ||
from collections import defaultdict | ||
from typing import Tuple, Optional, Union, Any | ||
from typing import Any, Tuple, Dict, List, Optional, Union | ||
import os | ||
import re | ||
|
||
|
@@ -506,3 +509,73 @@ def __call__( | |
) | ||
|
||
return sample_dict | ||
|
||
@classmethod | ||
def from_pretrained( | ||
cls, identifier: str, pad_token: str = "<PAD>", max_size: Optional[int] = None | ||
) -> "ModularTokenizerOp": | ||
if not os.path.isdir(identifier): | ||
# Try to download from hub | ||
try: | ||
# Download the entire repo | ||
identifier = snapshot_download( | ||
repo_id=str(identifier), | ||
# revision=revision, | ||
# cache_dir=cache_dir, | ||
# force_download=force_download, | ||
# proxies=proxies, | ||
# resume_download=resume_download, | ||
# token=token, | ||
# local_files_only=local_files_only, | ||
allow_patterns="tokenizer/", | ||
) | ||
identifier = os.path.join(identifier, "tokenizer") | ||
except Exception as e: | ||
raise Exception( | ||
f"Couldn't find the checkpoint path nor download from HF hub! {identifier}" | ||
) from e | ||
|
||
tokenizer_op = cls( | ||
tokenizer_path=identifier, pad_token=pad_token, max_size=max_size | ||
) | ||
return tokenizer_op | ||
|
||
def save_pretrained(self, save_directory: Union[str, Path]) -> None: | ||
print(f"Saving @ {save_directory=}") | ||
self._tokenizer.save(path=str(save_directory)) | ||
Comment on lines
+562
to
+564
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I understand, all of the information is being stored in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct |
||
|
||
@validate_hf_hub_args | ||
def push_to_hub( | ||
self, | ||
repo_id: str, | ||
*, | ||
commit_message: str = "Push model using huggingface_hub.", | ||
private: bool = False, | ||
token: Optional[str] = None, | ||
branch: Optional[str] = None, | ||
create_pr: Optional[bool] = None, | ||
allow_patterns: Optional[Union[List[str], str]] = None, | ||
ignore_patterns: Optional[Union[List[str], str]] = None, | ||
delete_patterns: Optional[Union[List[str], str]] = None, | ||
model_card_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> None: | ||
api = HfApi(token=token) | ||
repo_id = api.create_repo( | ||
repo_id=repo_id, private=private, exist_ok=True | ||
).repo_id | ||
# Push the files to the repo in a single commit | ||
with SoftTemporaryDirectory() as tmp: | ||
saved_path = Path(tmp) / repo_id | ||
tokenzier_dirpath = saved_path / "tokenizer" | ||
self.save_pretrained(tokenzier_dirpath) | ||
return api.upload_folder( | ||
repo_id=repo_id, | ||
repo_type="model", | ||
folder_path=saved_path, | ||
commit_message=commit_message, | ||
revision=branch, | ||
create_pr=create_pr, | ||
allow_patterns=allow_patterns, | ||
ignore_patterns=ignore_patterns, | ||
delete_patterns=delete_patterns, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,7 +58,7 @@ exclude = | |
|
||
|
||
[mypy] | ||
python_version = 3.7 | ||
python_version = 3.9 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. due |
||
warn_return_any = True | ||
warn_unused_configs = True | ||
disallow_untyped_defs = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comment that for arguments details directs to snapshot_download