Skip to content

Commit

Permalink
deepspeed-chat: filter stage3 too long prompts (#782)
Browse files Browse the repository at this point in the history
In case stage3 prompts are too long, the prompts are used but they are
arbitrary sliced at start to fit into the configured max prompt length.
This arbitrary slicing sometimes causes prompts to be less meaningful.
Which in turn, causes the generator to generate garbage.
This phenomena was observed to de-stabilize RLHF stage3.
To fix it, we filter prompts that are too long.

In addition, dataset rebuild flag is propagated to other required consumers.
Note that since generated dataset are cached in disk, this commit will have
effect only if we cleanup step3 cached datasets.

Change-Id: I440f09decf0784e4c2c8167a893006dff312281b

Signed-off-by: Moshe Island <[email protected]>
Co-authored-by: Moshe Island <[email protected]>
  • Loading branch information
mosheisland and mosheisland authored Nov 21, 2023
1 parent 09af71a commit 8c551d2
Showing 1 changed file with 53 additions and 24 deletions.
77 changes: 53 additions & 24 deletions applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,19 @@ def get_shuffle_idx(seed, size):
return shuffle_idx


def get_raw_dataset_split_index(local_rank, output_path, dataset_name, seed,
split_name, data_split, split_index,
data_size):
def get_raw_dataset_split_index(local_rank,
output_path,
dataset_name,
seed,
split_name,
data_split,
split_index,
data_size,
rebuild=False):
index_file_name = f"{output_path}/{dataset_name}_seed{seed}_{split_name}_{data_split}_{split_index}.npy"
# reindex each time when using local jsonfile since it's more likely to get modified
if (not os.path.isfile(index_file_name)) or (dataset_name == 'jsonfile'):
if rebuild or (not os.path.isfile(index_file_name)) or (dataset_name
== 'jsonfile'):
splits = [float(s) for s in data_split.split(',')]
splits_sum = sum(splits)
splits = [split / splits_sum for split in splits]
Expand Down Expand Up @@ -176,6 +183,9 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
chosen_token["attention_mask"] = chosen_token[
"attention_mask"].squeeze(0)
chosen_dataset.append(chosen_token)
print(
f'Creating dataset {raw_dataset.dataset_name_clean} for {train_phase=} size={len(chosen_dataset)}'
)

elif train_phase == 2:
for i, tmp_data in enumerate(current_dataset):
Expand Down Expand Up @@ -204,39 +214,41 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
reject_token["input_ids"] = reject_token["input_ids"]
reject_token["attention_mask"] = reject_token["attention_mask"]
reject_dataset.append(reject_token)
print(
f'Creating dataset {raw_dataset.dataset_name_clean} for {train_phase=} size={len(chosen_dataset)}'
)

elif train_phase == 3:
filtered = 0
for i, tmp_data in enumerate(current_dataset):
# tokenize the text
prompt = raw_dataset.get_prompt(tmp_data)
if prompt is not None:
prompt_token = tokenizer(prompt, return_tensors="pt")
prompt_token["input_ids"] = prompt_token["input_ids"]
prompt_token["attention_mask"] = prompt_token["attention_mask"]
for key_word in ["input_ids", "attention_mask"]:
length = prompt_token[key_word].size()[-1]
if length > max_seq_len:
y = prompt_token[key_word].squeeze(0)[length -
(max_seq_len -
1):].flip(0)
else:
y = prompt_token[key_word].squeeze(0).flip(0)
prompt_token[key_word] = y
prompt_dataset.append(prompt_token)
if prompt_token["input_ids"].size()[-1] <= max_seq_len:
for key_word in ["input_ids", "attention_mask"]:
prompt_token[key_word] = prompt_token[
key_word].squeeze(0).flip(0)
prompt_dataset.append(prompt_token)
else:
filtered += 1
print(f'Creating dataset {raw_dataset.dataset_name_clean} '
f'for {train_phase=} size={len(prompt_dataset)} {filtered=}')

return PromptDataset(prompt_dataset, chosen_dataset, reject_dataset,
tokenizer.pad_token_id, train_phase)


def create_dataset(local_rank, dataset_name, data_split, output_path,
train_phase, seed, tokenizer, end_of_conversation_token,
max_seq_len):
max_seq_len, rebuild):
raw_dataset = get_raw_dataset(dataset_name, output_path, seed, local_rank)
train_dataset = raw_dataset.get_train_data()
train_index = get_raw_dataset_split_index(local_rank, output_path,
raw_dataset.dataset_name_clean,
seed, "train", data_split,
train_phase - 1,
len(train_dataset))
len(train_dataset), rebuild)
train_dataset = Subset(train_dataset, train_index)
train_dataset = create_dataset_split(train_dataset, raw_dataset,
train_phase, tokenizer,
Expand All @@ -248,7 +260,7 @@ def create_dataset(local_rank, dataset_name, data_split, output_path,
raw_dataset.dataset_name_clean,
seed, "eval",
data_split, train_phase - 1,
len(eval_dataset))
len(eval_dataset), rebuild)
eval_dataset = Subset(eval_dataset, eval_index)
eval_dataset = create_dataset_split(eval_dataset, raw_dataset, train_phase,
tokenizer, end_of_conversation_token,
Expand Down Expand Up @@ -287,19 +299,36 @@ def create_prompt_dataset(local_rank,
torch.distributed.all_reduce(buf_create_cache)

if local_rank <= 0 and (buf_create_cache.item() != 0 or reload):
print(f'Creating prompt dataset {data_path}, {reload=}')
if len(data_path) == 1: # Single dataset.
train_dataset, eval_dataset = create_dataset(
local_rank, data_path[0], data_split, output_path, train_phase,
seed, tokenizer, end_of_conversation_token, max_seq_len)
local_rank,
data_path[0],
data_split,
output_path,
train_phase,
seed,
tokenizer,
end_of_conversation_token,
max_seq_len,
rebuild=reload)
else: # Blending datasets.
train_datasets = []
eval_datasets = []
train_size = 0
eval_size = 0
for d_path in data_path:
train_dataset, eval_dataset = create_dataset(
local_rank, d_path, data_split, output_path, train_phase,
seed, tokenizer, end_of_conversation_token, max_seq_len)
local_rank,
d_path,
data_split,
output_path,
train_phase,
seed,
tokenizer,
end_of_conversation_token,
max_seq_len,
rebuild=reload)
train_datasets.append(train_dataset)
eval_datasets.append(eval_dataset)
train_size += len(train_dataset)
Expand Down Expand Up @@ -328,7 +357,7 @@ def create_prompt_dataset(local_rank,
tokenizer,
end_of_conversation_token,
max_seq_len,
)
rebuild=reload)
sft_train_datasets.append(sft_train_dataset)
sft_eval_datasets.append(sft_eval_dataset)
sft_train_size += len(sft_train_dataset)
Expand Down

0 comments on commit 8c551d2

Please sign in to comment.