commit fbbdfff866778b0d63665286d0ed927c3496dcb1
parent 3389c1286b23550c23449588aecdd028e62b4ac1
Author: andrewlaack-collab <andrew.laack@imbue.com>
Date: Sun, 1 Mar 2026 22:03:57 +0000
Added updating model registry (#156)
* Added updating model registry
* Format loader.py and loader_test.py with black
* Implement model precedence: user-defined > built-in > registry
* Fix vet-identified issues in remote registry feature
- Add logging to silent except block in _refresh_remote_registry_cache
- Remove unused urllib.error import
- Wrap load_registry_config() in try/except in main() to handle corrupted cache
- Rename get_user_defined_model_ids to get_model_ids_from_config (generic usage)
- Reuse get_builtin_model_ids() in _is_builtin_model instead of reimplementing
- Add tests for validate_api_key_for_model and get_max_output_tokens_for_model with registry_config
* No auto-updating, hidden cli argument by default
* Added additional tests
* Refactor
* Refactoring + added contribution guidelines for what models should be in the registry
* Fix time based tests. Slows down non-llm options too much to do the other way
* Updated wording
* Refactoring
* Messaging
---------
Co-authored-by: Andrew Laack <andrew@laack.co>
Co-authored-by: OpenCode <opencode@users.noreply.github.com>
Diffstat:
10 files changed, 834 insertions(+), 121 deletions(-)
diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md
@@ -104,6 +104,10 @@ Based on your needs, instead of the above, you can also extend one of the existi
In that case you would simply expand the rubric in the prompt. That is actually the preferred way to catch issues at the moment due to efficiency.
Refer to the source code for more details.
+### Model registry
+
+The `registry/models.json` file contains model definitions distributed via the `--update-models` CLI option. See [`registry/CONTRIBUTING.md`](registry/CONTRIBUTING.md) for expectations about what models should be added to the registry.
+
## CI / CD
### GitHub Actions naming conventions
diff --git a/registry/CONTRIBUTING.md b/registry/CONTRIBUTING.md
@@ -0,0 +1,13 @@
+# Registry Model Standards
+
+Models added to `registry/models.json` should meet the following bar:
+
+1. **Forwards compatibility.** This registry exists primarily so older versions of vet get access to new models. When built-in models are added, they should be added here if endpoint compatibility allows it.
+
+2. **Produces useful output.** The models should be able to complete a vet run and produce at least some actionable findings. It does not need to catch everything, but it should not consistently mis-identify issues.
+
+3. **Runs reliably.** The API endpoint must be stable with no consistent failures, timeouts, or malformed responses during normal usage.
+
+## Limitations
+
+Registry models are routed through a generic OpenAI-compatible API layer, not the native provider-specific API classes used by built-in models. This means features like cost tracking, rate limiting, and provider-specific error handling are not available for registry models. The registry is a lightweight "tide you over" mechanism for making new model IDs available before they are added as builtins, not a full-featured alternative.
diff --git a/registry/models.json b/registry/models.json
@@ -0,0 +1,18 @@
+{
+ "providers": {
+ "groq": {
+ "name": "Groq",
+ "api_type": "openai_compatible",
+ "base_url": "https://api.groq.com/openai/v1",
+ "api_key_env": "GROQ_API_KEY",
+ "models": {
+ "kimi-k2": {
+ "model_id": "moonshotai/kimi-k2-instruct-0905",
+ "context_window": 262144,
+ "max_output_tokens": 16384,
+ "supports_temperature": true
+ }
+ }
+ }
+ }
+}
diff --git a/vet/cli/config/loader.py b/vet/cli/config/loader.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import os
import tomllib
+import urllib.request
from pathlib import Path
from pydantic import ValidationError
@@ -20,15 +21,8 @@ class ConfigLoadError(Exception):
pass
-class MissingAPIKeyError(Exception):
- def __init__(self, env_var: str, provider_name: str, model_id: str) -> None:
- self.env_var = env_var
- self.provider_name = provider_name
- self.model_id = model_id
- super().__init__(
- f"API key not found: environment variable '{env_var}' is not set. "
- + f"This is required for model '{model_id}' from provider '{provider_name}'."
- )
+_DEFAULT_REGISTRY_URL = "https://raw.githubusercontent.com/imbue-ai/vet/main/registry/models.json"
+_REGISTRY_FETCH_TIMEOUT_SECONDS = 5
def get_xdg_config_home() -> Path:
@@ -38,6 +32,32 @@ def get_xdg_config_home() -> Path:
return Path.home() / ".config"
+def _get_xdg_cache_home() -> Path:
+ xdg_cache = os.environ.get("XDG_CACHE_HOME")
+ if xdg_cache:
+ return Path(xdg_cache)
+ return Path.home() / ".cache"
+
+
+def _get_registry_cache_path() -> Path:
+ return _get_xdg_cache_home() / "vet" / "remote_models.json"
+
+
+def update_remote_registry_cache() -> tuple[Path, ModelsConfig]:
+ url = os.environ.get("VET_REGISTRY_URL", _DEFAULT_REGISTRY_URL)
+ req = urllib.request.Request(url, headers={"User-Agent": "vet"})
+ with urllib.request.urlopen(req, timeout=_REGISTRY_FETCH_TIMEOUT_SECONDS) as resp:
+ data = resp.read()
+ try:
+ config = ModelsConfig.model_validate_json(data)
+ except ValidationError as e:
+ raise ConfigLoadError(f"Remote registry at {url} returned invalid data: {e}") from e
+ cache_path = _get_registry_cache_path()
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
+ cache_path.write_bytes(data)
+ return cache_path, config
+
+
def find_git_repo_root(start_path: Path) -> Path | None:
current = start_path.resolve()
while current != current.parent:
@@ -90,11 +110,15 @@ def load_models_config(repo_path: Path | None = None) -> ModelsConfig:
return ModelsConfig(providers=merged_providers)
-def get_user_defined_model_ids(config: ModelsConfig) -> set[str]:
- model_ids: set[str] = set()
- for provider in config.providers.values():
- model_ids.update(provider.models.keys())
- return model_ids
+def load_registry_config() -> ModelsConfig:
+ cache_path = _get_registry_cache_path()
+ if cache_path.exists():
+ return _load_single_config_file(cache_path)
+ return ModelsConfig(providers={})
+
+
+def get_model_ids_from_config(config: ModelsConfig) -> set[str]:
+ return {mid for provider in config.providers.values() for mid in provider.models}
def get_provider_for_model(model_id: str, config: ModelsConfig) -> ProviderConfig | None:
@@ -104,25 +128,6 @@ def get_provider_for_model(model_id: str, config: ModelsConfig) -> ProviderConfi
return None
-def validate_api_key_for_model(model_id: str, config: ModelsConfig) -> None:
- provider = get_provider_for_model(model_id, config)
- if provider is None:
- return
-
- api_key_env = provider.api_key_env
- if api_key_env is None:
- return
-
- api_key = os.environ.get(api_key_env, "")
- if not api_key:
- provider_name = provider.name or "unknown provider"
- raise MissingAPIKeyError(
- env_var=api_key_env,
- provider_name=provider_name,
- model_id=model_id,
- )
-
-
def get_models_by_provider_from_config(config: ModelsConfig) -> dict[str, list[str]]:
result: dict[str, list[str]] = {}
for provider_id, provider in config.providers.items():
@@ -131,40 +136,6 @@ def get_models_by_provider_from_config(config: ModelsConfig) -> dict[str, list[s
return result
-def get_max_output_tokens_for_model(model_id: str, config: ModelsConfig) -> int | None:
- provider = get_provider_for_model(model_id, config)
- if provider is not None:
- return provider.models[model_id].max_output_tokens
-
- try:
- from vet.imbue_core.agents.llm_apis.common import get_model_max_output_tokens
-
- return get_model_max_output_tokens(model_id)
- except Exception:
- return None
-
-
-def build_language_model_config(model_id: str, user_config: ModelsConfig):
- from vet.imbue_core.agents.configs import LanguageModelGenerationConfig
- from vet.imbue_core.agents.configs import OpenAICompatibleModelConfig
-
- provider = get_provider_for_model(model_id, user_config)
- if provider is None:
- return LanguageModelGenerationConfig(model_name=model_id)
-
- model_config = provider.models[model_id]
- actual_model_name = model_config.model_id or model_id
-
- return OpenAICompatibleModelConfig(
- model_name=actual_model_name,
- custom_base_url=provider.base_url,
- custom_api_key_env=provider.api_key_env or "",
- custom_context_window=model_config.context_window,
- custom_max_output_tokens=model_config.max_output_tokens,
- custom_supports_temperature=model_config.supports_temperature,
- )
-
-
def get_cli_config_file_paths(repo_path: Path | None = None) -> list[Path]:
return _get_config_file_paths("vet", "configs.toml", "configs.toml", repo_path)
diff --git a/vet/cli/config/loader_test.py b/vet/cli/config/loader_test.py
@@ -8,19 +8,23 @@ from unittest.mock import patch
import pytest
from vet.cli.config.loader import ConfigLoadError
-from vet.cli.config.loader import MissingAPIKeyError
from vet.cli.config.loader import _load_single_config_file
from vet.cli.config.loader import find_git_repo_root
from vet.cli.config.loader import get_config_file_paths
+from vet.cli.config.loader import get_model_ids_from_config
from vet.cli.config.loader import get_models_by_provider_from_config
from vet.cli.config.loader import get_provider_for_model
-from vet.cli.config.loader import get_user_defined_model_ids
from vet.cli.config.loader import get_xdg_config_home
from vet.cli.config.loader import load_models_config
-from vet.cli.config.loader import validate_api_key_for_model
+from vet.cli.config.loader import load_registry_config
+from vet.cli.config.loader import update_remote_registry_cache
from vet.cli.config.schema import ModelConfig
from vet.cli.config.schema import ModelsConfig
from vet.cli.config.schema import ProviderConfig
+from vet.cli.models import MissingProviderAPIKeyError
+from vet.cli.models import build_language_model_config
+from vet.cli.models import get_max_output_tokens_for_model
+from vet.cli.models import validate_api_key_for_model
def test_get_xdg_config_home_uses_env_var(tmp_path: Path) -> None:
@@ -246,7 +250,7 @@ def test_load_models_config_project_overrides_global(tmp_path: Path) -> None:
assert result.providers["shared-provider"].base_url == "http://project:8080/v1"
-def test_get_user_defined_model_ids_extracts_all_ids() -> None:
+def test_get_model_ids_from_config_extracts_all_ids() -> None:
config = ModelsConfig(
providers={
"provider1": ProviderConfig(
@@ -279,7 +283,7 @@ def test_get_user_defined_model_ids_extracts_all_ids() -> None:
}
)
- result = get_user_defined_model_ids(config)
+ result = get_model_ids_from_config(config)
assert result == {"model-a", "model-b", "model-c"}
@@ -382,7 +386,7 @@ def test_validate_api_key_raises_when_key_not_set() -> None:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("MISSING_KEY", None)
- with pytest.raises(MissingAPIKeyError) as exc_info:
+ with pytest.raises(MissingProviderAPIKeyError) as exc_info:
validate_api_key_for_model("model-a", config)
assert exc_info.value.env_var == "MISSING_KEY"
@@ -436,3 +440,315 @@ def test_get_models_by_provider_groups_models() -> None:
assert "openrouter" in result
assert result["openrouter"] == ["anthropic/claude-3"]
+
+
+_REMOTE_PROVIDER_JSON = json.dumps(
+ {
+ "providers": {
+ "remote-provider": {
+ "base_url": "http://remote:8080/v1",
+ "api_key_env": "REMOTE_KEY",
+ "models": {
+ "remote-model": {
+ "context_window": 128000,
+ "max_output_tokens": 16384,
+ "supports_temperature": True,
+ }
+ },
+ }
+ }
+ }
+)
+
+
+def test_update_remote_registry_cache_fetches_and_writes(tmp_path: Path, make_mock_response) -> None:
+ mock_response = make_mock_response(_REMOTE_PROVIDER_JSON.encode())
+
+ env = {"XDG_CACHE_HOME": str(tmp_path)}
+ with patch.dict(os.environ, env):
+ with patch("vet.cli.config.loader.urllib.request.urlopen", return_value=mock_response):
+ cache_path, config = update_remote_registry_cache()
+
+ assert cache_path.exists()
+ assert json.loads(cache_path.read_text())["providers"]["remote-provider"]
+ assert "remote-provider" in config.providers
+
+
+def test_update_remote_registry_cache_respects_custom_url(tmp_path: Path, make_mock_response) -> None:
+ custom_url = "https://example.com/custom/models.json"
+ mock_response = make_mock_response(_REMOTE_PROVIDER_JSON.encode())
+
+ env = {
+ "XDG_CACHE_HOME": str(tmp_path),
+ "VET_REGISTRY_URL": custom_url,
+ }
+ with patch.dict(os.environ, env):
+ with patch(
+ "vet.cli.config.loader.urllib.request.urlopen",
+ return_value=mock_response,
+ ) as mock_urlopen:
+ update_remote_registry_cache()
+
+ call_args = mock_urlopen.call_args
+ assert call_args[0][0].full_url == custom_url
+
+
+def test_update_remote_registry_cache_raises_on_network_error(tmp_path: Path) -> None:
+ env = {"XDG_CACHE_HOME": str(tmp_path)}
+ with patch.dict(os.environ, env):
+ with patch(
+ "vet.cli.config.loader.urllib.request.urlopen",
+ side_effect=OSError("no network"),
+ ):
+ with pytest.raises(OSError):
+ update_remote_registry_cache()
+
+
+def test_update_remote_registry_cache_rejects_invalid_json(tmp_path: Path, make_mock_response) -> None:
+ mock_response = make_mock_response(b"<html>Not Found</html>")
+
+ env = {"XDG_CACHE_HOME": str(tmp_path)}
+ with patch.dict(os.environ, env):
+ with patch(
+ "vet.cli.config.loader.urllib.request.urlopen",
+ return_value=mock_response,
+ ):
+ with pytest.raises(ConfigLoadError, match="invalid data"):
+ update_remote_registry_cache()
+
+ cache_file = tmp_path / "vet" / "remote_models.json"
+ assert not cache_file.exists()
+
+
+def test_load_models_config_does_not_include_registry(tmp_path: Path) -> None:
+ cache_dir = tmp_path / "cache" / "vet"
+ cache_dir.mkdir(parents=True)
+ cache_file = cache_dir / "remote_models.json"
+ cache_file.write_text(
+ json.dumps(
+ {
+ "providers": {
+ "remote-only": {
+ "base_url": "http://remote:8080/v1",
+ "models": {
+ "remote-model": {
+ "context_window": 64000,
+ "max_output_tokens": 8192,
+ "supports_temperature": True,
+ }
+ },
+ }
+ }
+ }
+ )
+ )
+
+ env = {
+ "XDG_CACHE_HOME": str(tmp_path / "cache"),
+ "XDG_CONFIG_HOME": str(tmp_path / "nonexistent"),
+ }
+ with patch.dict(os.environ, env):
+ result = load_models_config(repo_path=None)
+
+ assert "remote-only" not in result.providers
+
+
+def test_load_registry_config_returns_registry_providers(tmp_path: Path) -> None:
+ cache_dir = tmp_path / "vet"
+ cache_dir.mkdir(parents=True)
+ cache_file = cache_dir / "remote_models.json"
+ cache_file.write_text(_REMOTE_PROVIDER_JSON)
+
+ env = {"XDG_CACHE_HOME": str(tmp_path)}
+ with patch.dict(os.environ, env):
+ result = load_registry_config()
+
+ assert "remote-provider" in result.providers
+
+
+def test_load_registry_config_returns_empty_when_no_cache(tmp_path: Path) -> None:
+ env = {"XDG_CACHE_HOME": str(tmp_path)}
+ with patch.dict(os.environ, env):
+ result = load_registry_config()
+
+ assert result.providers == {}
+
+
+def test_load_registry_config_raises_on_corrupt_cache(tmp_path: Path) -> None:
+ cache_dir = tmp_path / "vet"
+ cache_dir.mkdir(parents=True)
+ cache_file = cache_dir / "remote_models.json"
+ cache_file.write_text("not valid json at all")
+
+ env = {"XDG_CACHE_HOME": str(tmp_path)}
+ with patch.dict(os.environ, env):
+ with pytest.raises(ConfigLoadError):
+ load_registry_config()
+
+
+def _make_provider(base_url: str, model_id: str, api_key_env: str | None = None) -> ProviderConfig:
+ return ProviderConfig(
+ base_url=base_url,
+ api_key_env=api_key_env,
+ models={
+ model_id: ModelConfig(
+ context_window=128000,
+ max_output_tokens=16384,
+ supports_temperature=True,
+ )
+ },
+ )
+
+
+def test_build_config_user_defined_wins_over_builtin_and_registry() -> None:
+ from vet.imbue_core.agents.configs import OpenAICompatibleModelConfig
+ from vet.imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
+
+ builtin_id = AnthropicModelName.CLAUDE_4_6_OPUS.value
+ user_config = ModelsConfig(providers={"custom": _make_provider("http://custom:8080/v1", builtin_id)})
+ registry_config = ModelsConfig(providers={"registry": _make_provider("http://registry:8080/v1", builtin_id)})
+
+ result = build_language_model_config(builtin_id, user_config, registry_config)
+ assert isinstance(result, OpenAICompatibleModelConfig)
+ assert result.custom_base_url == "http://custom:8080/v1"
+
+
+def test_build_config_builtin_wins_over_registry() -> None:
+ from vet.imbue_core.agents.configs import LanguageModelGenerationConfig
+ from vet.imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
+
+ builtin_id = AnthropicModelName.CLAUDE_4_6_OPUS.value
+ user_config = ModelsConfig(providers={})
+ registry_config = ModelsConfig(providers={"registry": _make_provider("http://registry:8080/v1", builtin_id)})
+
+ result = build_language_model_config(builtin_id, user_config, registry_config)
+ assert isinstance(result, LanguageModelGenerationConfig)
+ assert result.model_name == builtin_id
+
+
+def test_build_config_registry_used_for_unknown_model() -> None:
+ from vet.imbue_core.agents.configs import OpenAICompatibleModelConfig
+
+ user_config = ModelsConfig(providers={})
+ registry_config = ModelsConfig(
+ providers={"registry": _make_provider("http://registry:8080/v1", "registry-only-model")}
+ )
+
+ result = build_language_model_config("registry-only-model", user_config, registry_config)
+ assert isinstance(result, OpenAICompatibleModelConfig)
+ assert result.custom_base_url == "http://registry:8080/v1"
+
+
+def test_build_config_no_registry_falls_through() -> None:
+ from vet.imbue_core.agents.configs import LanguageModelGenerationConfig
+
+ user_config = ModelsConfig(providers={})
+ result = build_language_model_config("totally-unknown", user_config)
+ assert isinstance(result, LanguageModelGenerationConfig)
+ assert result.model_name == "totally-unknown"
+
+
+def test_validate_api_key_for_registry_model_passes_when_key_is_set() -> None:
+ user_config = ModelsConfig(providers={})
+ registry_config = ModelsConfig(
+ providers={
+ "registry": ProviderConfig(
+ name="Registry",
+ base_url="http://registry:8080/v1",
+ api_key_env="REGISTRY_API_KEY",
+ models={
+ "registry-model": ModelConfig(
+ context_window=128000,
+ max_output_tokens=16384,
+ supports_temperature=True,
+ )
+ },
+ )
+ }
+ )
+
+ with patch.dict(os.environ, {"REGISTRY_API_KEY": "secret"}):
+ validate_api_key_for_model("registry-model", user_config, registry_config)
+
+
+def test_validate_api_key_for_registry_model_raises_when_key_missing() -> None:
+ user_config = ModelsConfig(providers={})
+ registry_config = ModelsConfig(
+ providers={
+ "registry": ProviderConfig(
+ name="Registry",
+ base_url="http://registry:8080/v1",
+ api_key_env="MISSING_REGISTRY_KEY",
+ models={
+ "registry-model": ModelConfig(
+ context_window=128000,
+ max_output_tokens=16384,
+ supports_temperature=True,
+ )
+ },
+ )
+ }
+ )
+
+ with patch.dict(os.environ, {}, clear=True):
+ os.environ.pop("MISSING_REGISTRY_KEY", None)
+ with pytest.raises(MissingProviderAPIKeyError) as exc_info:
+ validate_api_key_for_model("registry-model", user_config, registry_config)
+
+ assert exc_info.value.env_var == "MISSING_REGISTRY_KEY"
+ assert exc_info.value.model_id == "registry-model"
+
+
+def test_get_max_output_tokens_for_registry_model() -> None:
+ user_config = ModelsConfig(providers={})
+ registry_config = ModelsConfig(
+ providers={
+ "registry": ProviderConfig(
+ base_url="http://registry:8080/v1",
+ models={
+ "registry-model": ModelConfig(
+ context_window=128000,
+ max_output_tokens=32768,
+ supports_temperature=True,
+ )
+ },
+ )
+ }
+ )
+
+ result = get_max_output_tokens_for_model("registry-model", user_config, registry_config)
+ assert result == 32768
+
+
+def test_get_max_output_tokens_user_config_wins_over_registry() -> None:
+ user_config = ModelsConfig(
+ providers={
+ "user": ProviderConfig(
+ base_url="http://user:8080/v1",
+ models={
+ "shared-model": ModelConfig(
+ context_window=128000,
+ max_output_tokens=65536,
+ supports_temperature=True,
+ )
+ },
+ )
+ }
+ )
+ registry_config = ModelsConfig(
+ providers={
+ "registry": ProviderConfig(
+ base_url="http://registry:8080/v1",
+ models={
+ "shared-model": ModelConfig(
+ context_window=128000,
+ max_output_tokens=16384,
+ supports_temperature=True,
+ )
+ },
+ )
+ }
+ )
+
+ result = get_max_output_tokens_for_model("shared-model", user_config, registry_config)
+ assert result == 65536
diff --git a/vet/cli/conftest.py b/vet/cli/conftest.py
@@ -0,0 +1,21 @@
+from __future__ import annotations
+
+import pytest
+
+
+@pytest.fixture()
+def make_mock_response():
+ """Factory fixture that creates mock urllib response objects."""
+
+ def _make(data: bytes):
+ return type(
+ "Response",
+ (),
+ {
+ "read": lambda self: data,
+ "__enter__": lambda self: self,
+ "__exit__": lambda *a: None,
+ },
+ )()
+
+ return _make
diff --git a/vet/cli/main.py b/vet/cli/main.py
@@ -20,7 +20,8 @@ from vet.cli.config.loader import get_config_preset
from vet.cli.config.loader import load_cli_config
from vet.cli.config.loader import load_custom_guides_config
from vet.cli.config.loader import load_models_config
-from vet.cli.config.loader import validate_api_key_for_model
+from vet.cli.config.loader import load_registry_config
+from vet.cli.config.loader import update_remote_registry_cache
from vet.cli.config.schema import ModelsConfig
from vet.formatters import OUTPUT_FIELDS
from vet.formatters import OUTPUT_FORMATS
@@ -156,6 +157,11 @@ def create_parser() -> argparse.ArgumentParser:
help="List all available models",
)
model_group.add_argument(
+ "--update-models",
+ action="store_true",
+ help=argparse.SUPPRESS,
+ )
+ model_group.add_argument(
"--temperature",
type=float,
default=CLI_DEFAULTS.temperature,
@@ -276,6 +282,7 @@ _HARNESS_ISSUE_URLS: dict[AgentHarnessType, str] = {
def list_models(
user_config: ModelsConfig | None = None,
*,
+ registry_config: ModelsConfig | None = None,
agentic: bool = False,
agent_harness: AgentHarnessType | None = None,
) -> None:
@@ -301,7 +308,7 @@ def list_models(
print("Available models:")
print()
- models_by_provider = get_models_by_provider(user_config)
+ models_by_provider = get_models_by_provider(user_config, registry_config)
for provider, model_ids in sorted(models_by_provider.items()):
print(f" {provider}:")
for model_id in sorted(model_ids):
@@ -426,6 +433,28 @@ def main(argv: list[str] | None = None) -> int:
parser = create_parser()
args = parser.parse_args(argv)
+ # Handle subcommands that don't need config loading.
+ if args.update_models:
+ try:
+ cache_path, updated_config = update_remote_registry_cache()
+ model_count = sum(len(p.models) for p in updated_config.providers.values())
+ provider_count = len(updated_config.providers)
+ print(f"Updated model registry ({model_count} models from {provider_count} providers).")
+ print(f"Cache written to {cache_path}")
+ except Exception as e:
+ print(f"vet: failed to update model registry: {e}", file=sys.stderr)
+ return 1
+ return 0
+
+ if args.list_issue_codes:
+ list_issue_codes()
+ return 0
+
+ if args.list_fields:
+ list_fields()
+ return 0
+
+ # Load configs needed by the remaining commands.
goal = args.goal or ""
repo_path = args.repo
@@ -437,27 +466,26 @@ def main(argv: list[str] | None = None) -> int:
return 2
try:
+ registry_config = load_registry_config()
+ except ConfigLoadError as e:
+ logger.warning("Could not load remote registry: {}", e)
+ registry_config = ModelsConfig(providers={})
+
+ try:
custom_guides_config = load_custom_guides_config(repo_path)
except ConfigLoadError as e:
print(f"vet: could not load custom guides: {e}", file=sys.stderr)
return 2
- if args.list_issue_codes:
- list_issue_codes()
- return 0
-
if args.list_models:
list_models(
user_config,
+ registry_config=registry_config,
agentic=args.agentic,
agent_harness=args.agent_harness if args.agentic else None,
)
return 0
- if args.list_fields:
- list_fields()
- return 0
-
try:
cli_configs = load_cli_config(repo_path)
except ConfigLoadError as e:
@@ -523,10 +551,15 @@ def main(argv: list[str] | None = None) -> int:
configure_logging(args.verbose, args.log_file)
+ # Lazy imports: vet.cli.models transitively imports the LLM SDK provider
+ # modules (~1s), so it must NOT be imported at module level. Lightweight
+ # subcommands (--version, --list-issue-codes, --list-fields, --update-models)
+ # exit before reaching this point. See startup_time_test.py.
from vet.api import find_issues
- from vet.cli.config.loader import build_language_model_config
- from vet.cli.config.loader import get_max_output_tokens_for_model
from vet.cli.models import DEFAULT_MODEL_ID
+ from vet.cli.models import build_language_model_config
+ from vet.cli.models import get_max_output_tokens_for_model
+ from vet.cli.models import validate_api_key_for_model
from vet.cli.models import validate_model_id
from vet.formatters import format_github_review
from vet.formatters import format_issue_text
@@ -584,20 +617,20 @@ def main(argv: list[str] | None = None) -> int:
model_id = args.model or DEFAULT_MODEL_ID
try:
- model_id = validate_model_id(model_id, user_config)
+ model_id = validate_model_id(model_id, user_config, registry_config)
except ValueError as e:
print(f"vet: {e}", file=sys.stderr)
return 2
try:
- validate_api_key_for_model(model_id, user_config)
+ validate_api_key_for_model(model_id, user_config, registry_config)
except Exception as e:
print(f"vet: {e}", file=sys.stderr)
return 2
# TODO: Support OFFLINE, UPDATE_SNAPSHOT, and MOCKED modes.
- language_model_config = build_language_model_config(model_id, user_config)
- max_output_tokens = get_max_output_tokens_for_model(model_id, user_config)
+ language_model_config = build_language_model_config(model_id, user_config, registry_config)
+ max_output_tokens = get_max_output_tokens_for_model(model_id, user_config, registry_config)
config = VetConfig(
enabled_identifiers=enabled_identifiers,
diff --git a/vet/cli/main_test.py b/vet/cli/main_test.py
@@ -0,0 +1,147 @@
+from __future__ import annotations
+
+import json
+import os
+from pathlib import Path
+from unittest.mock import patch
+
+from vet.cli.main import main
+
+_REMOTE_PROVIDER_JSON = json.dumps(
+ {
+ "providers": {
+ "remote-provider": {
+ "base_url": "http://remote:8080/v1",
+ "api_key_env": "REMOTE_KEY",
+ "models": {
+ "remote-model-a": {
+ "context_window": 128000,
+ "max_output_tokens": 16384,
+ "supports_temperature": True,
+ },
+ "remote-model-b": {
+ "context_window": 64000,
+ "max_output_tokens": 8192,
+ "supports_temperature": False,
+ },
+ },
+ }
+ }
+ }
+)
+
+
+def _env_for_isolated_config(tmp_path: Path) -> dict[str, str]:
+ """Return env overrides that isolate XDG dirs to tmp_path."""
+ return {
+ "XDG_CONFIG_HOME": str(tmp_path / "config"),
+ "XDG_CACHE_HOME": str(tmp_path / "cache"),
+ }
+
+
+class TestUpdateModels:
+ """CLI integration tests for the --update-models flag."""
+
+ def test_update_models_success(self, tmp_path: Path, capsys, make_mock_response) -> None:
+ mock_response = make_mock_response(_REMOTE_PROVIDER_JSON.encode())
+ env = _env_for_isolated_config(tmp_path)
+
+ with patch.dict(os.environ, env):
+ with patch(
+ "vet.cli.config.loader.urllib.request.urlopen",
+ return_value=mock_response,
+ ):
+ exit_code = main(["--update-models"])
+
+ assert exit_code == 0
+
+ captured = capsys.readouterr()
+ assert "Updated model registry" in captured.out
+ assert "2 models from 1 providers" in captured.out
+ assert "Cache written to" in captured.out
+
+ def test_update_models_writes_cache_file(self, tmp_path: Path, make_mock_response) -> None:
+ mock_response = make_mock_response(_REMOTE_PROVIDER_JSON.encode())
+ env = _env_for_isolated_config(tmp_path)
+
+ with patch.dict(os.environ, env):
+ with patch(
+ "vet.cli.config.loader.urllib.request.urlopen",
+ return_value=mock_response,
+ ):
+ main(["--update-models"])
+
+ cache_file = tmp_path / "cache" / "vet" / "remote_models.json"
+ assert cache_file.exists()
+ data = json.loads(cache_file.read_text())
+ assert "remote-provider" in data["providers"]
+
+ def test_update_models_network_error_returns_1(self, tmp_path: Path, capsys) -> None:
+ env = _env_for_isolated_config(tmp_path)
+
+ with patch.dict(os.environ, env):
+ with patch(
+ "vet.cli.config.loader.urllib.request.urlopen",
+ side_effect=OSError("connection refused"),
+ ):
+ exit_code = main(["--update-models"])
+
+ assert exit_code == 1
+
+ captured = capsys.readouterr()
+ assert "failed to update model registry" in captured.err
+ assert "connection refused" in captured.err
+
+ def test_update_models_invalid_remote_data_returns_1(self, tmp_path: Path, capsys, make_mock_response) -> None:
+ mock_response = make_mock_response(b"<html>Not Found</html>")
+ env = _env_for_isolated_config(tmp_path)
+
+ with patch.dict(os.environ, env):
+ with patch(
+ "vet.cli.config.loader.urllib.request.urlopen",
+ return_value=mock_response,
+ ):
+ exit_code = main(["--update-models"])
+
+ assert exit_code == 1
+
+ captured = capsys.readouterr()
+ assert "failed to update model registry" in captured.err
+
+ def test_update_models_does_not_write_cache_on_invalid_data(self, tmp_path: Path, make_mock_response) -> None:
+ mock_response = make_mock_response(b"not json at all")
+ env = _env_for_isolated_config(tmp_path)
+
+ with patch.dict(os.environ, env):
+ with patch(
+ "vet.cli.config.loader.urllib.request.urlopen",
+ return_value=mock_response,
+ ):
+ main(["--update-models"])
+
+ cache_file = tmp_path / "cache" / "vet" / "remote_models.json"
+ assert not cache_file.exists()
+
+
+class TestListModels:
+ """CLI integration tests for the --list-models flag."""
+
+ def test_list_models_shows_registry_models(self, tmp_path: Path, capsys, make_mock_response) -> None:
+ """Registry models should appear in --list-models output after --update-models."""
+ mock_response = make_mock_response(_REMOTE_PROVIDER_JSON.encode())
+ env = _env_for_isolated_config(tmp_path)
+
+ with patch.dict(os.environ, env):
+ with patch(
+ "vet.cli.config.loader.urllib.request.urlopen",
+ return_value=mock_response,
+ ):
+ main(["--update-models"])
+
+ exit_code = main(["--list-models"])
+
+ assert exit_code == 0
+
+ captured = capsys.readouterr()
+ assert "remote-model-a" in captured.out
+ assert "remote-model-b" in captured.out
diff --git a/vet/cli/models.py b/vet/cli/models.py
@@ -1,8 +1,12 @@
from __future__ import annotations
+import os
+
+from vet.cli.config.loader import get_model_ids_from_config
from vet.cli.config.loader import get_models_by_provider_from_config
-from vet.cli.config.loader import get_user_defined_model_ids
+from vet.cli.config.loader import get_provider_for_model
from vet.cli.config.schema import ModelsConfig
+from vet.cli.config.schema import ProviderConfig
from vet.imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
from vet.imbue_core.agents.llm_apis.common import get_all_model_names
from vet.imbue_core.agents.llm_apis.gemini_api import GeminiModelName
@@ -11,31 +15,50 @@ from vet.imbue_core.agents.llm_apis.openai_api import OpenAIModelName
DEFAULT_MODEL_ID = AnthropicModelName.CLAUDE_4_6_OPUS.value
+class MissingProviderAPIKeyError(Exception):
+ def __init__(self, env_var: str, provider_name: str, model_id: str) -> None:
+ self.env_var = env_var
+ self.provider_name = provider_name
+ self.model_id = model_id
+ super().__init__(
+ f"API key not found: environment variable '{env_var}' is not set. "
+ + f"This is required for model '{model_id}' from provider '{provider_name}'."
+ )
+
+
def get_builtin_model_ids() -> set[str]:
return {str(name) for name in get_all_model_names()}
-def get_all_model_ids(user_config: ModelsConfig | None = None) -> set[str]:
+def get_all_model_ids(
+ user_config: ModelsConfig | None = None,
+ registry_config: ModelsConfig | None = None,
+) -> set[str]:
model_ids = get_builtin_model_ids()
if user_config:
- model_ids.update(get_user_defined_model_ids(user_config))
-
- return model_ids
+ model_ids.update(get_model_ids_from_config(user_config))
+ if registry_config:
+ model_ids.update(get_model_ids_from_config(registry_config))
-def is_valid_model_id(model_id: str, user_config: ModelsConfig | None = None) -> bool:
- return model_id in get_all_model_ids(user_config)
+ return model_ids
-def is_user_defined_model(model_id: str, user_config: ModelsConfig | None = None) -> bool:
- if user_config is None:
- return False
- return model_id in get_user_defined_model_ids(user_config)
+def is_valid_model_id(
+ model_id: str,
+ user_config: ModelsConfig | None = None,
+ registry_config: ModelsConfig | None = None,
+) -> bool:
+ return model_id in get_all_model_ids(user_config, registry_config)
-def validate_model_id(model_id: str, user_config: ModelsConfig | None = None) -> str:
- if not is_valid_model_id(model_id, user_config):
+def validate_model_id(
+ model_id: str,
+ user_config: ModelsConfig | None = None,
+ registry_config: ModelsConfig | None = None,
+) -> str:
+ if not is_valid_model_id(model_id, user_config, registry_config):
raise ValueError(f"Unknown model: {model_id}. Use --list-models to see available models.")
return model_id
@@ -50,12 +73,109 @@ def get_builtin_models_by_provider() -> dict[str, list[str]]:
def get_models_by_provider(
user_config: ModelsConfig | None = None,
+ registry_config: ModelsConfig | None = None,
) -> dict[str, list[str]]:
- providers = get_builtin_models_by_provider()
+ providers: dict[str, list[str]] = {}
+
+ def _merge(source: dict[str, list[str]]) -> None:
+ for name, models in source.items():
+ if name in providers:
+ seen = set(providers[name])
+ providers[name].extend(m for m in models if m not in seen)
+ else:
+ providers[name] = list(models)
+
+ if registry_config:
+ _merge(get_models_by_provider_from_config(registry_config))
+
+ _merge(get_builtin_models_by_provider())
if user_config:
- user_providers = get_models_by_provider_from_config(user_config)
- for provider_name, model_ids in user_providers.items():
- providers[provider_name] = model_ids
+ _merge(get_models_by_provider_from_config(user_config))
return providers
+
+
+def _resolve_provider(
+ model_id: str,
+ user_config: ModelsConfig,
+ registry_config: ModelsConfig | None = None,
+) -> ProviderConfig | None:
+ provider = get_provider_for_model(model_id, user_config)
+ if provider is not None:
+ return provider
+
+ if model_id in get_builtin_model_ids():
+ return None
+
+ if registry_config is not None:
+ return get_provider_for_model(model_id, registry_config)
+
+ return None
+
+
+def validate_api_key_for_model(
+ model_id: str,
+ user_config: ModelsConfig,
+ registry_config: ModelsConfig | None = None,
+) -> None:
+ provider = _resolve_provider(model_id, user_config, registry_config)
+
+ if provider is None:
+ return
+
+ api_key_env = provider.api_key_env
+ if api_key_env is None:
+ return
+
+ api_key = os.environ.get(api_key_env, "")
+ if not api_key:
+ provider_name = provider.name or "unknown provider"
+ raise MissingProviderAPIKeyError(
+ env_var=api_key_env,
+ provider_name=provider_name,
+ model_id=model_id,
+ )
+
+
+def get_max_output_tokens_for_model(
+ model_id: str,
+ user_config: ModelsConfig,
+ registry_config: ModelsConfig | None = None,
+) -> int | None:
+ provider = _resolve_provider(model_id, user_config, registry_config)
+ if provider is not None:
+ return provider.models[model_id].max_output_tokens
+
+ try:
+ from vet.imbue_core.agents.llm_apis.common import get_model_max_output_tokens
+
+ return get_model_max_output_tokens(model_id)
+ except Exception:
+ return None
+
+
+def build_language_model_config(
+ model_id: str,
+ user_config: ModelsConfig,
+ registry_config: ModelsConfig | None = None,
+):
+ from vet.imbue_core.agents.configs import LanguageModelGenerationConfig
+ from vet.imbue_core.agents.configs import OpenAICompatibleModelConfig
+
+ provider = _resolve_provider(model_id, user_config, registry_config)
+
+ if provider is None:
+ return LanguageModelGenerationConfig(model_name=model_id)
+
+ model_config = provider.models[model_id]
+ actual_model_name = model_config.model_id or model_id
+
+ return OpenAICompatibleModelConfig(
+ model_name=actual_model_name,
+ custom_base_url=provider.base_url,
+ custom_api_key_env=provider.api_key_env or "",
+ custom_context_window=model_config.context_window,
+ custom_max_output_tokens=model_config.max_output_tokens,
+ custom_supports_temperature=model_config.supports_temperature,
+ )
diff --git a/vet/cli/models_test.py b/vet/cli/models_test.py
@@ -10,7 +10,6 @@ from vet.cli.models import get_all_model_ids
from vet.cli.models import get_builtin_model_ids
from vet.cli.models import get_builtin_models_by_provider
from vet.cli.models import get_models_by_provider
-from vet.cli.models import is_user_defined_model
from vet.cli.models import is_valid_model_id
from vet.cli.models import validate_model_id
@@ -71,18 +70,6 @@ def test_is_valid_model_id(model_id: str, user_config: ModelsConfig | None, expe
assert is_valid_model_id(model_id, user_config) is expected
-@pytest.mark.parametrize(
- ("model_id", "user_config", "expected"),
- [
- ("any-model", None, False),
- ("my-custom-model", SAMPLE_USER_CONFIG, True),
- (DEFAULT_MODEL_ID, SAMPLE_USER_CONFIG, False),
- ],
-)
-def test_is_user_defined_model(model_id: str, user_config: ModelsConfig | None, expected: bool) -> None:
- assert is_user_defined_model(model_id, user_config) is expected
-
-
def test_validate_model_id_returns_model_id_when_valid() -> None:
result = validate_model_id(DEFAULT_MODEL_ID)
assert result == DEFAULT_MODEL_ID
@@ -172,7 +159,7 @@ def test_get_models_by_provider_includes_user_defined_providers() -> None:
assert "openai" in providers
-def test_get_models_by_provider_user_provider_overrides_builtin_with_same_name() -> None:
+def test_get_models_by_provider_user_provider_merges_with_builtin_same_name() -> None:
user_config = ModelsConfig(
providers={
"custom": ProviderConfig(
@@ -192,4 +179,87 @@ def test_get_models_by_provider_user_provider_overrides_builtin_with_same_name()
providers = get_models_by_provider(user_config)
- assert providers["anthropic"] == ["custom-model"]
+ assert "custom-model" in providers["anthropic"]
+ assert DEFAULT_MODEL_ID in providers["anthropic"]
+
+
+SAMPLE_REGISTRY_CONFIG = ModelsConfig(
+ providers={
+ "registry-provider": ProviderConfig(
+ name="Registry Provider",
+ base_url="http://registry:8080/v1",
+ api_key_env="REGISTRY_KEY",
+ models={
+ "registry-model": ModelConfig(
+ context_window=128000,
+ max_output_tokens=16384,
+ supports_temperature=True,
+ ),
+ },
+ )
+ }
+)
+
+
+def test_get_all_model_ids_includes_registry_models() -> None:
+ all_ids = get_all_model_ids(
+ user_config=SAMPLE_USER_CONFIG,
+ registry_config=SAMPLE_REGISTRY_CONFIG,
+ )
+
+ assert "my-custom-model" in all_ids
+ assert DEFAULT_MODEL_ID in all_ids
+ assert "registry-model" in all_ids
+
+
+def test_validate_model_id_accepts_registry_model() -> None:
+ result = validate_model_id(
+ "registry-model",
+ user_config=None,
+ registry_config=SAMPLE_REGISTRY_CONFIG,
+ )
+ assert result == "registry-model"
+
+
+def test_validate_model_id_rejects_unknown_even_with_registry() -> None:
+ with pytest.raises(ValueError):
+ validate_model_id(
+ "totally-unknown",
+ user_config=SAMPLE_USER_CONFIG,
+ registry_config=SAMPLE_REGISTRY_CONFIG,
+ )
+
+
+def test_get_models_by_provider_includes_registry_providers() -> None:
+ providers = get_models_by_provider(
+ user_config=None,
+ registry_config=SAMPLE_REGISTRY_CONFIG,
+ )
+
+ assert "Registry Provider" in providers
+ assert "registry-model" in providers["Registry Provider"]
+ assert "anthropic" in providers
+ assert "openai" in providers
+
+
+def test_get_models_by_provider_registry_merges_with_builtin_same_name() -> None:
+ registry_config = ModelsConfig(
+ providers={
+ "anthropic-override": ProviderConfig(
+ name="anthropic",
+ base_url="http://registry:8080/v1",
+ models={
+ "registry-claude": ModelConfig(
+ context_window=128000,
+ max_output_tokens=16384,
+ supports_temperature=True,
+ )
+ },
+ )
+ }
+ )
+
+ providers = get_models_by_provider(user_config=None, registry_config=registry_config)
+
+ assert "registry-claude" in providers["anthropic"]
+ assert DEFAULT_MODEL_ID in providers["anthropic"]