Skip to content

Commit

Permalink
feat: add support for Google Palm (#126)
Browse files Browse the repository at this point in the history
* feat: Add support for Google Palm

* Fix linting

* rename test

* add generative-ai dependency

* Refactor with Base class

* fix linting issues
  • Loading branch information
WaseemSabir authored May 19, 2023
1 parent a899d1a commit baedcbc
Show file tree
Hide file tree
Showing 6 changed files with 506 additions and 2 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@ pandasai.egg-info
/venv

# command line
/pandasai_cli.egg-info
/pandasai_cli.egg-info

# pycharm
.idea/
71 changes: 71 additions & 0 deletions pandasai/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import openai
import requests
from google import generativeai

from ..constants import END_CODE_TAG, START_CODE_TAG
from ..exceptions import (
Expand Down Expand Up @@ -238,3 +239,73 @@ def call(self, instruction: str, value: str, suffix: str = "") -> str:
# replace instruction + value from the inputs to avoid showing it in the output
output = response.replace(instruction + value + suffix, "")
return output


class BaseGoogle(LLM):
"""Base class to implement a new Google LLM"""

genai: Any
temperature: Optional[float] = 0
top_p: Optional[float] = None
top_k: Optional[float] = None
max_output_tokens: Optional[int] = None

def _configure(self, api_key: str):
if not api_key:
raise APIKeyNotFoundError("Google Palm API key is required")

generativeai.configure(api_key=api_key)
self.genai = generativeai

def _valid_params(self):
return ["temperature", "top_p", "top_k", "max_output_tokens"]

def _set_params(self, **kwargs):
valid_params = self._valid_params()
for key, value in kwargs.items():
if key in valid_params:
setattr(self, key, value)

def _validate(self):
"""Validates the parameters for Google"""

if self.temperature is not None and not 0 <= self.temperature <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")

if self.top_p is not None and not 0 <= self.top_p <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")

if self.top_k is not None and not 0 <= self.top_k <= 1:
raise ValueError("top_k must be in the range [0.0, 1.0]")

if self.max_output_tokens is not None and self.max_output_tokens <= 0:
raise ValueError("max_output_tokens must be greater than zero")

@abstractmethod
def _generate_text(self, prompt: str) -> str:
"""
Generates text for prompt, specific to implementation.
Args:
prompt (str): Prompt
Returns:
str: LLM response
"""
raise MethodNotImplementedError("method has not been implemented")

def call(self, instruction: str, value: str, suffix: str = "") -> str:
"""
Call the Google LLM.
Args:
instruction (str): Instruction to pass
value (str): Value to pass
suffix (str): Suffix to pass
Returns:
str: Response
"""
self.last_prompt = str(instruction) + str(value)
prompt = str(instruction) + str(value) + suffix
return self._generate_text(prompt)
46 changes: 46 additions & 0 deletions pandasai/llm/google_palm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Google Palm LLM"""
from .base import BaseGoogle


class GooglePalm(BaseGoogle):
"""Google Palm LLM"""

model: str = "models/text-bison-001"

def __init__(self, api_key: str, **kwargs):
self._configure(api_key=api_key)
self._set_params(**kwargs)

def _valid_params(self):
return super()._valid_params() + ["model"]

def _validate(self):
super()._validate()

if not self.model:
raise ValueError("model is required.")

def _generate_text(self, prompt: str) -> str:
"""
Generates text for prompt
Args:
prompt (str): Prompt
Returns:
str: LLM response
"""
self._validate()
completion = self.genai.generate_text(
model=self.model,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
max_output_tokens=self.max_output_tokens,
)
return completion.result

@property
def type(self) -> str:
return "google-palm"
Loading

0 comments on commit baedcbc

Please sign in to comment.