vet

Mirror of Vet, an AI code review tool
git clone git://git.laack.co/vet.git
Log | Files | Refs | README | LICENSE

common.py (2489B)


      1 from vet.imbue_core.agents.llm_apis.anthropic_api import ANTHROPIC_MODEL_INFO_BY_NAME
      2 from vet.imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
      3 from vet.imbue_core.agents.llm_apis.gemini_api import GEMINI_MODEL_INFO_BY_NAME
      4 from vet.imbue_core.agents.llm_apis.gemini_api import GeminiModelName
      5 from vet.imbue_core.agents.llm_apis.mock_api import MY_MOCK_MODEL_INFO
      6 from vet.imbue_core.agents.llm_apis.models import ModelInfo
      7 from vet.imbue_core.agents.llm_apis.openai_api import OpenAIModelName
      8 from vet.imbue_core.agents.llm_apis.openai_api import get_model_info as get_openai_model_info
      9 
     10 ModelName = AnthropicModelName | OpenAIModelName | GeminiModelName
     11 
     12 
     13 def get_model_info_from_name(model_name: str) -> ModelInfo:
     14     if model_name == MY_MOCK_MODEL_INFO.model_name:
     15         return MY_MOCK_MODEL_INFO
     16     if model_name in (v for v in AnthropicModelName):
     17         return ANTHROPIC_MODEL_INFO_BY_NAME[AnthropicModelName(model_name)]
     18     elif model_name in (v for v in OpenAIModelName):
     19         return get_openai_model_info(OpenAIModelName(model_name))
     20     elif model_name in (v for v in GeminiModelName):
     21         return GEMINI_MODEL_INFO_BY_NAME[GeminiModelName(model_name)]
     22     else:
     23         raise Exception(f"Unknown model: {model_name}")
     24 
     25 
     26 def get_model_max_context_length(model_name: str) -> int:
     27     model_info = get_model_info_from_name(model_name)
     28     return model_info.max_input_tokens
     29 
     30 
     31 def get_model_max_output_tokens(model_name: str) -> int:
     32     model_info = get_model_info_from_name(model_name)
     33     if model_info.max_output_tokens is None:
     34         raise ValueError(f"Model {model_name} does not have max_output_tokens defined")
     35     return model_info.max_output_tokens
     36 
     37 
     38 def get_all_model_names() -> list[str]:
     39     names = []
     40     names.extend(list(v for v in AnthropicModelName))
     41     names.extend(list(v for v in OpenAIModelName))
     42     names.extend(list(v for v in GeminiModelName))
     43     return names
     44 
     45 
     46 def get_formatted_model_name(model_name: str) -> str:
     47     """Get a nicely formatted model name.
     48 
     49         Does things like removing generic prefixes like 'models/' and forward slashes (which can interfere with file names).
     50 
     51         Some examples:
     52 
     53     - 'models/gemini-2.5-flash' -> 'gemini-2.5-flash'
     54     - 'groq/llama-3.3-70b-versatile' -> 'groq-llama-3.3-70b-versatile'
     55     - 'claude-opus-4-6' -> 'claude-opus-4-6'
     56 
     57     """
     58     if model_name.startswith("models/"):
     59         model_name = model_name[len("models/") :]
     60     return model_name.replace("/", "-")