Skip to content

Commit

Permalink
Use dynamic number of beams depending on the prompt's token length. W…
Browse files Browse the repository at this point in the history
…e scale it down approximately quadratically due to the quadratic nature of attention. We now no longer need the statements to explicitly deal with torch memory before the generate statement.

Update prompt to be the same as sql-coder.
Add tests.
  • Loading branch information
wongjingping committed Oct 18, 2023
1 parent 6189cf5 commit cf939dc
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 33 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: psf/black@stable
test:
runs-on: ubuntu-latest
Expand All @@ -16,7 +16,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'
python-version: '3.11'
cache: 'pip'
- name: Install pip dependencies
run: |
Expand Down
36 changes: 18 additions & 18 deletions eval/hf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ def generate_prompt(prompt_file, question, db_name, public_data):
return prompt


def dynamic_num_beams(prompt: str, tokenizer) -> int:
tokens = len(tokenizer.encode(prompt))
print(tokens)
if tokens <= 1024:
return 4
elif tokens <= 1536:
return 2
else:
return 1


def get_tokenizer_model(model_name: Optional[str], adapter_path: Optional[str]):
"""
Load a HuggingFace tokenizer and model.
Expand Down Expand Up @@ -116,35 +127,24 @@ def run_hf_eval(args):
total_correct = 0
output_rows = []

if model_name is None or "llama" not in model_name.lower():
pipeline_config = {
"max_new_tokens": 300,
"do_sample": False,
"num_beams": 4,
}
else:
pipeline_config = {
"max_new_tokens": 300,
"do_sample": False,
"num_beams": 3,
}

with tqdm(total=len(df)) as pbar:
for row in df.to_dict("records"):
total_tried += 1
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
start_time = time()
num_beams = dynamic_num_beams(row["prompt"], tokenizer)
# we set return_full_text to False so that we don't get the prompt text in the generated text
# this simplifies our postprocessing to deal with just the truncation of the end of the query
generated_query = (
pipe(
row["prompt"],
max_new_tokens=300,
do_sample=False,
num_beams=num_beams,
num_return_sequences=1,
return_full_text=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
**pipeline_config,
)[0]["generated_text"]
.split("```sql")[-1]
.split("```")[0]
.split(";")[0]
.strip()
Expand Down
22 changes: 11 additions & 11 deletions prompts/prompt.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
### Instructions:
Your task is to convert a question into a SQL query, given a Postgres database schema.
Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float
### Task
Generate a SQL query to answer the following question:
`{user_question}`

### Input:
Generate a SQL query that answers the question `{user_question}`.
This query will run on a database whose schema is represented in this string:
### Database Schema
The query will run on a database with the following schema:
{table_metadata_string}

### Response:
Based on your instructions, here is the SQL query I have generated to answer the question `{user_question}`:
### SQL
Follow these steps to create the SQL Query:
1. Only use the columns and tables present in the database schema
2. Use table aliases to prevent ambiguity when doing joins. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.

Given the database schema, here is the SQL query that answers `{user_question}`:
```sql
22 changes: 22 additions & 0 deletions tests/test_hf_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
from transformers import AutoTokenizer

from eval.hf_runner import dynamic_num_beams


@pytest.fixture
def tokenizer():
return AutoTokenizer.from_pretrained("bigcode/starcoder")


def test_dynamic_num_beams_ranges(tokenizer):
prompt = "word "
prompt_4 = prompt * 1023
num_beams_4 = dynamic_num_beams(prompt_4, tokenizer)
assert num_beams_4 == 4
prompt_2 = prompt * 1535
num_beams_2 = dynamic_num_beams(prompt_2, tokenizer)
assert num_beams_2 == 2
prompt_1 = prompt * 2048
num_beams_1 = dynamic_num_beams(prompt_1, tokenizer)
assert num_beams_1 == 1
2 changes: 0 additions & 2 deletions utils/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,8 @@ def prune_metadata_str(question, db_name, public_data=True):
root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
emb_path = os.path.join(root_dir, "data", "embeddings.pkl")
if public_data:
print("Loading public data")
import defog_data.supplementary as sup
else:
print("Loading private data")
import defog_data_private.supplementary as sup
emb, csv_descriptions = sup.load_embeddings(emb_path)
table_metadata_csv = get_md_emb(
Expand Down

0 comments on commit cf939dc

Please sign in to comment.