-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use dynamic number of beams depending on the prompt's token length. W…
…e scale it down approximately quadratically due to the quadratic nature of attention. We now no longer need the statements to explicitly deal with torch memory before the generate statement. Update prompt to be the same as sql-coder. Add tests.
- Loading branch information
1 parent
6189cf5
commit cf939dc
Showing
5 changed files
with
53 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,15 @@ | ||
### Instructions: | ||
Your task is to convert a question into a SQL query, given a Postgres database schema. | ||
Adhere to these rules: | ||
- **Deliberately go through the question and database schema word by word** to appropriately answer the question | ||
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. | ||
- When creating a ratio, always cast the numerator as float | ||
### Task | ||
Generate a SQL query to answer the following question: | ||
`{user_question}` | ||
|
||
### Input: | ||
Generate a SQL query that answers the question `{user_question}`. | ||
This query will run on a database whose schema is represented in this string: | ||
### Database Schema | ||
The query will run on a database with the following schema: | ||
{table_metadata_string} | ||
|
||
### Response: | ||
Based on your instructions, here is the SQL query I have generated to answer the question `{user_question}`: | ||
### SQL | ||
Follow these steps to create the SQL Query: | ||
1. Only use the columns and tables present in the database schema | ||
2. Use table aliases to prevent ambiguity when doing joins. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. | ||
|
||
Given the database schema, here is the SQL query that answers `{user_question}`: | ||
```sql |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import pytest | ||
from transformers import AutoTokenizer | ||
|
||
from eval.hf_runner import dynamic_num_beams | ||
|
||
|
||
@pytest.fixture | ||
def tokenizer(): | ||
return AutoTokenizer.from_pretrained("bigcode/starcoder") | ||
|
||
|
||
def test_dynamic_num_beams_ranges(tokenizer): | ||
prompt = "word " | ||
prompt_4 = prompt * 1023 | ||
num_beams_4 = dynamic_num_beams(prompt_4, tokenizer) | ||
assert num_beams_4 == 4 | ||
prompt_2 = prompt * 1535 | ||
num_beams_2 = dynamic_num_beams(prompt_2, tokenizer) | ||
assert num_beams_2 == 2 | ||
prompt_1 = prompt * 2048 | ||
num_beams_1 = dynamic_num_beams(prompt_1, tokenizer) | ||
assert num_beams_1 == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters