Skip to content

Commit

Permalink
fix: tts reference and config (#713)
Browse files Browse the repository at this point in the history
* fix: tts reference and config

* fix streaming

* fix bug

* fix wav header
  • Loading branch information
AnyaCoder authored Dec 7, 2024
1 parent 9f881ed commit b11bcf8
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ dependencies = [
"ormsgpack",
"tiktoken>=0.8.0",
"pydantic==2.9.2",
"cachetools",
]

[project.optional-dependencies]
stable = [
"torch<=2.4.1",
"torchaudio",
"cachetools",
]

[build-system]
Expand Down
2 changes: 1 addition & 1 deletion tools/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def parse_args():
parser.add_argument(
"--max_new_tokens",
type=int,
default=0,
default=1024,
help="Maximum new tokens to generate. \n0 means no limit.",
)
parser.add_argument(
Expand Down
10 changes: 5 additions & 5 deletions tools/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def check_and_download_files(repo_id, file_list, local_dir):


# 1st
repo_id_1 = "fishaudio/fish-speech-1.4"
local_dir_1 = "./checkpoints/fish-speech-1.4"
repo_id_1 = "fishaudio/fish-speech-1.5"
local_dir_1 = "./checkpoints/fish-speech-1.5"
files_1 = [
"gitattributes",
"model.pth",
"README.md",
"special_tokens_map.json",
"tokenizer_config.json",
"tokenizer.json",
"special_tokens.json",
"tokenizer.tiktoken",
"config.json",
"firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
]
Expand Down
3 changes: 1 addition & 2 deletions tools/inference_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, No
audio=(sample_rate, segment),
error=None,
)
else:
segments.append(segment)
segments.append(segment)
else:
break

Expand Down
1 change: 0 additions & 1 deletion tools/inference_engine/reference_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def load_by_hash(
# If the references are not already loaded, encode them
prompt_tokens.append(
self.encode_reference(
decoder_model=self.decoder_model,
reference_audio=ref.audio,
enable_reference_audio=True,
)
Expand Down
9 changes: 3 additions & 6 deletions tools/inference_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@dataclass
class InferenceResult:
code: Literal["header", "segment", "error", "final"]
audio: Optional[Tuple[int, np.ndarray]]
audio: Optional[Tuple[int, np.ndarray | bytes]]
error: Optional[Exception]


Expand All @@ -25,7 +25,7 @@ def normalize_text(user_input: str, use_normalization: bool) -> str:

def wav_chunk_header(
sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
) -> np.ndarray:
) -> bytes:
buffer = io.BytesIO()

with wave.open(buffer, "wb") as wav_file:
Expand All @@ -36,7 +36,4 @@ def wav_chunk_header(
wav_header_bytes = buffer.getvalue()
buffer.close()

# Convert to numpy array
wav_header = np.frombuffer(wav_header_bytes, dtype=np.uint8)

return wav_header
return wav_header_bytes
2 changes: 1 addition & 1 deletion tools/run_webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def parse_args():
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=0,
max_new_tokens=1024,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.5,
Expand Down
12 changes: 8 additions & 4 deletions tools/server/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
Wrapper for the inference function.
Used in the API server.
"""
count = 0
for result in engine.inference(req):
match result.code:
case "header":
Expand All @@ -27,15 +28,18 @@ def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
)

case "segment":
count += 1
if isinstance(result.audio, tuple):
yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()

case "final":
count += 1
if isinstance(result.audio, tuple):
yield result.audio[1]
return None # Stop the generator

raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR,
content="No audio generated, please check the input text.",
)
if count == 0:
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR,
content="No audio generated, please check the input text.",
)
4 changes: 2 additions & 2 deletions tools/server/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def warm_up(self, tts_inference_engine) -> None:
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=0,
max_new_tokens=1024,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.5,
repetition_penalty=1.2,
temperature=0.7,
format="wav",
)
Expand Down

0 comments on commit b11bcf8

Please sign in to comment.