openai_api.py (23998B)
1 import asyncio 2 import enum 3 import re 4 from collections import defaultdict 5 from contextlib import contextmanager 6 from functools import lru_cache 7 from typing import AsyncGenerator 8 from typing import Iterator 9 from typing import Mapping 10 11 import httpx 12 import tiktoken 13 from loguru import logger 14 from openai import AsyncStream 15 from openai import InternalServerError 16 from openai import NOT_GIVEN 17 from openai import NotGiven 18 from openai._client import AsyncOpenAI 19 from openai._exceptions import APIConnectionError 20 from openai._exceptions import BadRequestError 21 from openai._exceptions import RateLimitError 22 from openai.types.chat import ChatCompletion 23 from pydantic.functional_validators import field_validator 24 25 from vet.imbue_core.agents.llm_apis.api_utils import convert_prompt_to_openai_messages 26 from vet.imbue_core.agents.llm_apis.data_types import CachingInfo 27 from vet.imbue_core.agents.llm_apis.data_types import CostedLanguageModelResponse 28 from vet.imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams 29 from vet.imbue_core.agents.llm_apis.data_types import LanguageModelResponse 30 from vet.imbue_core.agents.llm_apis.data_types import LanguageModelResponseUsage 31 from vet.imbue_core.agents.llm_apis.data_types import LanguageModelResponseWithLogits 32 from vet.imbue_core.agents.llm_apis.data_types import ResponseStopReason 33 from vet.imbue_core.agents.llm_apis.data_types import TokenProbability 34 from vet.imbue_core.agents.llm_apis.errors import BadAPIRequestError 35 from vet.imbue_core.agents.llm_apis.errors import LanguageModelInvalidModelNameError 36 from vet.imbue_core.agents.llm_apis.errors import MissingAPIKeyError 37 from vet.imbue_core.agents.llm_apis.errors import PromptTooLongError 38 from vet.imbue_core.agents.llm_apis.errors import TransientLanguageModelError 39 from vet.imbue_core.agents.llm_apis.models import ModelInfo 40 from vet.imbue_core.agents.llm_apis.openai_compatible_api import OpenAICompatibleAPI 41 from vet.imbue_core.agents.llm_apis.openai_compatible_api import _OPENAI_COMPATIBLE_STOP_REASON_TO_STOP_REASON 42 from vet.imbue_core.agents.llm_apis.openai_data_types import OpenAICachingInfo 43 from vet.imbue_core.agents.llm_apis.stream import LanguageModelStreamDeltaEvent 44 from vet.imbue_core.agents.llm_apis.stream import LanguageModelStreamEndEvent 45 from vet.imbue_core.agents.llm_apis.stream import LanguageModelStreamEvent 46 from vet.imbue_core.agents.llm_apis.stream import LanguageModelStreamStartEvent 47 from vet.imbue_core.frozen_utils import FrozenDict 48 from vet.imbue_core.frozen_utils import FrozenMapping 49 from vet.imbue_core.itertools import only 50 from vet.imbue_core.secrets_utils import get_secret 51 52 FINE_TUNED_GPT4O_MINI_2024_07_18_PREFIX = "ft:gpt-4o-mini-2024-07-18" 53 FINE_TUNED_GPT4O_2024_08_06_PREFIX = "ft:gpt-4o-2024-08-06" 54 55 56 class OpenAIModelName(enum.StrEnum): 57 GPT_4_1 = "gpt-4.1" 58 GPT_4_1_MINI = "gpt-4.1-mini" 59 O3 = "o3" 60 O3_MINI = "o3-mini" 61 O4_MINI = "o4-mini" 62 GPT_5 = "gpt-5" 63 GPT_5_MINI = "gpt-5-mini" 64 GPT_5_1 = "gpt-5.1" 65 GPT_5_2 = "gpt-5.2" 66 GPT_5_4 = "gpt-5.4" 67 GPT_5_4_PRO = "gpt-5.4-pro" 68 69 70 # Using Tier 5 rate limits 71 # https://platform.openai.com/settings/organization/limits 72 73 OPENAI_MODEL_INFO_BY_NAME: FrozenMapping[OpenAIModelName, ModelInfo] = FrozenDict( 74 { 75 OpenAIModelName.GPT_4_1: ModelInfo( 76 model_name=str(OpenAIModelName.GPT_4_1), 77 cost_per_input_token=2 / 1_000_000, 78 cost_per_output_token=8 / 1_000_000, 79 max_input_tokens=1_047_576, 80 max_output_tokens=32_768, 81 rate_limit_req=10000 / 60, # 10000 RPM = 166.67 RPS 82 ), 83 OpenAIModelName.GPT_4_1_MINI: ModelInfo( 84 model_name=str(OpenAIModelName.GPT_4_1_MINI), 85 cost_per_input_token=0.4 / 1_000_000, 86 cost_per_output_token=1.6 / 1_000_000, 87 max_input_tokens=1_047_576, 88 max_output_tokens=32_768, 89 rate_limit_req=30000 / 60, # 30000 RPM = 500 RPS 90 ), 91 OpenAIModelName.O3: ModelInfo( 92 model_name=str(OpenAIModelName.O3), 93 cost_per_input_token=2 / 1_000_000, 94 cost_per_output_token=8 / 1_000_000, 95 max_input_tokens=200_000, 96 max_output_tokens=100_000, 97 rate_limit_req=10000 / 60, # 10000 RPM = 166.67 RPS 98 ), 99 OpenAIModelName.O3_MINI: ModelInfo( 100 model_name=str(OpenAIModelName.O3_MINI), 101 cost_per_input_token=1.1 / 1_000_000, 102 cost_per_output_token=4.4 / 1_000_000, 103 max_input_tokens=200_000, 104 max_output_tokens=100_000, 105 rate_limit_req=30000 / 60, # 30000 RPM = 500 RPS 106 ), 107 OpenAIModelName.O4_MINI: ModelInfo( 108 model_name=str(OpenAIModelName.O4_MINI), 109 cost_per_input_token=1.1 / 1_000_000, 110 cost_per_output_token=4.4 / 1_000_000, 111 max_input_tokens=200_000, 112 max_output_tokens=100_000, 113 rate_limit_req=30000 / 60, # 30000 RPM = 500 RPS 114 ), 115 OpenAIModelName.GPT_5: ModelInfo( 116 model_name=str(OpenAIModelName.GPT_5), 117 cost_per_input_token=1.25 / 1_000_000, 118 cost_per_output_token=10 / 1_000_000, 119 max_input_tokens=400_000, 120 max_output_tokens=128_000, 121 rate_limit_req=15000 / 60, # 15000 RPM = 250 RPS 122 ), 123 OpenAIModelName.GPT_5_MINI: ModelInfo( 124 model_name=str(OpenAIModelName.GPT_5_MINI), 125 cost_per_input_token=0.25 / 1_000_000, 126 cost_per_output_token=2.00 / 1_000_000, 127 max_input_tokens=400_000, 128 max_output_tokens=128_000, 129 rate_limit_req=30000 / 60, # 30000 RPM = 500 RPS 130 ), 131 OpenAIModelName.GPT_5_1: ModelInfo( 132 model_name=str(OpenAIModelName.GPT_5_1), 133 cost_per_input_token=1.25 / 1_000_000, 134 cost_per_output_token=10 / 1_000_000, 135 max_input_tokens=400_000, 136 max_output_tokens=128_000, 137 rate_limit_req=15000 / 60, # 15000 RPM = 250 RPS 138 ), 139 OpenAIModelName.GPT_5_2: ModelInfo( 140 model_name=str(OpenAIModelName.GPT_5_2), 141 cost_per_input_token=1.75 / 1_000_000, 142 cost_per_output_token=14 / 1_000_000, 143 max_input_tokens=400_000, 144 max_output_tokens=128_000, 145 rate_limit_req=15000 / 60, # 15000 RPM = 250 RPS 146 ), 147 OpenAIModelName.GPT_5_4: ModelInfo( 148 model_name=str(OpenAIModelName.GPT_5_4), 149 cost_per_input_token=2.50 / 1_000_000, 150 cost_per_output_token=15 / 1_000_000, 151 max_input_tokens=1_050_000, 152 max_output_tokens=128_000, 153 rate_limit_req=15000 / 60, # 15000 RPM = 250 RPS 154 ), 155 OpenAIModelName.GPT_5_4_PRO: ModelInfo( 156 model_name=str(OpenAIModelName.GPT_5_4_PRO), 157 cost_per_input_token=30 / 1_000_000, 158 cost_per_output_token=180 / 1_000_000, 159 max_input_tokens=1_050_000, 160 max_output_tokens=128_000, 161 rate_limit_req=10000 / 60, # 10000 RPM = 166.67 RPS 162 ), 163 } 164 ) 165 166 167 # Pricing for fine-tuned models taken from here: https://platform.openai.com/docs/pricing 168 def get_model_info(model_name: OpenAIModelName) -> ModelInfo: 169 # Check for the family of fine-tuned models. 170 if model_name.startswith(FINE_TUNED_GPT4O_MINI_2024_07_18_PREFIX): 171 return ModelInfo( 172 model_name=str(model_name), 173 cost_per_input_token=0.3 / 1_000_000, 174 cost_per_output_token=1.2 / 1_000_000, 175 max_input_tokens=128_000, 176 max_output_tokens=16_384, 177 rate_limit_req=30000 / 60, # 30000 RPM = 500 RPS (same as base model) 178 ) 179 if model_name.startswith(FINE_TUNED_GPT4O_2024_08_06_PREFIX): 180 return ModelInfo( 181 model_name=str(model_name), 182 cost_per_input_token=3.75 / 1_000_000, 183 cost_per_output_token=15.0 / 1_000_000, 184 max_input_tokens=128_000, 185 max_output_tokens=16_384, 186 rate_limit_req=10000 / 60, # 10000 RPM = 166.67 RPS (same as base model) 187 ) 188 # Otherwise, return the model info for the base model. 189 return OPENAI_MODEL_INFO_BY_NAME[model_name] 190 191 192 _CAPACITY_SEMAPHOR_BY_MODEL_NAME: Mapping[OpenAIModelName, asyncio.Semaphore] = defaultdict( 193 lambda: asyncio.Semaphore(20), 194 ) 195 196 197 def _get_capacity_semaphor(model_name: OpenAIModelName) -> asyncio.Semaphore: 198 # Fine-tuned models share rate limits with the base model. 199 # Note: fine-tuned model prefixes fall through to the defaultdict default. 200 return _CAPACITY_SEMAPHOR_BY_MODEL_NAME[model_name] 201 202 203 def is_openai_reasoning_model(model_name: str) -> bool: 204 return model_name in ( 205 OpenAIModelName.O3, 206 OpenAIModelName.O3_MINI, 207 OpenAIModelName.O4_MINI, 208 OpenAIModelName.GPT_5, 209 OpenAIModelName.GPT_5_MINI, 210 OpenAIModelName.GPT_5_1, 211 OpenAIModelName.GPT_5_2, 212 OpenAIModelName.GPT_5_4, 213 OpenAIModelName.GPT_5_4_PRO, 214 ) 215 216 217 def is_fine_tuned_openai_model(model_name: OpenAIModelName) -> bool: 218 return model_name.value.startswith(FINE_TUNED_GPT4O_MINI_2024_07_18_PREFIX) or model_name.value.startswith( 219 FINE_TUNED_GPT4O_2024_08_06_PREFIX 220 ) 221 222 223 _OPENAI_COMPLETION_ERROR_PATTERN = re.compile( 224 r".*This model's maximum context length is (\d+) tokens, however you requested (\d+) tokens \((\d+) in your prompt; (\d+) for the completion\). Please reduce your prompt; or completion length.*" 225 ) 226 227 _OPENAI_STOP_REASON_TO_STOP_REASON = _OPENAI_COMPATIBLE_STOP_REASON_TO_STOP_REASON 228 229 230 @lru_cache(maxsize=1) 231 def get_openai_tokenizer(model_name: str) -> tiktoken.Encoding: 232 """Get the appropriate tiktoken tokenizer for an OpenAI model. 233 234 Args: 235 model_name: The OpenAI model name (e.g., "gpt-4.1"). 236 237 Returns: 238 The tiktoken Encoding for the model. 239 """ 240 if model_name.startswith("gpt-4"): 241 fixed_model_name = "gpt-4" 242 elif model_name.startswith("gpt-3.5"): 243 fixed_model_name = "gpt-3.5" 244 else: 245 # Just default to `gpt-4o` for now, since this seems to be the most recent tokenizer 246 # and we are only using it for estimating token usage 247 fixed_model_name = "gpt-4o" 248 return tiktoken.encoding_for_model(fixed_model_name) 249 250 251 def count_openai_tokens(text: str, model_name: str) -> int: 252 return len(get_openai_tokenizer(model_name).encode(text, disallowed_special=())) 253 254 255 @contextmanager 256 def _openai_exception_manager() -> Iterator[None]: 257 """Simple context manager for parsing OpenAI API exceptions.""" 258 try: 259 yield 260 except BadRequestError as e: 261 error_text_match = _OPENAI_COMPLETION_ERROR_PATTERN.search(str(e)) 262 if error_text_match is not None: 263 max_prompt_len = int(error_text_match.group(1)) 264 prompt_len = int(error_text_match.group(2)) 265 logger.debug( 266 "PromptTooLongError max_prompt_len={max_prompt_len} prompt_len={prompt_len}", 267 max_prompt_len=max_prompt_len, 268 prompt_len=prompt_len, 269 ) 270 raise PromptTooLongError(prompt_len, max_prompt_len) from e 271 logger.debug("BadAPIRequestError {e}", e=e) 272 raise BadAPIRequestError(str(e)) from e 273 except APIConnectionError as e: 274 logger.debug("Rate limited? Received APIConnectionError {e}", e=e) 275 raise TransientLanguageModelError("APIConnectionError") from e 276 except RateLimitError as e: 277 if e.code == "insufficient_quota": 278 raise 279 logger.debug("Rate limited? {e}", e=e) 280 raise TransientLanguageModelError("RateLimitError") from e 281 except httpx.RemoteProtocolError as e: 282 logger.debug("httpx.RemoteProtocolError {e}", e=e) 283 raise TransientLanguageModelError("httpx.RemoteProtocolError") from e 284 except InternalServerError as e: 285 logger.debug("InternalServerError {e}", e=e) 286 raise TransientLanguageModelError("InternalServerError") from e 287 288 289 class OpenAIChatAPI(OpenAICompatibleAPI): 290 model_name: OpenAIModelName = OpenAIModelName.GPT_4_1 291 292 @field_validator("model_name") # pyre-ignore[56]: pyre doesn't understand pydantic 293 @classmethod 294 def validate_model_name(cls, v: str) -> str: 295 if v not in OPENAI_MODEL_INFO_BY_NAME: 296 raise LanguageModelInvalidModelNameError(v, cls.__name__, list(OPENAI_MODEL_INFO_BY_NAME)) 297 return v 298 299 @property 300 def model_info(self) -> ModelInfo: 301 return get_model_info(self.model_name) 302 303 def _get_client(self) -> AsyncOpenAI: 304 api_key = get_secret("OPENAI_API_KEY") 305 if not api_key: 306 raise MissingAPIKeyError("OPENAI_API_KEY environment variable is not set") 307 return AsyncOpenAI( # pyre-ignore[16]: pyre doesn't understand the auto-generated openai._client 308 api_key=api_key 309 ) 310 311 async def _call_api( 312 self, 313 prompt: str, 314 params: LanguageModelGenerationParams, 315 network_failure_count: int = 0, 316 ) -> CostedLanguageModelResponse: 317 messages = convert_prompt_to_openai_messages(prompt) 318 with _openai_exception_manager(): 319 client = self._get_client() 320 321 is_reasoning_model = is_openai_reasoning_model(self.model_name) 322 323 top_logprobs: NotGiven | int 324 if self.is_using_logprobs: 325 assert not is_reasoning_model, "Logprobs are not supported for reasoning models." 326 top_logprobs = 5 327 else: 328 top_logprobs = NOT_GIVEN 329 330 temperature: NotGiven | float = NOT_GIVEN if is_reasoning_model else params.temperature 331 332 async with _get_capacity_semaphor(self.model_name): 333 api_result = await client.chat.completions.create( 334 model=self.model_name, 335 messages=messages, # type: ignore 336 max_completion_tokens=params.max_tokens, 337 n=params.count, 338 temperature=temperature, 339 stream=False, 340 seed=params.seed, 341 stop=params.stop, 342 presence_penalty=self.presence_penalty, 343 logprobs=self.is_using_logprobs, 344 top_logprobs=top_logprobs, 345 ) 346 assert isinstance(api_result, ChatCompletion) 347 348 usage = api_result.usage 349 if usage is not None: 350 completion_tokens = usage.completion_tokens 351 prompt_tokens = usage.prompt_tokens 352 cached_tokens = ( 353 usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details is not None else 0 354 ) or 0 355 caching_info = CachingInfo( 356 read_from_cache=cached_tokens, 357 provider_specific_data=OpenAICachingInfo(), 358 ) 359 else: 360 completion_tokens = 0 361 prompt_tokens = self.count_tokens(prompt) 362 cached_tokens = None 363 caching_info = None 364 365 results: tuple[LanguageModelResponse | LanguageModelResponseWithLogits, ...] 366 if self.is_using_logprobs: 367 results = self._parse_response_with_logprobs( 368 api_result, 369 prompt_tokens=prompt_tokens, 370 stop=params.stop, 371 network_failure_count=network_failure_count, 372 ) 373 else: 374 results = self._parse_response_without_logprobs( 375 api_result, 376 prompt_tokens=prompt_tokens, 377 stop=params.stop, 378 network_failure_count=network_failure_count, 379 ) 380 381 logger.trace("text: {text}", text=results[0].text) 382 dollars_used = self.calculate_cost(prompt_tokens, completion_tokens) 383 logger.trace("dollars used: {dollars_used}", dollars_used=dollars_used) 384 return CostedLanguageModelResponse( 385 usage=LanguageModelResponseUsage( 386 prompt_tokens_used=prompt_tokens, 387 completion_tokens_used=completion_tokens, 388 dollars_used=dollars_used, 389 caching_info=caching_info, 390 ), 391 responses=tuple(results), 392 ) 393 394 async def _get_api_stream( 395 self, 396 prompt: str, 397 params: LanguageModelGenerationParams, 398 ) -> AsyncGenerator[LanguageModelStreamEvent, None]: 399 messages = convert_prompt_to_openai_messages(prompt) 400 with _openai_exception_manager(): 401 client = self._get_client() 402 403 is_reasoning_model = is_openai_reasoning_model(self.model_name) 404 temperature: NotGiven | float = NOT_GIVEN if is_reasoning_model else params.temperature 405 406 async with _get_capacity_semaphor(self.model_name): 407 api_result = await client.chat.completions.create( 408 model=self.model_name, 409 messages=messages, # type: ignore 410 max_completion_tokens=params.max_tokens, 411 n=1, 412 temperature=temperature, 413 stop=params.stop, 414 seed=params.seed, 415 stream=True, 416 stream_options={"include_usage": True}, 417 presence_penalty=self.presence_penalty, 418 logprobs=False, # not used when streaming 419 top_logprobs=NOT_GIVEN, # only allowed when logprobs=True 420 ) 421 assert isinstance(api_result, AsyncStream) 422 423 yield LanguageModelStreamStartEvent() 424 425 usage = None 426 finish_reason: str | None = None 427 async for chunk in api_result: 428 if hasattr(chunk, "usage") and chunk.usage is not None: 429 # final chunk containing usage info after all streaming is done 430 usage = chunk.usage 431 continue 432 433 if chunk.choices: 434 assert len(chunk.choices) == 1, "Currently only count=1 supported for streaming API." 435 data = only(chunk.choices) 436 delta = data.delta.content 437 if delta is not None: 438 yield LanguageModelStreamDeltaEvent(delta=delta) 439 if data.finish_reason: 440 finish_reason = str(data.finish_reason) 441 442 stop_reason = _OPENAI_STOP_REASON_TO_STOP_REASON[str(finish_reason)] 443 # Note, OpenAI API treats end turn and stop sequence the same 444 # Here we assume it is stop sequence if user has specified a stop sequence 445 if params.stop is not None and stop_reason == ResponseStopReason.END_TURN: 446 yield LanguageModelStreamDeltaEvent(delta=params.stop) 447 448 if usage is not None: 449 completion_tokens = usage.completion_tokens 450 prompt_tokens = usage.prompt_tokens 451 dollars_used = self.calculate_cost(prompt_tokens, completion_tokens) 452 cached_tokens = usage.prompt_tokens_details.cached_tokens 453 logger.trace( 454 "Used this many cached read tokens: {cached_tokens}", 455 cached_tokens=cached_tokens, 456 ) 457 caching_info = CachingInfo( 458 read_from_cache=cached_tokens, 459 provider_specific_data=OpenAICachingInfo(), 460 ) 461 else: 462 completion_tokens = -1 463 prompt_tokens = -1 464 dollars_used = -1 465 caching_info = None 466 logger.trace("dollars used: {dollars_used}", dollars_used=dollars_used) 467 468 yield LanguageModelStreamEndEvent( 469 usage=LanguageModelResponseUsage( 470 prompt_tokens_used=prompt_tokens, 471 completion_tokens_used=completion_tokens, 472 dollars_used=dollars_used, 473 caching_info=caching_info, 474 ), 475 stop_reason=stop_reason, 476 ) 477 478 def count_tokens(self, text: str) -> int: 479 return count_openai_tokens(text, self.model_name) 480 481 def _parse_response_without_logprobs( 482 self, 483 response: ChatCompletion, 484 prompt_tokens: int, 485 stop: str | None, 486 network_failure_count: int, 487 ) -> tuple[LanguageModelResponse, ...]: 488 results = [] 489 for data in response.choices: 490 assert data.message.content is not None 491 text = data.message.content 492 token_count = self.count_tokens(text) + prompt_tokens 493 stop_reason = _OPENAI_STOP_REASON_TO_STOP_REASON[str(data.finish_reason)] 494 # Note, OpenAI API treats end turn and stop sequence the same 495 # Here we assume it is stop sequence if user has specified a stop sequence 496 if stop is not None and stop_reason == ResponseStopReason.END_TURN: 497 text += stop 498 result = LanguageModelResponse( 499 text=text, 500 token_count=token_count, 501 stop_reason=stop_reason, 502 network_failure_count=network_failure_count, 503 ) 504 results.append(result) 505 return tuple(results) 506 507 def _parse_response_with_logprobs( 508 self, 509 response: ChatCompletion, 510 prompt_tokens: int, 511 stop: str | None, 512 network_failure_count: int, 513 ) -> tuple[LanguageModelResponseWithLogits, ...]: 514 results = [] 515 for data in response.choices: 516 assert data.message.content is not None 517 logprobs = data.logprobs 518 assert logprobs is not None 519 logprobs_content = logprobs.content 520 assert logprobs_content is not None 521 text = data.message.content 522 523 token_probabilities = [] 524 for logprob_token_entry in logprobs_content: 525 top_logprobs = logprob_token_entry.top_logprobs 526 top_entries = [ 527 TokenProbability( 528 token=top_logprob_obj.token, 529 log_probability=top_logprob_obj.logprob, 530 is_stop=False, 531 ) 532 for top_logprob_obj in top_logprobs 533 ] 534 selected_entry = TokenProbability( 535 token=logprob_token_entry.token, 536 log_probability=logprob_token_entry.logprob, 537 is_stop=False, 538 ) 539 if selected_entry in top_entries: 540 top_entries.remove(selected_entry) 541 token_probabilities.append(tuple([selected_entry] + top_entries)) 542 543 stop_reason = _OPENAI_STOP_REASON_TO_STOP_REASON[str(data.finish_reason)] 544 545 # Note, OpenAI API treats end turn and stop sequence the same 546 # Here we assume it is stop sequence if user has specified a stop sequence 547 if stop is not None and stop_reason == ResponseStopReason.END_TURN: 548 text += stop 549 token_probabilities.append( 550 tuple( 551 [ 552 TokenProbability( 553 token=stop, 554 log_probability=self.stop_token_log_probability, 555 is_stop=True, 556 ) 557 ] 558 ) 559 ) 560 result = LanguageModelResponseWithLogits( 561 text=text, 562 token_probabilities=tuple(token_probabilities), 563 token_count=len(logprobs_content) + prompt_tokens, 564 stop_reason=stop_reason, 565 network_failure_count=network_failure_count, 566 ) 567 results.append(result) 568 return tuple(results)