diff --git a/evals/registry.py b/evals/registry.py index 48a62116f7..41a9a37cdb 100644 --- a/evals/registry.py +++ b/evals/registry.py @@ -32,9 +32,10 @@ def n_ctx_from_model_name(model_name: str) -> Optional[int]: - """Returns n_ctx for a given API model name. Model list last updated 2023-06-16.""" + """Returns n_ctx for a given API model name. Model list last updated 2023-10-24.""" # note that for most models, the max tokens is n_ctx + 1 PREFIX_AND_N_CTX: list[tuple[str, int]] = [ + ("gpt-3.5-turbo-16k-", 16384), ("gpt-3.5-turbo-", 4096), ("gpt-4-32k-", 32768), ("gpt-4-", 8192), @@ -52,6 +53,7 @@ def n_ctx_from_model_name(model_name: str) -> Optional[int]: "text-davinci-002": 4096, "text-davinci-003": 4096, "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-16k": 16384, "gpt-4": 8192, "gpt-4-32k": 32768, "gpt-4-base": 8192, @@ -74,13 +76,15 @@ def is_chat_model(model_name: str) -> bool: if model_name in {"gpt-4-base"}: return False - CHAT_MODEL_NAMES = {"gpt-3.5-turbo", "gpt-4", "gpt-4-32k"} + CHAT_MODEL_NAMES = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k"} + if model_name in CHAT_MODEL_NAMES: return True - for model_prefix in {"gpt-3.5-turbo-", "gpt-4-", "gpt-4-32k-"}: + for model_prefix in {"gpt-3.5-turbo-", "gpt-4-"}: if model_name.startswith(model_prefix): return True + return False diff --git a/evals/registry_test.py b/evals/registry_test.py index 1b6c475ba0..2ff9e16a08 100644 --- a/evals/registry_test.py +++ b/evals/registry_test.py @@ -2,22 +2,25 @@ def test_n_ctx_from_model_name(): + assert n_ctx_from_model_name("gpt-3.5-turbo") == 4096 + assert n_ctx_from_model_name("gpt-3.5-turbo-0613") == 4096 + assert n_ctx_from_model_name("gpt-3.5-turbo-16k") == 16384 + assert n_ctx_from_model_name("gpt-3.5-turbo-16k-0613") == 16384 assert n_ctx_from_model_name("gpt-4") == 8192 - assert n_ctx_from_model_name("gpt-4-0314") == 8192 assert n_ctx_from_model_name("gpt-4-0613") == 8192 assert n_ctx_from_model_name("gpt-4-32k") == 32768 - assert n_ctx_from_model_name("gpt-4-32k-0314") == 32768 assert n_ctx_from_model_name("gpt-4-32k-0613") == 32768 def test_is_chat_model(): assert is_chat_model("gpt-3.5-turbo") - assert is_chat_model("gpt-3.5-turbo-0314") assert is_chat_model("gpt-3.5-turbo-0613") + assert is_chat_model("gpt-3.5-turbo-16k") + assert is_chat_model("gpt-3.5-turbo-16k-0613") assert is_chat_model("gpt-4") - assert is_chat_model("gpt-4-0314") assert is_chat_model("gpt-4-0613") assert is_chat_model("gpt-4-32k") - assert is_chat_model("gpt-4-32k-0314") assert is_chat_model("gpt-4-32k-0613") assert not is_chat_model("text-davinci-003") + assert not is_chat_model("gpt4-base") + assert not is_chat_model("code-davinci-002")