From 6e8a6e45c0833654135d9c3d38d618fa36bab7c9 Mon Sep 17 00:00:00 2001 From: patcher99 Date: Sat, 16 Mar 2024 12:59:48 +0530 Subject: [PATCH] fix completions in openai --- src/dokumetry/async_openai.py | 20 +++++++++++--------- src/dokumetry/openai.py | 24 +++++++++++++----------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/dokumetry/async_openai.py b/src/dokumetry/async_openai.py index 343f12e..23bdac2 100644 --- a/src/dokumetry/async_openai.py +++ b/src/dokumetry/async_openai.py @@ -48,11 +48,12 @@ async def llm_chat_completions(*args, **kwargs): async def stream_generator(): accumulated_content = "" async for chunk in await original_chat_create(*args, **kwargs): - #pylint: disable=line-too-long - if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'): - content = chunk.choices[0].delta.content - if content: - accumulated_content += content + if len(chunk.choices) > 0: + #pylint: disable=line-too-long + if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'): + content = chunk.choices[0].delta.content + if content: + accumulated_content += content yield chunk response_id = chunk.id end_time = time.time() @@ -171,10 +172,11 @@ async def llm_completions(*args, **kwargs): async def stream_generator(): accumulated_content = "" async for chunk in await original_completions_create(*args, **kwargs): - if hasattr(chunk.choices[0].text, 'content'): - content = chunk.choices[0].text - if content: - accumulated_content += content + if len(chunk.choices) > 0: + if hasattr(chunk.choices[0], 'text'): + content = chunk.choices[0].text + if content: + accumulated_content += content yield chunk response_id = chunk.id end_time = time.time() diff --git a/src/dokumetry/openai.py b/src/dokumetry/openai.py index 0a7835f..dbdaae6 100644 --- a/src/dokumetry/openai.py +++ b/src/dokumetry/openai.py @@ -48,11 +48,12 @@ def llm_chat_completions(*args, **kwargs): def stream_generator(): accumulated_content = "" for chunk in original_chat_create(*args, **kwargs): - #pylint: disable=line-too-long - if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'): - content = chunk.choices[0].delta.content - if content: - accumulated_content += content + if len(chunk.choices) > 0: + #pylint: disable=line-too-long + if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'): + content = chunk.choices[0].delta.content + if content: + accumulated_content += content yield chunk response_id = chunk.id end_time = time.time() @@ -170,11 +171,12 @@ def llm_completions(*args, **kwargs): if streaming: def stream_generator(): accumulated_content = "" - for chunk in original_chat_create(*args, **kwargs): - if hasattr(chunk.choices[0].text, 'content'): - content = chunk.choices[0].text - if content: - accumulated_content += content + for chunk in original_completions_create(*args, **kwargs): + if len(chunk.choices) > 0: + if hasattr(chunk.choices[0], 'text'): + content = chunk.choices[0].text + if content: + accumulated_content += content yield chunk response_id = chunk.id end_time = time.time() @@ -258,7 +260,7 @@ def patched_embeddings_create(*args, **kwargs): end_time = time.time() duration = end_time - start_time model = kwargs.get('model', "No Model provided") - prompt = kwargs.get('input', "No prompt provided") + prompt = ', '.join(kwargs.get('input', [])) data = { "environment": environment,