Skip to content

Commit

Permalink
[fix]fix problems to let version 1.5 support sft (#774)
Browse files Browse the repository at this point in the history
* [docs]Add docs of Fish Agent.

* [docs]:Fix some issues

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [docs]Add Chinese docs for Fish Agent

* [docs]fix some issue

* [docs]fix the bug that chinese page display wrong

* [docs]Fix bugs in Chinese docs and add translated docs of agent for other language.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [feature]: change conversation.visual color and semantic encoded method

* [feature]:change collate_fn tokenizer to FishTokenzier

* [fix]fix some dimension problem in semantic.py

* [feature]change conf to tiktoken

* [fix]:fix ddp training problem

* [feature]use conversation to replace manully tokens and labels generate

* [fix]fix embedding calculate in BaseTransformer forward

* [fix]use einops to operate tensor to avoid bugs

* [fix]fix bugs in generate and llama for sft

* [fix]delete unused codes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: whaledolphin <[email protected]>
  • Loading branch information
3 people authored Dec 21, 2024
1 parent 0b48e78 commit 40665e1
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 215 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ asr-label*
/example
/faster_whisper
/.gradio
*log
13 changes: 8 additions & 5 deletions fish_speech/configs/text2semantic_finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@ defaults:

project: text2semantic_finetune_dual_ar
max_length: 4096
pretrained_ckpt_path: checkpoints/fish-speech-1.4
pretrained_ckpt_path: checkpoints/fish-speech-1.5

# Lightning Trainer
trainer:
accumulate_grad_batches: 1
gradient_clip_val: 1.0
gradient_clip_algorithm: "norm"
max_steps: 1000
max_steps: 10000
precision: bf16-true
limit_val_batches: 10
val_check_interval: 100
# strategy:
# find_unused_parameters: true
# static_graph: true

# Dataset Configuration
tokenizer:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: ${pretrained_ckpt_path}
_target_: fish_speech.tokenizer.FishTokenizer
model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken

# Dataset Configuration
train_dataset:
Expand Down Expand Up @@ -47,7 +50,7 @@ data:
train_dataset: ${train_dataset}
val_dataset: ${val_dataset}
num_workers: 4
batch_size: 8
batch_size: 4
tokenizer: ${tokenizer}
max_length: ${max_length}

Expand Down
33 changes: 16 additions & 17 deletions fish_speech/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,35 +207,34 @@ def visualize(
tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
)

# Colors for alternating tokens
colors = {
"blue": "\033[94m", # Light blue
"cyan": "\033[96m", # Cyan
"green": "\033[92m", # Light green
"dark_green": "\033[32m", # Dark green
"purple": "\033[95m",
"yellow": "\033[93m",
"red": "\033[91m",
"cyan": "\033[96m",
}
blue_idx = 0
green_idx = 0
first_idx = 0
second_idx = 0

def print_in_blue(x):
nonlocal blue_idx
color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
def print_first_group(x):
nonlocal first_idx
color = colors["purple"] if first_idx % 2 == 0 else colors["yellow"]
print(f"{color}{x}\033[0m", end="")
blue_idx += 1
first_idx += 1

def print_in_green(x):
nonlocal green_idx
color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
def print_second_group(x):
nonlocal second_idx
color = colors["red"] if second_idx % 2 == 0 else colors["cyan"]
print(f"{color}{x}\033[0m", end="")
green_idx += 1
second_idx += 1

for tok, lab in zip(encoded.tokens, encoded.labels):
val = tokenizer.decode([tok])

if lab == -100:
print_in_green(val)
print_second_group(val)
else:
print_in_blue(val)
print_first_group(val)

print()

Expand Down
Loading

0 comments on commit 40665e1

Please sign in to comment.