common.py (2489B)
1 from vet.imbue_core.agents.llm_apis.anthropic_api import ANTHROPIC_MODEL_INFO_BY_NAME 2 from vet.imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName 3 from vet.imbue_core.agents.llm_apis.gemini_api import GEMINI_MODEL_INFO_BY_NAME 4 from vet.imbue_core.agents.llm_apis.gemini_api import GeminiModelName 5 from vet.imbue_core.agents.llm_apis.mock_api import MY_MOCK_MODEL_INFO 6 from vet.imbue_core.agents.llm_apis.models import ModelInfo 7 from vet.imbue_core.agents.llm_apis.openai_api import OpenAIModelName 8 from vet.imbue_core.agents.llm_apis.openai_api import get_model_info as get_openai_model_info 9 10 ModelName = AnthropicModelName | OpenAIModelName | GeminiModelName 11 12 13 def get_model_info_from_name(model_name: str) -> ModelInfo: 14 if model_name == MY_MOCK_MODEL_INFO.model_name: 15 return MY_MOCK_MODEL_INFO 16 if model_name in (v for v in AnthropicModelName): 17 return ANTHROPIC_MODEL_INFO_BY_NAME[AnthropicModelName(model_name)] 18 elif model_name in (v for v in OpenAIModelName): 19 return get_openai_model_info(OpenAIModelName(model_name)) 20 elif model_name in (v for v in GeminiModelName): 21 return GEMINI_MODEL_INFO_BY_NAME[GeminiModelName(model_name)] 22 else: 23 raise Exception(f"Unknown model: {model_name}") 24 25 26 def get_model_max_context_length(model_name: str) -> int: 27 model_info = get_model_info_from_name(model_name) 28 return model_info.max_input_tokens 29 30 31 def get_model_max_output_tokens(model_name: str) -> int: 32 model_info = get_model_info_from_name(model_name) 33 if model_info.max_output_tokens is None: 34 raise ValueError(f"Model {model_name} does not have max_output_tokens defined") 35 return model_info.max_output_tokens 36 37 38 def get_all_model_names() -> list[str]: 39 names = [] 40 names.extend(list(v for v in AnthropicModelName)) 41 names.extend(list(v for v in OpenAIModelName)) 42 names.extend(list(v for v in GeminiModelName)) 43 return names 44 45 46 def get_formatted_model_name(model_name: str) -> str: 47 """Get a nicely formatted model name. 48 49 Does things like removing generic prefixes like 'models/' and forward slashes (which can interfere with file names). 50 51 Some examples: 52 53 - 'models/gemini-2.5-flash' -> 'gemini-2.5-flash' 54 - 'groq/llama-3.3-70b-versatile' -> 'groq-llama-3.3-70b-versatile' 55 - 'claude-opus-4-6' -> 'claude-opus-4-6' 56 57 """ 58 if model_name.startswith("models/"): 59 model_name = model_name[len("models/") :] 60 return model_name.replace("/", "-")