models_test.py (8513B)
1 from __future__ import annotations 2 3 import pytest 4 5 from vet.cli.config.schema import ModelConfig 6 from vet.cli.config.schema import ModelsConfig 7 from vet.cli.config.schema import ProviderConfig 8 from vet.cli.models import DEFAULT_MODEL_ID 9 from vet.cli.models import get_all_model_ids 10 from vet.cli.models import get_builtin_model_ids 11 from vet.cli.models import get_builtin_models_by_provider 12 from vet.cli.models import get_models_by_provider 13 from vet.cli.models import is_valid_model_id 14 from vet.cli.models import validate_model_id 15 16 SAMPLE_USER_CONFIG = ModelsConfig( 17 providers={ 18 "custom": ProviderConfig( 19 base_url="http://localhost:8080/v1", 20 api_key_env="CUSTOM_KEY", 21 models={ 22 "my-custom-model": ModelConfig( 23 context_window=128000, 24 max_output_tokens=16384, 25 supports_temperature=True, 26 ), 27 "another-model": ModelConfig( 28 context_window=128000, 29 max_output_tokens=16384, 30 supports_temperature=True, 31 ), 32 }, 33 ) 34 } 35 ) 36 37 38 def test_default_model_is_in_builtin_models() -> None: 39 assert DEFAULT_MODEL_ID in get_builtin_model_ids() 40 41 42 def test_get_builtin_model_ids_returns_strings() -> None: 43 model_ids = get_builtin_model_ids() 44 assert all(isinstance(m, str) for m in model_ids) 45 46 47 def test_get_all_model_ids_returns_builtin_models_when_no_config() -> None: 48 all_ids = get_all_model_ids(user_config=None) 49 builtin_ids = get_builtin_model_ids() 50 assert all_ids == builtin_ids 51 52 53 def test_get_all_model_ids_includes_user_defined_models() -> None: 54 all_ids = get_all_model_ids(SAMPLE_USER_CONFIG) 55 56 assert "my-custom-model" in all_ids 57 assert "another-model" in all_ids 58 assert DEFAULT_MODEL_ID in all_ids 59 60 61 @pytest.mark.parametrize( 62 ("model_id", "user_config", "expected"), 63 [ 64 (DEFAULT_MODEL_ID, None, True), 65 ("nonexistent-model-xyz", None, False), 66 ("my-custom-model", SAMPLE_USER_CONFIG, True), 67 ], 68 ) 69 def test_is_valid_model_id(model_id: str, user_config: ModelsConfig | None, expected: bool) -> None: 70 assert is_valid_model_id(model_id, user_config) is expected 71 72 73 def test_validate_model_id_returns_model_id_when_valid() -> None: 74 result = validate_model_id(DEFAULT_MODEL_ID) 75 assert result == DEFAULT_MODEL_ID 76 77 78 def test_validate_model_id_raises_for_invalid_model() -> None: 79 with pytest.raises(ValueError) as exc_info: 80 validate_model_id("nonexistent-model-xyz") 81 82 assert "Unknown model: nonexistent-model-xyz" in str(exc_info.value) 83 assert "--list-models" in str(exc_info.value) 84 85 86 def test_validate_model_id_validates_user_defined_model() -> None: 87 user_config = ModelsConfig( 88 providers={ 89 "custom": ProviderConfig( 90 base_url="http://localhost:8080/v1", 91 api_key_env="CUSTOM_KEY", 92 models={ 93 "my-custom-model": ModelConfig( 94 context_window=128000, 95 max_output_tokens=16384, 96 supports_temperature=True, 97 ) 98 }, 99 ) 100 } 101 ) 102 103 result = validate_model_id("my-custom-model", user_config) 104 assert result == "my-custom-model" 105 106 107 def test_get_builtin_models_by_provider_returns_dict_with_expected_providers() -> None: 108 providers = get_builtin_models_by_provider() 109 110 assert "anthropic" in providers 111 assert "openai" in providers 112 assert "gemini" in providers 113 assert "groq" not in providers 114 115 116 def test_get_builtin_models_by_provider_all_values_are_lists_of_strings() -> None: 117 providers = get_builtin_models_by_provider() 118 119 for provider_name, models in providers.items(): 120 assert isinstance(models, list), f"{provider_name} should have a list of models" 121 assert all(isinstance(m, str) for m in models), f"{provider_name} models should all be strings" 122 123 124 def test_get_models_by_provider_returns_builtin_providers_when_no_config() -> None: 125 providers = get_models_by_provider(user_config=None) 126 builtin_providers = get_builtin_models_by_provider() 127 128 assert providers == builtin_providers 129 130 131 def test_get_models_by_provider_includes_user_defined_providers() -> None: 132 user_config = ModelsConfig( 133 providers={ 134 "ollama": ProviderConfig( 135 name="Ollama Local", 136 base_url="http://localhost:11434/v1", 137 api_key_env="OLLAMA_KEY", 138 models={ 139 "llama3.2:latest": ModelConfig( 140 context_window=128000, 141 max_output_tokens=16384, 142 supports_temperature=True, 143 ), 144 "qwen:7b": ModelConfig( 145 context_window=32768, 146 max_output_tokens=8192, 147 supports_temperature=True, 148 ), 149 }, 150 ) 151 } 152 ) 153 154 providers = get_models_by_provider(user_config) 155 156 assert "Ollama Local" in providers 157 assert set(providers["Ollama Local"]) == {"llama3.2:latest", "qwen:7b"} 158 assert "anthropic" in providers 159 assert "openai" in providers 160 161 162 def test_get_models_by_provider_user_provider_merges_with_builtin_same_name() -> None: 163 user_config = ModelsConfig( 164 providers={ 165 "custom": ProviderConfig( 166 name="anthropic", 167 base_url="http://localhost:8080/v1", 168 api_key_env="CUSTOM_KEY", 169 models={ 170 "custom-model": ModelConfig( 171 context_window=128000, 172 max_output_tokens=16384, 173 supports_temperature=True, 174 ) 175 }, 176 ) 177 } 178 ) 179 180 providers = get_models_by_provider(user_config) 181 182 assert "custom-model" in providers["anthropic"] 183 assert DEFAULT_MODEL_ID in providers["anthropic"] 184 185 186 SAMPLE_REGISTRY_CONFIG = ModelsConfig( 187 providers={ 188 "registry-provider": ProviderConfig( 189 name="Registry Provider", 190 base_url="http://registry:8080/v1", 191 api_key_env="REGISTRY_KEY", 192 models={ 193 "registry-model": ModelConfig( 194 context_window=128000, 195 max_output_tokens=16384, 196 supports_temperature=True, 197 ), 198 }, 199 ) 200 } 201 ) 202 203 204 def test_get_all_model_ids_includes_registry_models() -> None: 205 all_ids = get_all_model_ids( 206 user_config=SAMPLE_USER_CONFIG, 207 registry_config=SAMPLE_REGISTRY_CONFIG, 208 ) 209 210 assert "my-custom-model" in all_ids 211 assert DEFAULT_MODEL_ID in all_ids 212 assert "registry-model" in all_ids 213 214 215 def test_validate_model_id_accepts_registry_model() -> None: 216 result = validate_model_id( 217 "registry-model", 218 user_config=None, 219 registry_config=SAMPLE_REGISTRY_CONFIG, 220 ) 221 assert result == "registry-model" 222 223 224 def test_validate_model_id_rejects_unknown_even_with_registry() -> None: 225 with pytest.raises(ValueError): 226 validate_model_id( 227 "totally-unknown", 228 user_config=SAMPLE_USER_CONFIG, 229 registry_config=SAMPLE_REGISTRY_CONFIG, 230 ) 231 232 233 def test_get_models_by_provider_includes_registry_providers() -> None: 234 providers = get_models_by_provider( 235 user_config=None, 236 registry_config=SAMPLE_REGISTRY_CONFIG, 237 ) 238 239 assert "Registry Provider" in providers 240 assert "registry-model" in providers["Registry Provider"] 241 assert "anthropic" in providers 242 assert "openai" in providers 243 244 245 def test_get_models_by_provider_registry_merges_with_builtin_same_name() -> None: 246 registry_config = ModelsConfig( 247 providers={ 248 "anthropic-override": ProviderConfig( 249 name="anthropic", 250 base_url="http://registry:8080/v1", 251 models={ 252 "registry-claude": ModelConfig( 253 context_window=128000, 254 max_output_tokens=16384, 255 supports_temperature=True, 256 ) 257 }, 258 ) 259 } 260 ) 261 262 providers = get_models_by_provider(user_config=None, registry_config=registry_config) 263 264 assert "registry-claude" in providers["anthropic"] 265 assert DEFAULT_MODEL_ID in providers["anthropic"]