openai_compatible_api.py (12223B)
1 import math 2 from contextlib import contextmanager 3 from typing import AsyncGenerator 4 from typing import Iterator 5 6 import httpx 7 from loguru import logger 8 from openai import AsyncOpenAI 9 from openai import AsyncStream 10 from openai import InternalServerError 11 from openai import NOT_GIVEN 12 from openai import NotGiven 13 from openai._exceptions import APIConnectionError 14 from openai._exceptions import BadRequestError 15 from openai._exceptions import RateLimitError 16 from openai.types.chat import ChatCompletion 17 18 from vet.imbue_core.agents.llm_apis.api_utils import convert_prompt_to_openai_messages 19 from vet.imbue_core.agents.llm_apis.constants import approximate_token_count 20 from vet.imbue_core.agents.llm_apis.data_types import CachingInfo 21 from vet.imbue_core.agents.llm_apis.data_types import CostedLanguageModelResponse 22 from vet.imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams 23 from vet.imbue_core.agents.llm_apis.data_types import LanguageModelResponse 24 from vet.imbue_core.agents.llm_apis.data_types import LanguageModelResponseUsage 25 from vet.imbue_core.agents.llm_apis.data_types import ResponseStopReason 26 from vet.imbue_core.agents.llm_apis.errors import BadAPIRequestError 27 from vet.imbue_core.agents.llm_apis.errors import PromptTooLongError 28 from vet.imbue_core.agents.llm_apis.errors import TransientLanguageModelError 29 from vet.imbue_core.agents.llm_apis.language_model_api import LanguageModelAPI 30 from vet.imbue_core.agents.llm_apis.models import ModelInfo 31 from vet.imbue_core.agents.llm_apis.openai_data_types import OpenAICachingInfo 32 from vet.imbue_core.agents.llm_apis.stream import LanguageModelStreamDeltaEvent 33 from vet.imbue_core.agents.llm_apis.stream import LanguageModelStreamEndEvent 34 from vet.imbue_core.agents.llm_apis.stream import LanguageModelStreamEvent 35 from vet.imbue_core.agents.llm_apis.stream import LanguageModelStreamStartEvent 36 from vet.imbue_core.frozen_utils import FrozenDict 37 from vet.imbue_core.frozen_utils import FrozenMapping 38 from vet.imbue_core.itertools import only 39 from vet.imbue_core.secrets_utils import get_secret 40 41 _OPENAI_COMPATIBLE_STOP_REASON_TO_STOP_REASON: FrozenMapping[str, ResponseStopReason] = FrozenDict( 42 { 43 "stop": ResponseStopReason.END_TURN, 44 "length": ResponseStopReason.MAX_TOKENS, 45 "tool_calls": ResponseStopReason.TOOL_CALLS, 46 "function_call": ResponseStopReason.FUNCTION_CALL, 47 "content_filter": ResponseStopReason.CONTENT_FILTER, 48 "None": ResponseStopReason.NONE, 49 } 50 ) 51 52 53 # TODO: Should the pre-defined OpenAI model class inherit from this? 54 class OpenAICompatibleAPI(LanguageModelAPI): 55 model_name: str 56 base_url: str = "https://api.openai.com/v1" 57 api_key_env: str = "OPENAI_API_KEY" 58 context_window: int | None = None 59 max_output_tokens: int | None = None 60 is_conversational: bool = True 61 presence_penalty: float = 0.0 62 supports_temperature: bool = True 63 # this shouldn't really ever even be used, but just in case 64 stop_token_log_probability: float = math.log(0.9999) 65 66 @property 67 def model_info(self) -> ModelInfo: 68 if self.context_window is None or self.max_output_tokens is None: 69 raise ValueError("Must provide context_window and max_output_tokens, or subclass must override model_info") 70 return ModelInfo( 71 model_name=self.model_name, 72 cost_per_input_token=0.0, 73 cost_per_output_token=0.0, 74 max_input_tokens=self.context_window, 75 max_output_tokens=self.max_output_tokens, 76 rate_limit_req=None, 77 ) 78 79 def _get_client(self) -> AsyncOpenAI: 80 api_key = get_secret(self.api_key_env) if self.api_key_env else "" 81 if not api_key: 82 api_key = "not-required" 83 logger.debug("API key not set, attempting to use API without key.") 84 85 return AsyncOpenAI( 86 api_key=api_key, 87 base_url=self.base_url, 88 ) 89 90 @contextmanager 91 def _exception_handler(self, prompt: str) -> Iterator[None]: 92 try: 93 yield 94 except BadRequestError as e: 95 if e.code == "context_length_exceeded": 96 prompt_len = self.count_tokens(prompt) 97 max_prompt_len = self.model_info.max_input_tokens 98 logger.debug( 99 "PromptTooLongError max_prompt_len={max_prompt_len} prompt_len={prompt_len}", 100 max_prompt_len=max_prompt_len, 101 prompt_len=prompt_len, 102 ) 103 raise PromptTooLongError(prompt_len, max_prompt_len) from e 104 logger.debug("BadAPIRequestError {e}", e=e) 105 raise BadAPIRequestError(str(e)) from e 106 except APIConnectionError as e: 107 logger.debug("API connection error: {e}", e=e) 108 raise TransientLanguageModelError("APIConnectionError") from e 109 except RateLimitError as e: 110 if e.code == "insufficient_quota": 111 raise 112 logger.debug("Rate limited: {e}", e=e) 113 raise TransientLanguageModelError("RateLimitError") from e 114 except httpx.RemoteProtocolError as e: 115 logger.debug("httpx.RemoteProtocolError {e}", e=e) 116 raise TransientLanguageModelError("httpx.RemoteProtocolError") from e 117 except InternalServerError as e: 118 logger.debug("InternalServerError {e}", e=e) 119 raise TransientLanguageModelError("InternalServerError") from e 120 121 async def _call_api( 122 self, 123 prompt: str, 124 params: LanguageModelGenerationParams, 125 network_failure_count: int = 0, 126 ) -> CostedLanguageModelResponse: 127 messages = convert_prompt_to_openai_messages(prompt) 128 129 with self._exception_handler(prompt): 130 client = self._get_client() 131 132 temperature: NotGiven | float = params.temperature if self.supports_temperature else NOT_GIVEN 133 134 api_result = await client.chat.completions.create( 135 model=self.model_name, 136 messages=messages, 137 max_completion_tokens=params.max_tokens, 138 n=params.count, 139 temperature=temperature, 140 stream=False, 141 seed=params.seed, 142 stop=params.stop, 143 presence_penalty=self.presence_penalty, 144 ) 145 assert isinstance(api_result, ChatCompletion) 146 147 usage = api_result.usage 148 if usage is not None: 149 completion_tokens = usage.completion_tokens 150 prompt_tokens = usage.prompt_tokens 151 cached_tokens = ( 152 usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details is not None else 0 153 ) or 0 154 caching_info = CachingInfo( 155 read_from_cache=cached_tokens, 156 provider_specific_data=OpenAICachingInfo(), 157 ) 158 else: 159 completion_tokens = 0 160 prompt_tokens = self.count_tokens(prompt) 161 cached_tokens = None 162 caching_info = None 163 164 results = self._parse_response( 165 api_result, 166 prompt_tokens=prompt_tokens, 167 stop=params.stop, 168 network_failure_count=network_failure_count, 169 ) 170 171 logger.trace("text: {text}", text=results[0].text) 172 dollars_used = self.calculate_cost(prompt_tokens, completion_tokens) 173 logger.trace("dollars used: {dollars_used}", dollars_used=dollars_used) 174 175 return CostedLanguageModelResponse( 176 usage=LanguageModelResponseUsage( 177 prompt_tokens_used=prompt_tokens, 178 completion_tokens_used=completion_tokens, 179 dollars_used=dollars_used, 180 caching_info=caching_info, 181 ), 182 responses=tuple(results), 183 ) 184 185 async def _get_api_stream( 186 self, 187 prompt: str, 188 params: LanguageModelGenerationParams, 189 ) -> AsyncGenerator[LanguageModelStreamEvent, None]: 190 messages = convert_prompt_to_openai_messages(prompt) 191 192 with self._exception_handler(prompt): 193 client = self._get_client() 194 195 temperature: NotGiven | float = params.temperature if self.supports_temperature else NOT_GIVEN 196 197 api_result = await client.chat.completions.create( 198 model=self.model_name, 199 messages=messages, 200 max_completion_tokens=params.max_tokens, 201 n=1, 202 temperature=temperature, 203 stop=params.stop, 204 seed=params.seed, 205 stream=True, 206 stream_options={"include_usage": True}, 207 presence_penalty=self.presence_penalty, 208 ) 209 assert isinstance(api_result, AsyncStream) 210 211 yield LanguageModelStreamStartEvent() 212 213 usage = None 214 finish_reason: str | None = None 215 async for chunk in api_result: 216 if hasattr(chunk, "usage") and chunk.usage is not None: 217 usage = chunk.usage 218 continue 219 220 if chunk.choices: 221 assert len(chunk.choices) == 1, "Currently only count=1 supported for streaming API." 222 data = only(chunk.choices) 223 delta = data.delta.content 224 if delta is not None: 225 yield LanguageModelStreamDeltaEvent(delta=delta) 226 if data.finish_reason: 227 finish_reason = str(data.finish_reason) 228 229 stop_reason = _OPENAI_COMPATIBLE_STOP_REASON_TO_STOP_REASON.get(str(finish_reason), ResponseStopReason.NONE) 230 if params.stop is not None and stop_reason == ResponseStopReason.END_TURN: 231 yield LanguageModelStreamDeltaEvent(delta=params.stop) 232 233 if usage is not None: 234 completion_tokens = usage.completion_tokens 235 prompt_tokens = usage.prompt_tokens 236 dollars_used = self.calculate_cost(prompt_tokens, completion_tokens) 237 cached_tokens = ( 238 usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details is not None else 0 239 ) or 0 240 caching_info = CachingInfo( 241 read_from_cache=cached_tokens, 242 provider_specific_data=OpenAICachingInfo(), 243 ) 244 else: 245 completion_tokens = -1 246 prompt_tokens = -1 247 dollars_used = -1 248 caching_info = None 249 logger.trace("dollars used: {dollars_used}", dollars_used=dollars_used) 250 251 yield LanguageModelStreamEndEvent( 252 usage=LanguageModelResponseUsage( 253 prompt_tokens_used=prompt_tokens, 254 completion_tokens_used=completion_tokens, 255 dollars_used=dollars_used, 256 caching_info=caching_info, 257 ), 258 stop_reason=stop_reason, 259 ) 260 261 def count_tokens(self, text: str) -> int: 262 return approximate_token_count(text) 263 264 def _parse_response( 265 self, 266 response: ChatCompletion, 267 prompt_tokens: int, 268 stop: str | None, 269 network_failure_count: int, 270 ) -> tuple[LanguageModelResponse, ...]: 271 results = [] 272 for data in response.choices: 273 assert data.message.content is not None 274 text = data.message.content 275 token_count = self.count_tokens(text) + prompt_tokens 276 stop_reason = _OPENAI_COMPATIBLE_STOP_REASON_TO_STOP_REASON.get( 277 str(data.finish_reason), ResponseStopReason.NONE 278 ) 279 if stop is not None and stop_reason == ResponseStopReason.END_TURN: 280 text += stop 281 result = LanguageModelResponse( 282 text=text, 283 token_count=token_count, 284 stop_reason=stop_reason, 285 network_failure_count=network_failure_count, 286 ) 287 results.append(result) 288 return tuple(results)