vet

Mirror of Vet, an AI code review tool
git clone git://git.laack.co/vet.git
Log | Files | Refs | README | LICENSE

openai_compatible_api.py (12223B)


      1 import math
      2 from contextlib import contextmanager
      3 from typing import AsyncGenerator
      4 from typing import Iterator
      5 
      6 import httpx
      7 from loguru import logger
      8 from openai import AsyncOpenAI
      9 from openai import AsyncStream
     10 from openai import InternalServerError
     11 from openai import NOT_GIVEN
     12 from openai import NotGiven
     13 from openai._exceptions import APIConnectionError
     14 from openai._exceptions import BadRequestError
     15 from openai._exceptions import RateLimitError
     16 from openai.types.chat import ChatCompletion
     17 
     18 from vet.imbue_core.agents.llm_apis.api_utils import convert_prompt_to_openai_messages
     19 from vet.imbue_core.agents.llm_apis.constants import approximate_token_count
     20 from vet.imbue_core.agents.llm_apis.data_types import CachingInfo
     21 from vet.imbue_core.agents.llm_apis.data_types import CostedLanguageModelResponse
     22 from vet.imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams
     23 from vet.imbue_core.agents.llm_apis.data_types import LanguageModelResponse
     24 from vet.imbue_core.agents.llm_apis.data_types import LanguageModelResponseUsage
     25 from vet.imbue_core.agents.llm_apis.data_types import ResponseStopReason
     26 from vet.imbue_core.agents.llm_apis.errors import BadAPIRequestError
     27 from vet.imbue_core.agents.llm_apis.errors import PromptTooLongError
     28 from vet.imbue_core.agents.llm_apis.errors import TransientLanguageModelError
     29 from vet.imbue_core.agents.llm_apis.language_model_api import LanguageModelAPI
     30 from vet.imbue_core.agents.llm_apis.models import ModelInfo
     31 from vet.imbue_core.agents.llm_apis.openai_data_types import OpenAICachingInfo
     32 from vet.imbue_core.agents.llm_apis.stream import LanguageModelStreamDeltaEvent
     33 from vet.imbue_core.agents.llm_apis.stream import LanguageModelStreamEndEvent
     34 from vet.imbue_core.agents.llm_apis.stream import LanguageModelStreamEvent
     35 from vet.imbue_core.agents.llm_apis.stream import LanguageModelStreamStartEvent
     36 from vet.imbue_core.frozen_utils import FrozenDict
     37 from vet.imbue_core.frozen_utils import FrozenMapping
     38 from vet.imbue_core.itertools import only
     39 from vet.imbue_core.secrets_utils import get_secret
     40 
     41 _OPENAI_COMPATIBLE_STOP_REASON_TO_STOP_REASON: FrozenMapping[str, ResponseStopReason] = FrozenDict(
     42     {
     43         "stop": ResponseStopReason.END_TURN,
     44         "length": ResponseStopReason.MAX_TOKENS,
     45         "tool_calls": ResponseStopReason.TOOL_CALLS,
     46         "function_call": ResponseStopReason.FUNCTION_CALL,
     47         "content_filter": ResponseStopReason.CONTENT_FILTER,
     48         "None": ResponseStopReason.NONE,
     49     }
     50 )
     51 
     52 
     53 # TODO: Should the pre-defined OpenAI model class inherit from this?
     54 class OpenAICompatibleAPI(LanguageModelAPI):
     55     model_name: str
     56     base_url: str = "https://api.openai.com/v1"
     57     api_key_env: str = "OPENAI_API_KEY"
     58     context_window: int | None = None
     59     max_output_tokens: int | None = None
     60     is_conversational: bool = True
     61     presence_penalty: float = 0.0
     62     supports_temperature: bool = True
     63     # this shouldn't really ever even be used, but just in case
     64     stop_token_log_probability: float = math.log(0.9999)
     65 
     66     @property
     67     def model_info(self) -> ModelInfo:
     68         if self.context_window is None or self.max_output_tokens is None:
     69             raise ValueError("Must provide context_window and max_output_tokens, or subclass must override model_info")
     70         return ModelInfo(
     71             model_name=self.model_name,
     72             cost_per_input_token=0.0,
     73             cost_per_output_token=0.0,
     74             max_input_tokens=self.context_window,
     75             max_output_tokens=self.max_output_tokens,
     76             rate_limit_req=None,
     77         )
     78 
     79     def _get_client(self) -> AsyncOpenAI:
     80         api_key = get_secret(self.api_key_env) if self.api_key_env else ""
     81         if not api_key:
     82             api_key = "not-required"
     83             logger.debug("API key not set, attempting to use API without key.")
     84 
     85         return AsyncOpenAI(
     86             api_key=api_key,
     87             base_url=self.base_url,
     88         )
     89 
     90     @contextmanager
     91     def _exception_handler(self, prompt: str) -> Iterator[None]:
     92         try:
     93             yield
     94         except BadRequestError as e:
     95             if e.code == "context_length_exceeded":
     96                 prompt_len = self.count_tokens(prompt)
     97                 max_prompt_len = self.model_info.max_input_tokens
     98                 logger.debug(
     99                     "PromptTooLongError max_prompt_len={max_prompt_len} prompt_len={prompt_len}",
    100                     max_prompt_len=max_prompt_len,
    101                     prompt_len=prompt_len,
    102                 )
    103                 raise PromptTooLongError(prompt_len, max_prompt_len) from e
    104             logger.debug("BadAPIRequestError {e}", e=e)
    105             raise BadAPIRequestError(str(e)) from e
    106         except APIConnectionError as e:
    107             logger.debug("API connection error: {e}", e=e)
    108             raise TransientLanguageModelError("APIConnectionError") from e
    109         except RateLimitError as e:
    110             if e.code == "insufficient_quota":
    111                 raise
    112             logger.debug("Rate limited: {e}", e=e)
    113             raise TransientLanguageModelError("RateLimitError") from e
    114         except httpx.RemoteProtocolError as e:
    115             logger.debug("httpx.RemoteProtocolError {e}", e=e)
    116             raise TransientLanguageModelError("httpx.RemoteProtocolError") from e
    117         except InternalServerError as e:
    118             logger.debug("InternalServerError {e}", e=e)
    119             raise TransientLanguageModelError("InternalServerError") from e
    120 
    121     async def _call_api(
    122         self,
    123         prompt: str,
    124         params: LanguageModelGenerationParams,
    125         network_failure_count: int = 0,
    126     ) -> CostedLanguageModelResponse:
    127         messages = convert_prompt_to_openai_messages(prompt)
    128 
    129         with self._exception_handler(prompt):
    130             client = self._get_client()
    131 
    132             temperature: NotGiven | float = params.temperature if self.supports_temperature else NOT_GIVEN
    133 
    134             api_result = await client.chat.completions.create(
    135                 model=self.model_name,
    136                 messages=messages,
    137                 max_completion_tokens=params.max_tokens,
    138                 n=params.count,
    139                 temperature=temperature,
    140                 stream=False,
    141                 seed=params.seed,
    142                 stop=params.stop,
    143                 presence_penalty=self.presence_penalty,
    144             )
    145             assert isinstance(api_result, ChatCompletion)
    146 
    147             usage = api_result.usage
    148             if usage is not None:
    149                 completion_tokens = usage.completion_tokens
    150                 prompt_tokens = usage.prompt_tokens
    151                 cached_tokens = (
    152                     usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details is not None else 0
    153                 ) or 0
    154                 caching_info = CachingInfo(
    155                     read_from_cache=cached_tokens,
    156                     provider_specific_data=OpenAICachingInfo(),
    157                 )
    158             else:
    159                 completion_tokens = 0
    160                 prompt_tokens = self.count_tokens(prompt)
    161                 cached_tokens = None
    162                 caching_info = None
    163 
    164             results = self._parse_response(
    165                 api_result,
    166                 prompt_tokens=prompt_tokens,
    167                 stop=params.stop,
    168                 network_failure_count=network_failure_count,
    169             )
    170 
    171             logger.trace("text: {text}", text=results[0].text)
    172             dollars_used = self.calculate_cost(prompt_tokens, completion_tokens)
    173             logger.trace("dollars used: {dollars_used}", dollars_used=dollars_used)
    174 
    175             return CostedLanguageModelResponse(
    176                 usage=LanguageModelResponseUsage(
    177                     prompt_tokens_used=prompt_tokens,
    178                     completion_tokens_used=completion_tokens,
    179                     dollars_used=dollars_used,
    180                     caching_info=caching_info,
    181                 ),
    182                 responses=tuple(results),
    183             )
    184 
    185     async def _get_api_stream(
    186         self,
    187         prompt: str,
    188         params: LanguageModelGenerationParams,
    189     ) -> AsyncGenerator[LanguageModelStreamEvent, None]:
    190         messages = convert_prompt_to_openai_messages(prompt)
    191 
    192         with self._exception_handler(prompt):
    193             client = self._get_client()
    194 
    195             temperature: NotGiven | float = params.temperature if self.supports_temperature else NOT_GIVEN
    196 
    197             api_result = await client.chat.completions.create(
    198                 model=self.model_name,
    199                 messages=messages,
    200                 max_completion_tokens=params.max_tokens,
    201                 n=1,
    202                 temperature=temperature,
    203                 stop=params.stop,
    204                 seed=params.seed,
    205                 stream=True,
    206                 stream_options={"include_usage": True},
    207                 presence_penalty=self.presence_penalty,
    208             )
    209             assert isinstance(api_result, AsyncStream)
    210 
    211             yield LanguageModelStreamStartEvent()
    212 
    213             usage = None
    214             finish_reason: str | None = None
    215             async for chunk in api_result:
    216                 if hasattr(chunk, "usage") and chunk.usage is not None:
    217                     usage = chunk.usage
    218                     continue
    219 
    220                 if chunk.choices:
    221                     assert len(chunk.choices) == 1, "Currently only count=1 supported for streaming API."
    222                     data = only(chunk.choices)
    223                     delta = data.delta.content
    224                     if delta is not None:
    225                         yield LanguageModelStreamDeltaEvent(delta=delta)
    226                     if data.finish_reason:
    227                         finish_reason = str(data.finish_reason)
    228 
    229             stop_reason = _OPENAI_COMPATIBLE_STOP_REASON_TO_STOP_REASON.get(str(finish_reason), ResponseStopReason.NONE)
    230             if params.stop is not None and stop_reason == ResponseStopReason.END_TURN:
    231                 yield LanguageModelStreamDeltaEvent(delta=params.stop)
    232 
    233             if usage is not None:
    234                 completion_tokens = usage.completion_tokens
    235                 prompt_tokens = usage.prompt_tokens
    236                 dollars_used = self.calculate_cost(prompt_tokens, completion_tokens)
    237                 cached_tokens = (
    238                     usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details is not None else 0
    239                 ) or 0
    240                 caching_info = CachingInfo(
    241                     read_from_cache=cached_tokens,
    242                     provider_specific_data=OpenAICachingInfo(),
    243                 )
    244             else:
    245                 completion_tokens = -1
    246                 prompt_tokens = -1
    247                 dollars_used = -1
    248                 caching_info = None
    249             logger.trace("dollars used: {dollars_used}", dollars_used=dollars_used)
    250 
    251             yield LanguageModelStreamEndEvent(
    252                 usage=LanguageModelResponseUsage(
    253                     prompt_tokens_used=prompt_tokens,
    254                     completion_tokens_used=completion_tokens,
    255                     dollars_used=dollars_used,
    256                     caching_info=caching_info,
    257                 ),
    258                 stop_reason=stop_reason,
    259             )
    260 
    261     def count_tokens(self, text: str) -> int:
    262         return approximate_token_count(text)
    263 
    264     def _parse_response(
    265         self,
    266         response: ChatCompletion,
    267         prompt_tokens: int,
    268         stop: str | None,
    269         network_failure_count: int,
    270     ) -> tuple[LanguageModelResponse, ...]:
    271         results = []
    272         for data in response.choices:
    273             assert data.message.content is not None
    274             text = data.message.content
    275             token_count = self.count_tokens(text) + prompt_tokens
    276             stop_reason = _OPENAI_COMPATIBLE_STOP_REASON_TO_STOP_REASON.get(
    277                 str(data.finish_reason), ResponseStopReason.NONE
    278             )
    279             if stop is not None and stop_reason == ResponseStopReason.END_TURN:
    280                 text += stop
    281             result = LanguageModelResponse(
    282                 text=text,
    283                 token_count=token_count,
    284                 stop_reason=stop_reason,
    285                 network_failure_count=network_failure_count,
    286             )
    287             results.append(result)
    288         return tuple(results)