Skip to content

Commit

Permalink
Parametrising fast inference so that finetuned models can be used (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucapericlp authored Mar 26, 2024
1 parent a26ed91 commit 01e3bc0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<a target="_blank" style="display: inline-block; vertical-align: middle" href="https://colab.research.google.com/github/metavoiceio/metavoice-src/blob/main/colab_demo.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
[![](https://dcbadge.vercel.app/api/server/Cpy6U3na8Z?style=flat&compact=True)](https://discord.gg/tbTbkGEgJM)
[![](https://dcbadge.vercel.app/api/server/Cpy6U3na8Z?style=flat&compact=True)](https://discord.gg/tbTbkGEgJM)
[![Twitter](https://img.shields.io/twitter/url/https/twitter.com/OnusFM.svg?style=social&label=@metavoiceio)](https://twitter.com/metavoiceio)


Expand Down Expand Up @@ -69,7 +69,7 @@ poetry install && poetry run pip install torch==2.2.1 torchaudio==2.2.1
## Usage
1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/fast_inference.py)
```bash
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
# Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16.
poetry run python -i fam/llm/fast_inference.py

Expand All @@ -82,7 +82,7 @@ tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-s
2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](serving.py) or [web UI](app.py)
```bash
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
# Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16.
poetry run python serving.py
poetry run python app.py
Expand All @@ -108,6 +108,11 @@ Try it out using our sample datasets via:
poetry run finetune --train ./datasets/sample_dataset.csv --val ./datasets/sample_val_dataset.csv
```

Once you've trained your model, you can use it for inference via:
```bash
poetry run python -i fam/llm/fast_inference.py --first_stage_path ./my-finetuned_model.pt
```

### Configuration

In order to set hyperparameters such as learning rate, what to freeze, etc, you
Expand Down
7 changes: 6 additions & 1 deletion fam/llm/fast_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
seed: int = 1337,
output_dir: str = "outputs",
quantisation_mode: Optional[Literal["int4", "int8"]] = None,
first_stage_path: Optional[str] = None,
):
"""
Initialise the TTS model.
Expand All @@ -54,6 +55,7 @@ def __init__(
- None for no quantisation (bf16 or fp16 based on device),
- int4 for int4 weight-only quantisation,
- int8 for int8 weight-only quantisation.
first_stage_path: path to first-stage LLM checkpoint. If provided, this will override the one grabbed from Hugging Face via `model_name`.
"""

# NOTE: this needs to come first so that we don't change global state when we want to use
Expand All @@ -64,6 +66,9 @@ def __init__(
self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
if first_stage_path:
print(f"Overriding first stage checkpoint via provided model: {first_stage_path}")
first_stage_ckpt = first_stage_path or f"{self._model_dir}/first_stage.pt"

second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
config_second_stage = InferenceConfig(
Expand All @@ -85,7 +90,7 @@ def __init__(
self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
self.model, self.tokenizer, self.smodel, self.model_size = build_model(
precision=self.precision,
checkpoint_path=Path(f"{self._model_dir}/first_stage.pt"),
checkpoint_path=Path(first_stage_ckpt),
spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
device=self._device,
compile=True,
Expand Down

0 comments on commit 01e3bc0

Please sign in to comment.