Skip to content

Commit

Permalink
Use dynamic number of beams depending on the prompt's token length. W…
Browse files Browse the repository at this point in the history
…e scale it down approximately quadratically due to the quadratic nature of attention. We now no longer need the statements to explicitly deal with torch memory before the generate statement.

Update prompt to be the same as sql-coder.
Add tests.
Updated requirements.txt to support peft.
  • Loading branch information
wongjingping committed Oct 18, 2023
1 parent 6189cf5 commit f4ff7cd
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 37 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: psf/black@stable
test:
runs-on: ubuntu-latest
Expand All @@ -16,7 +16,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'
python-version: '3.11'
cache: 'pip'
- name: Install pip dependencies
run: |
Expand Down
36 changes: 18 additions & 18 deletions eval/hf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ def generate_prompt(prompt_file, question, db_name, public_data):
return prompt


def dynamic_num_beams(prompt: str, tokenizer) -> int:
tokens = len(tokenizer.encode(prompt))
print(tokens)
if tokens <= 1024:
return 4
elif tokens <= 1536:
return 2
else:
return 1


def get_tokenizer_model(model_name: Optional[str], adapter_path: Optional[str]):
"""
Load a HuggingFace tokenizer and model.
Expand Down Expand Up @@ -116,35 +127,24 @@ def run_hf_eval(args):
total_correct = 0
output_rows = []

if model_name is None or "llama" not in model_name.lower():
pipeline_config = {
"max_new_tokens": 300,
"do_sample": False,
"num_beams": 4,
}
else:
pipeline_config = {
"max_new_tokens": 300,
"do_sample": False,
"num_beams": 3,
}

with tqdm(total=len(df)) as pbar:
for row in df.to_dict("records"):
total_tried += 1
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
start_time = time()
num_beams = dynamic_num_beams(row["prompt"], tokenizer)
# we set return_full_text to False so that we don't get the prompt text in the generated text
# this simplifies our postprocessing to deal with just the truncation of the end of the query
generated_query = (
pipe(
row["prompt"],
max_new_tokens=300,
do_sample=False,
num_beams=num_beams,
num_return_sequences=1,
return_full_text=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
**pipeline_config,
)[0]["generated_text"]
.split("```sql")[-1]
.split("```")[0]
.split(";")[0]
.strip()
Expand Down
22 changes: 11 additions & 11 deletions prompts/prompt.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
### Instructions:
Your task is to convert a question into a SQL query, given a Postgres database schema.
Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float
### Task
Generate a SQL query to answer the following question:
`{user_question}`

### Input:
Generate a SQL query that answers the question `{user_question}`.
This query will run on a database whose schema is represented in this string:
### Database Schema
The query will run on a database with the following schema:
{table_metadata_string}

### Response:
Based on your instructions, here is the SQL query I have generated to answer the question `{user_question}`:
### SQL
Follow these steps to create the SQL Query:
1. Only use the columns and tables present in the database schema
2. Use table aliases to prevent ambiguity when doing joins. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.

Given the database schema, here is the SQL query that answers `{user_question}`:
```sql
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ defog-data==0.1.1
func_timeout
openai
pandas
peft
psycopg2-binary
pytest
pyyaml
sentence-transformers
spacy
sqlalchemy
tiktoken
pyyaml
sentence-transformers
torch
tqdm
transformers
22 changes: 22 additions & 0 deletions tests/test_hf_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
from transformers import AutoTokenizer

from eval.hf_runner import dynamic_num_beams


@pytest.fixture
def tokenizer():
return AutoTokenizer.from_pretrained("bigcode/starcoder")


def test_dynamic_num_beams_ranges(tokenizer):
prompt = "word "
prompt_4 = prompt * 1023
num_beams_4 = dynamic_num_beams(prompt_4, tokenizer)
assert num_beams_4 == 4
prompt_2 = prompt * 1535
num_beams_2 = dynamic_num_beams(prompt_2, tokenizer)
assert num_beams_2 == 2
prompt_1 = prompt * 2048
num_beams_1 = dynamic_num_beams(prompt_1, tokenizer)
assert num_beams_1 == 1
94 changes: 94 additions & 0 deletions tests/test_utils_pruning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import pytest
from utils.pruning import encoder, get_entity_types, get_md_emb


@pytest.fixture
def test_metadata():
column_csv = [
"country.name,text,country name",
"country.capital,text,country capital",
"country.id,integer,unique id for country, not iso code",
"airport.country_id,integer,unique id for country where airport is located in",
"airport.airport_name,text,name of airport",
"flight.pilot_name,text,name of the pilot",
"flight.airport_name,text,name of the airport",
"flight.flight_code,text,flight code",
]
column_emb = encoder.encode(column_csv, convert_to_tensor=True)
column_ner = {
"GPE": [
"country.name,text,country name",
"country.capital,text,country capital",
"airport.country_name,text,name of the country where the airport is located in",
],
"ORG": [
"country.name,text,name of the country",
"airport.airport_name,text,name of airport",
"flight.airport_name,text,name of the airport",
],
"PERSON": ["flight.pilot_name,text,name of the pilot"],
}
column_join = {("airport", "country"): [("airport.country_id", "country.id")]}
return column_emb, column_csv, column_ner, column_join


# test embedding results + ner + join columns for sql
def test_get_md_emb(test_metadata):
column_emb, column_csv, column_ner, column_join = test_metadata
question = "How many flights start from Los Angeles Airport (LAX)?"
assert get_entity_types(question) == {"GPE", "ORG"}
k = 3
threshold = 0.0

# Call the function and get the result
result = get_md_emb(
question,
column_emb,
column_csv,
column_ner,
column_join,
k,
threshold,
)
print(f"result\n{result}")
expected = """```
CREATE TABLE flight (
airport_name text, --name of the airport
flight_code text, --flight code
);
CREATE TABLE airport (
airport_name text, --name of airport
country_name text, --name of the country where the airport is located in
country_id integer, --unique id for country where airport is located in
);
CREATE TABLE country (
name text, --country name
capital text, --country capital
id integer, --unique id for country, not iso code
);
```
Additionally, the following are tables/column pairs that can be joined in this database:
```
airport.country_id can be joined with country.id
```"""
assert result == expected


def test_get_md_emb_sql_emb_empty(test_metadata):
column_emb, column_csv, column_ner, column_join = test_metadata
question = "Who ate my homework?"
k = 3
threshold = 1.0 # arbitrarily high threshold to test empty results

# Call the function and get the result
result = get_md_emb(
question,
column_emb,
column_csv,
column_ner,
column_join,
k,
threshold,
)
assert result == ""
8 changes: 4 additions & 4 deletions utils/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def get_entity_types(sentence, verbose: bool = False):
def format_topk_sql(
topk_table_columns: Dict[str, List[Tuple[str, str, str]]],
) -> str:
if len(topk_table_columns) == 0:
return ""
md_str = "```\n"
for table_name in topk_table_columns:
columns_str = ""
Expand All @@ -66,7 +68,7 @@ def format_topk_sql(
)
else:
columns_str += f"\n {column_tuple[0]} {column_tuple[1]}, "
md_str += f"CREATE TABLE {table_name} ({columns_str}\n)\n-----------\n"
md_str += f"CREATE TABLE {table_name} ({columns_str}\n);\n"
return md_str


Expand Down Expand Up @@ -155,7 +157,7 @@ def get_md_emb(
# 4) format metadata string
md_str = format_topk_sql(topk_table_columns)

if join_list:
if len(join_list) > 0:
md_str += "```\n\nAdditionally, the following are tables/column pairs that can be joined in this database:\n```\n"
md_str += "\n".join(join_list)
md_str += "\n```"
Expand All @@ -167,10 +169,8 @@ def prune_metadata_str(question, db_name, public_data=True):
root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
emb_path = os.path.join(root_dir, "data", "embeddings.pkl")
if public_data:
print("Loading public data")
import defog_data.supplementary as sup
else:
print("Loading private data")
import defog_data_private.supplementary as sup
emb, csv_descriptions = sup.load_embeddings(emb_path)
table_metadata_csv = get_md_emb(
Expand Down

0 comments on commit f4ff7cd

Please sign in to comment.