-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use dynamic number of beams depending on the prompt's token length. W…
…e scale it down approximately quadratically due to the quadratic nature of attention. We now no longer need the statements to explicitly deal with torch memory before the generate statement. Update prompt to be the same as sql-coder. Add tests. Updated requirements.txt to support peft.
- Loading branch information
1 parent
6189cf5
commit f4ff7cd
Showing
7 changed files
with
154 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,15 @@ | ||
### Instructions: | ||
Your task is to convert a question into a SQL query, given a Postgres database schema. | ||
Adhere to these rules: | ||
- **Deliberately go through the question and database schema word by word** to appropriately answer the question | ||
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. | ||
- When creating a ratio, always cast the numerator as float | ||
### Task | ||
Generate a SQL query to answer the following question: | ||
`{user_question}` | ||
|
||
### Input: | ||
Generate a SQL query that answers the question `{user_question}`. | ||
This query will run on a database whose schema is represented in this string: | ||
### Database Schema | ||
The query will run on a database with the following schema: | ||
{table_metadata_string} | ||
|
||
### Response: | ||
Based on your instructions, here is the SQL query I have generated to answer the question `{user_question}`: | ||
### SQL | ||
Follow these steps to create the SQL Query: | ||
1. Only use the columns and tables present in the database schema | ||
2. Use table aliases to prevent ambiguity when doing joins. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. | ||
|
||
Given the database schema, here is the SQL query that answers `{user_question}`: | ||
```sql |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import pytest | ||
from transformers import AutoTokenizer | ||
|
||
from eval.hf_runner import dynamic_num_beams | ||
|
||
|
||
@pytest.fixture | ||
def tokenizer(): | ||
return AutoTokenizer.from_pretrained("bigcode/starcoder") | ||
|
||
|
||
def test_dynamic_num_beams_ranges(tokenizer): | ||
prompt = "word " | ||
prompt_4 = prompt * 1023 | ||
num_beams_4 = dynamic_num_beams(prompt_4, tokenizer) | ||
assert num_beams_4 == 4 | ||
prompt_2 = prompt * 1535 | ||
num_beams_2 = dynamic_num_beams(prompt_2, tokenizer) | ||
assert num_beams_2 == 2 | ||
prompt_1 = prompt * 2048 | ||
num_beams_1 = dynamic_num_beams(prompt_1, tokenizer) | ||
assert num_beams_1 == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import pytest | ||
from utils.pruning import encoder, get_entity_types, get_md_emb | ||
|
||
|
||
@pytest.fixture | ||
def test_metadata(): | ||
column_csv = [ | ||
"country.name,text,country name", | ||
"country.capital,text,country capital", | ||
"country.id,integer,unique id for country, not iso code", | ||
"airport.country_id,integer,unique id for country where airport is located in", | ||
"airport.airport_name,text,name of airport", | ||
"flight.pilot_name,text,name of the pilot", | ||
"flight.airport_name,text,name of the airport", | ||
"flight.flight_code,text,flight code", | ||
] | ||
column_emb = encoder.encode(column_csv, convert_to_tensor=True) | ||
column_ner = { | ||
"GPE": [ | ||
"country.name,text,country name", | ||
"country.capital,text,country capital", | ||
"airport.country_name,text,name of the country where the airport is located in", | ||
], | ||
"ORG": [ | ||
"country.name,text,name of the country", | ||
"airport.airport_name,text,name of airport", | ||
"flight.airport_name,text,name of the airport", | ||
], | ||
"PERSON": ["flight.pilot_name,text,name of the pilot"], | ||
} | ||
column_join = {("airport", "country"): [("airport.country_id", "country.id")]} | ||
return column_emb, column_csv, column_ner, column_join | ||
|
||
|
||
# test embedding results + ner + join columns for sql | ||
def test_get_md_emb(test_metadata): | ||
column_emb, column_csv, column_ner, column_join = test_metadata | ||
question = "How many flights start from Los Angeles Airport (LAX)?" | ||
assert get_entity_types(question) == {"GPE", "ORG"} | ||
k = 3 | ||
threshold = 0.0 | ||
|
||
# Call the function and get the result | ||
result = get_md_emb( | ||
question, | ||
column_emb, | ||
column_csv, | ||
column_ner, | ||
column_join, | ||
k, | ||
threshold, | ||
) | ||
print(f"result\n{result}") | ||
expected = """``` | ||
CREATE TABLE flight ( | ||
airport_name text, --name of the airport | ||
flight_code text, --flight code | ||
); | ||
CREATE TABLE airport ( | ||
airport_name text, --name of airport | ||
country_name text, --name of the country where the airport is located in | ||
country_id integer, --unique id for country where airport is located in | ||
); | ||
CREATE TABLE country ( | ||
name text, --country name | ||
capital text, --country capital | ||
id integer, --unique id for country, not iso code | ||
); | ||
``` | ||
Additionally, the following are tables/column pairs that can be joined in this database: | ||
``` | ||
airport.country_id can be joined with country.id | ||
```""" | ||
assert result == expected | ||
|
||
|
||
def test_get_md_emb_sql_emb_empty(test_metadata): | ||
column_emb, column_csv, column_ner, column_join = test_metadata | ||
question = "Who ate my homework?" | ||
k = 3 | ||
threshold = 1.0 # arbitrarily high threshold to test empty results | ||
|
||
# Call the function and get the result | ||
result = get_md_emb( | ||
question, | ||
column_emb, | ||
column_csv, | ||
column_ner, | ||
column_join, | ||
k, | ||
threshold, | ||
) | ||
assert result == "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters