vet

Mirror of Vet, an AI code review tool
git clone git://git.laack.co/vet.git
Log | Files | Refs | README | LICENSE

models_test.py (8513B)


      1 from __future__ import annotations
      2 
      3 import pytest
      4 
      5 from vet.cli.config.schema import ModelConfig
      6 from vet.cli.config.schema import ModelsConfig
      7 from vet.cli.config.schema import ProviderConfig
      8 from vet.cli.models import DEFAULT_MODEL_ID
      9 from vet.cli.models import get_all_model_ids
     10 from vet.cli.models import get_builtin_model_ids
     11 from vet.cli.models import get_builtin_models_by_provider
     12 from vet.cli.models import get_models_by_provider
     13 from vet.cli.models import is_valid_model_id
     14 from vet.cli.models import validate_model_id
     15 
     16 SAMPLE_USER_CONFIG = ModelsConfig(
     17     providers={
     18         "custom": ProviderConfig(
     19             base_url="http://localhost:8080/v1",
     20             api_key_env="CUSTOM_KEY",
     21             models={
     22                 "my-custom-model": ModelConfig(
     23                     context_window=128000,
     24                     max_output_tokens=16384,
     25                     supports_temperature=True,
     26                 ),
     27                 "another-model": ModelConfig(
     28                     context_window=128000,
     29                     max_output_tokens=16384,
     30                     supports_temperature=True,
     31                 ),
     32             },
     33         )
     34     }
     35 )
     36 
     37 
     38 def test_default_model_is_in_builtin_models() -> None:
     39     assert DEFAULT_MODEL_ID in get_builtin_model_ids()
     40 
     41 
     42 def test_get_builtin_model_ids_returns_strings() -> None:
     43     model_ids = get_builtin_model_ids()
     44     assert all(isinstance(m, str) for m in model_ids)
     45 
     46 
     47 def test_get_all_model_ids_returns_builtin_models_when_no_config() -> None:
     48     all_ids = get_all_model_ids(user_config=None)
     49     builtin_ids = get_builtin_model_ids()
     50     assert all_ids == builtin_ids
     51 
     52 
     53 def test_get_all_model_ids_includes_user_defined_models() -> None:
     54     all_ids = get_all_model_ids(SAMPLE_USER_CONFIG)
     55 
     56     assert "my-custom-model" in all_ids
     57     assert "another-model" in all_ids
     58     assert DEFAULT_MODEL_ID in all_ids
     59 
     60 
     61 @pytest.mark.parametrize(
     62     ("model_id", "user_config", "expected"),
     63     [
     64         (DEFAULT_MODEL_ID, None, True),
     65         ("nonexistent-model-xyz", None, False),
     66         ("my-custom-model", SAMPLE_USER_CONFIG, True),
     67     ],
     68 )
     69 def test_is_valid_model_id(model_id: str, user_config: ModelsConfig | None, expected: bool) -> None:
     70     assert is_valid_model_id(model_id, user_config) is expected
     71 
     72 
     73 def test_validate_model_id_returns_model_id_when_valid() -> None:
     74     result = validate_model_id(DEFAULT_MODEL_ID)
     75     assert result == DEFAULT_MODEL_ID
     76 
     77 
     78 def test_validate_model_id_raises_for_invalid_model() -> None:
     79     with pytest.raises(ValueError) as exc_info:
     80         validate_model_id("nonexistent-model-xyz")
     81 
     82     assert "Unknown model: nonexistent-model-xyz" in str(exc_info.value)
     83     assert "--list-models" in str(exc_info.value)
     84 
     85 
     86 def test_validate_model_id_validates_user_defined_model() -> None:
     87     user_config = ModelsConfig(
     88         providers={
     89             "custom": ProviderConfig(
     90                 base_url="http://localhost:8080/v1",
     91                 api_key_env="CUSTOM_KEY",
     92                 models={
     93                     "my-custom-model": ModelConfig(
     94                         context_window=128000,
     95                         max_output_tokens=16384,
     96                         supports_temperature=True,
     97                     )
     98                 },
     99             )
    100         }
    101     )
    102 
    103     result = validate_model_id("my-custom-model", user_config)
    104     assert result == "my-custom-model"
    105 
    106 
    107 def test_get_builtin_models_by_provider_returns_dict_with_expected_providers() -> None:
    108     providers = get_builtin_models_by_provider()
    109 
    110     assert "anthropic" in providers
    111     assert "openai" in providers
    112     assert "gemini" in providers
    113     assert "groq" not in providers
    114 
    115 
    116 def test_get_builtin_models_by_provider_all_values_are_lists_of_strings() -> None:
    117     providers = get_builtin_models_by_provider()
    118 
    119     for provider_name, models in providers.items():
    120         assert isinstance(models, list), f"{provider_name} should have a list of models"
    121         assert all(isinstance(m, str) for m in models), f"{provider_name} models should all be strings"
    122 
    123 
    124 def test_get_models_by_provider_returns_builtin_providers_when_no_config() -> None:
    125     providers = get_models_by_provider(user_config=None)
    126     builtin_providers = get_builtin_models_by_provider()
    127 
    128     assert providers == builtin_providers
    129 
    130 
    131 def test_get_models_by_provider_includes_user_defined_providers() -> None:
    132     user_config = ModelsConfig(
    133         providers={
    134             "ollama": ProviderConfig(
    135                 name="Ollama Local",
    136                 base_url="http://localhost:11434/v1",
    137                 api_key_env="OLLAMA_KEY",
    138                 models={
    139                     "llama3.2:latest": ModelConfig(
    140                         context_window=128000,
    141                         max_output_tokens=16384,
    142                         supports_temperature=True,
    143                     ),
    144                     "qwen:7b": ModelConfig(
    145                         context_window=32768,
    146                         max_output_tokens=8192,
    147                         supports_temperature=True,
    148                     ),
    149                 },
    150             )
    151         }
    152     )
    153 
    154     providers = get_models_by_provider(user_config)
    155 
    156     assert "Ollama Local" in providers
    157     assert set(providers["Ollama Local"]) == {"llama3.2:latest", "qwen:7b"}
    158     assert "anthropic" in providers
    159     assert "openai" in providers
    160 
    161 
    162 def test_get_models_by_provider_user_provider_merges_with_builtin_same_name() -> None:
    163     user_config = ModelsConfig(
    164         providers={
    165             "custom": ProviderConfig(
    166                 name="anthropic",
    167                 base_url="http://localhost:8080/v1",
    168                 api_key_env="CUSTOM_KEY",
    169                 models={
    170                     "custom-model": ModelConfig(
    171                         context_window=128000,
    172                         max_output_tokens=16384,
    173                         supports_temperature=True,
    174                     )
    175                 },
    176             )
    177         }
    178     )
    179 
    180     providers = get_models_by_provider(user_config)
    181 
    182     assert "custom-model" in providers["anthropic"]
    183     assert DEFAULT_MODEL_ID in providers["anthropic"]
    184 
    185 
    186 SAMPLE_REGISTRY_CONFIG = ModelsConfig(
    187     providers={
    188         "registry-provider": ProviderConfig(
    189             name="Registry Provider",
    190             base_url="http://registry:8080/v1",
    191             api_key_env="REGISTRY_KEY",
    192             models={
    193                 "registry-model": ModelConfig(
    194                     context_window=128000,
    195                     max_output_tokens=16384,
    196                     supports_temperature=True,
    197                 ),
    198             },
    199         )
    200     }
    201 )
    202 
    203 
    204 def test_get_all_model_ids_includes_registry_models() -> None:
    205     all_ids = get_all_model_ids(
    206         user_config=SAMPLE_USER_CONFIG,
    207         registry_config=SAMPLE_REGISTRY_CONFIG,
    208     )
    209 
    210     assert "my-custom-model" in all_ids
    211     assert DEFAULT_MODEL_ID in all_ids
    212     assert "registry-model" in all_ids
    213 
    214 
    215 def test_validate_model_id_accepts_registry_model() -> None:
    216     result = validate_model_id(
    217         "registry-model",
    218         user_config=None,
    219         registry_config=SAMPLE_REGISTRY_CONFIG,
    220     )
    221     assert result == "registry-model"
    222 
    223 
    224 def test_validate_model_id_rejects_unknown_even_with_registry() -> None:
    225     with pytest.raises(ValueError):
    226         validate_model_id(
    227             "totally-unknown",
    228             user_config=SAMPLE_USER_CONFIG,
    229             registry_config=SAMPLE_REGISTRY_CONFIG,
    230         )
    231 
    232 
    233 def test_get_models_by_provider_includes_registry_providers() -> None:
    234     providers = get_models_by_provider(
    235         user_config=None,
    236         registry_config=SAMPLE_REGISTRY_CONFIG,
    237     )
    238 
    239     assert "Registry Provider" in providers
    240     assert "registry-model" in providers["Registry Provider"]
    241     assert "anthropic" in providers
    242     assert "openai" in providers
    243 
    244 
    245 def test_get_models_by_provider_registry_merges_with_builtin_same_name() -> None:
    246     registry_config = ModelsConfig(
    247         providers={
    248             "anthropic-override": ProviderConfig(
    249                 name="anthropic",
    250                 base_url="http://registry:8080/v1",
    251                 models={
    252                     "registry-claude": ModelConfig(
    253                         context_window=128000,
    254                         max_output_tokens=16384,
    255                         supports_temperature=True,
    256                     )
    257                 },
    258             )
    259         }
    260     )
    261 
    262     providers = get_models_by_provider(user_config=None, registry_config=registry_config)
    263 
    264     assert "registry-claude" in providers["anthropic"]
    265     assert DEFAULT_MODEL_ID in providers["anthropic"]