commit 91088763849a9c1db036598089fbfe0f6e725e09
parent b25a32e27d730e051b4451ad757c297b39f7097e
Author: andrewlaack-collab <andrew.laack@imbue.com>
Date: Thu, 29 Jan 2026 20:35:02 +0000
Rename imbue_verify -> Vet (#3)
* Large renaming
Diffstat:
79 files changed, 5204 insertions(+), 5194 deletions(-)
diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md
@@ -1,24 +1,24 @@
-# imbue_verify
+# vet
-Imbue verify is a library and CLI tool for verifying code quality and correctness.
+Vet is a library and CLI tool for verifying code quality and correctness.
## Installation
From the repository root:
```bash
-uv sync --project imbue-verify
+uv sync --project vet
```
## Usage
```bash
-uv run imbue-verify --help
-uv run imbue-verify "description of what the code change accomplishes"
-uv run imbue-verify --list-models
-uv run imbue-verify "description of what the code change accomplishes" --model claude-opus-4-5-20251101
-uv run imbue-verify --model claude-opus-4-5-20251101 # no goal specified
-uv run imbue-verify --model claude-opus-4-5-20251101 --base-commit main # default is HEAD
+uv run vet --help
+uv run vet "description of what the code change accomplishes"
+uv run vet --list-models
+uv run vet "description of what the code change accomplishes" --model claude-opus-4-5-20251101
+uv run vet --model claude-opus-4-5-20251101 # no goal specified
+uv run vet --model claude-opus-4-5-20251101 --base-commit main # default is HEAD
```
## Custom Models
@@ -62,7 +62,7 @@ Example configuration:
## Exit Status
-The following are the **expected** exit status codes for imbue-verify:
+The following are the **expected** exit status codes for vet:
- `0` - Success, no issues found
- `1` - Issues were found in the code
@@ -80,17 +80,17 @@ Issue identifiers are pieces of logic capable of finding issues in code. We fore
- To check for the quality of a single commit.
- "Assuming that we can treat the commit message as a requirement, how well does the commit implement it?"
-By default, `imbue_verify` runs all the registered issue identifiers and outputs all the found issues on the standard output in JSON format.
+By default, `vet` runs all the registered issue identifiers and outputs all the found issues on the standard output in JSON format.
#### Adding new Issue Identifiers
If you want to add a new issue identifier, you need to:
1. Implement the `IssueIdentifier` protocol from `imbue_tools.repo_utils.data_types`.
-2. Register the new issue identifier by adding it to `IDENTIFIERS` in `imbue_verify.issue_identifiers.registry`.
+2. Register the new issue identifier by adding it to `IDENTIFIERS` in `vet.issue_identifiers.registry`.
Based on your needs, instead of the above, you can also extend one of the existing batched zero-shot issue identifiers:
- - `imbue_verify/issue_identifiers/batched_commit_check.py`
+ - `vet/issue_identifiers/batched_commit_check.py`
(for commit checking)
In that case you would simply expand the rubric in the prompt. That is actually the preferred way to catch issues at the moment due to efficiency.
Refer to the source code for more details.
@@ -99,7 +99,7 @@ Refer to the source code for more details.
### Logging Configuration
-When creating a new entrypoint into imbue_verify, you must call `ensure_core_log_levels_configured()` to register the custom log levels used throughout the codebase.
+When creating a new entrypoint into vet, you must call `ensure_core_log_levels_configured()` to register the custom log levels used throughout the codebase.
```python
from imbue_core.log_utils import ensure_core_log_levels_configured
diff --git a/README.md b/README.md
@@ -1,8 +1,8 @@
-# VET : Verify EveryThing
+# Vet : Verify EveryThing
-VET is a standalone verification tool for **code changes** and **coding agent behavior**.
+Vet is a standalone verification tool for **code changes** and **coding agent behavior**.
-It reviews git diffs, and optionally an agent's conversation history, to find issues that tests and linters often miss. VET is optimized for use by humans, CI, and coding agents.
+It reviews git diffs, and optionally an agent's conversation history, to find issues that tests and linters often miss. Vet is optimized for use by humans, CI, and coding agents.
## Installation
@@ -12,7 +12,7 @@ pip install vet
## Quickstart
-Run VET in the current repo:
+Run Vet in the current repo:
```bash
vet "Implement X without breaking Y"
@@ -26,7 +26,7 @@ vet "Refactor storage layer" --base-commit main
## How it works
-VET snapshots the repo and diff, optionally adds a goal and agent conversation, runs LLM checks, then filters/deduplicates findings into a final list of issues.
+Vet snapshots the repo and diff, optionally adds a goal and agent conversation, runs LLM checks, then filters/deduplicates findings into a final list of issues.
TODO: Create rendering pipeline for this. GitHub MD doesn't directly support Graphviz.
@@ -58,12 +58,12 @@ digraph VET_DataFlow {
}
```
-## Why VET
+## Why Vet
- **Verification for agentic workflows**: "the agent said it ran tests" is not the same as "all tests ran successfully".
- **CI-friendly safety net**: catches classes of problems that may not be covered by existing tests.
- **Bring-your-own-model**: can run against hosted providers or local/self-hosted OpenAI-compatible endpoints.
-- **No telemetry collected by us**: VET does not collect any user data.
+- **No telemetry collected by us**: Vet does not collect any user data.
## Output & exit codes
@@ -77,7 +77,7 @@ Output formats:
## CI usage
-Recommended CI usage is to run VET with JSON output and display a warning if any issues are found.
+Recommended CI usage is to run Vet with JSON output and display a warning if any issues are found.
Example:
@@ -85,15 +85,15 @@ Example:
vet --base-commit main --output-format json > vet-report.json
```
-- If VET exits `0`, no issues were found.
-- If VET exits `1`, issues were found (treat as a failing check).
-- If VET exits `2`, the invocation/config is invalid (treat as a failing check).
+- If Vet exits `0`, no issues were found.
+- If Vet exits `1`, issues were found (treat as a failing check).
+- If Vet exits `2`, the invocation/config is invalid (treat as a failing check).
## Configuration
### Model configuration
-VET supports custom model definitions using OpenAI-compatible endpoints via JSON config files searched in:
+Vet supports custom model definitions using OpenAI-compatible endpoints via JSON config files searched in:
- `$XDG_CONFIG_HOME/imbue/models.json` (or `~/.config/imbue/models.json`)
- `models.json` at your repo root
@@ -138,7 +138,7 @@ vet "Harden error handling" --model gpt-4o-mini
### Configuration profiles (TOML)
-VET supports named profiles so teams can standardize CI usage without long CLI invocations.
+Vet supports named profiles so teams can standardize CI usage without long CLI invocations.
Profiles set defaults like model choice, enabled issue codes, output format, and thresholds.
@@ -146,7 +146,7 @@ Profiles set defaults like model choice, enabled issue codes, output format, and
### Conversation history
-VET can **optionally** ingest agent conversation history via a **history loader command**.
+Vet can **optionally** ingest agent conversation history via a **history loader command**.
#### History loader contract
@@ -155,8 +155,8 @@ VET can **optionally** ingest agent conversation history via a **history loader
Security note: this executes a command on your machine. Only run history loader commands you trust.
- Output format: **any text**
-- VET treats this as an opaque transcript (it may include user/assistant messages, tool calls, tool results, logs, etc.)
-- If you want VET to catch “claimed to run tests” style issues reliably, ensure your transcript includes tool invocations/results (or other evidence), not just prose.
+- Vet treats this as an opaque transcript (it may include user/assistant messages, tool calls, tool results, logs, etc.)
+- If you want Vet to catch “claimed to run tests” style issues reliably, ensure your transcript includes tool invocations/results (or other evidence), not just prose.
Example:
@@ -178,6 +178,6 @@ vet --history-loader "vet-history gemini-cli --latest"
## Privacy / telemetry
-VET does **not** collect telemetry and does not send usage data to external services.
+Vet does **not** collect telemetry and does not send usage data to external services.
-If you configure VET to use a hosted inference provider, that provider may log requests; selecting a provider is the user’s responsibility.
+If you configure Vet to use a hosted inference provider, that provider may log requests; selecting a provider is the user’s responsibility.
diff --git a/imbue_tools/README.md b/imbue_tools/README.md
@@ -1,5 +1,5 @@
# Purpose
-Shared functionality for imbue-cli tools like imbue-verify, imbue-retrieve, etc.
+Shared functionality for imbue-cli tools like vet, imbue-retrieve, etc.
# Contents
- formatting git repos as LLM input
diff --git a/imbue_tools/imbue_tools/repo_utils/project_context.py b/imbue_tools/imbue_tools/repo_utils/project_context.py
@@ -92,7 +92,7 @@ class LazyProjectContext(SerializableModel):
diff: str,
language_model_name: str,
repo_path: Path,
- # How many tokens to keep for the imbue_verify specific prompt and any output tokens.
+ # How many tokens to keep for the vet specific prompt and any output tokens.
tokens_to_reserve: int,
context_window: int | None = None,
is_custom_model: bool = False,
diff --git a/imbue_tools/imbue_tools/types/imbue_verify_config.py b/imbue_tools/imbue_tools/types/imbue_verify_config.py
@@ -1,100 +0,0 @@
-from pathlib import Path
-
-from imbue_core.agents.configs import LanguageModelGenerationConfig
-from imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
-from imbue_core.data_types import IssueCode
-from imbue_core.pydantic_serialization import SerializableModel
-
-DEFAULT_CONFIDENCE_THRESHOLD = 0.8
-
-
-class ImbueVerifyConfig(SerializableModel):
- """Configuration for the imbue_verify system."""
-
- # If none, all registered identifiers are used.
- # Otherwise, only the identifiers in this tuple are used.
- enabled_identifiers: tuple[str, ...] | None = None
-
- # Issue identifiers that are disabled are never used.
- disabled_identifiers: tuple[str, ...] | None = None
-
- # Similar to the above, but for reporting specific types of issues.
- # (Use the values from the imbue_verify.data_types.IssueCode enum.)
- enabled_issue_codes: tuple[IssueCode, ...] | None = None
- disabled_issue_codes: tuple[IssueCode, ...] | None = ()
-
- # Todo: Different models for different issue identifiers
- language_model_generation_config: LanguageModelGenerationConfig = LanguageModelGenerationConfig(
- model_name=AnthropicModelName.CLAUDE_4_5_HAIKU_2025_10_01
- )
- max_identifier_spend_dollars: float = 5.0
- max_output_tokens: int = 20000
- enable_parallel_agentic_issue_identification: bool = False
- max_identify_workers: int | None = None
- temperature: float = 0.5
-
- # If True, apply an additional LLM-based filtering stage, where each identified issue is evaluated
- # according to a number of quality criteria. Only issues that pass the evaluation are returned.
- filter_issues: bool = True
- filter_issues_through_llm_evaluator: bool = True
- filter_issues_below_confidence: float | None = DEFAULT_CONFIDENCE_THRESHOLD
-
- enable_deduplication: bool = True
- enable_collation: bool = True
-
- # If True, we attempt to cache the full prompts including specific inputs with the LLM provider.
- # There can be an additional cost for such a cache write, but it can help save cost in evaluation
- # contexts (such as black_box_evals) where the same inputs are being evaluated multiple times.
- cache_full_prompt: bool = False
-
- extra_context: str | None = None
-
- @classmethod
- def build(
- cls,
- language_model_name: str | None = None,
- language_model_cache_path: Path | None = None,
- enabled_identifiers: tuple[str, ...] | None = None,
- enable_parallel_agentic_issue_identification: bool = False,
- max_identify_workers: int | None = None,
- filter_issues: bool = True,
- filter_issues_below_confidence: float | None = DEFAULT_CONFIDENCE_THRESHOLD,
- enable_deduplication: bool = True,
- enable_collation: bool = True,
- enabled_issue_codes: tuple[IssueCode, ...] | None = None,
- temperature: float = 0.5,
- retry_jitter_factor: float = 0.0,
- cache_full_prompt: bool = False,
- ) -> "ImbueVerifyConfig":
- if not language_model_name:
- language_model_name = AnthropicModelName.CLAUDE_4_5_HAIKU_2025_10_01
- language_model_generation_config = LanguageModelGenerationConfig(
- model_name=language_model_name,
- cache_path=language_model_cache_path,
- retry_jitter_factor=retry_jitter_factor,
- )
- return cls(
- language_model_generation_config=language_model_generation_config,
- enabled_identifiers=enabled_identifiers,
- enable_parallel_agentic_issue_identification=enable_parallel_agentic_issue_identification,
- max_identify_workers=max_identify_workers,
- filter_issues=filter_issues,
- filter_issues_below_confidence=filter_issues_below_confidence,
- enable_deduplication=enable_deduplication,
- enable_collation=enable_collation,
- enabled_issue_codes=enabled_issue_codes,
- temperature=temperature,
- cache_full_prompt=cache_full_prompt,
- )
-
-
-def get_enabled_issue_codes(config: ImbueVerifyConfig) -> set[IssueCode]:
- all_issue_code_values = {item.value for item in IssueCode}
- explicitly_enabled = config.enabled_issue_codes or tuple()
- explicitly_disabled = config.disabled_issue_codes or tuple()
- for code in explicitly_enabled + explicitly_disabled:
- if code not in all_issue_code_values:
- raise ValueError(f"Bad config: unknown issue code: {code}")
- possibly_enabled_values = set(explicitly_enabled) if len(explicitly_enabled) > 0 else set(v for v in IssueCode)
- disabled_values = set(explicitly_disabled)
- return possibly_enabled_values - disabled_values
diff --git a/imbue_tools/imbue_tools/types/vet_config.py b/imbue_tools/imbue_tools/types/vet_config.py
@@ -0,0 +1,100 @@
+from pathlib import Path
+
+from imbue_core.agents.configs import LanguageModelGenerationConfig
+from imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
+from imbue_core.data_types import IssueCode
+from imbue_core.pydantic_serialization import SerializableModel
+
+DEFAULT_CONFIDENCE_THRESHOLD = 0.8
+
+
+class VetConfig(SerializableModel):
+ """Configuration for the vet system."""
+
+ # If none, all registered identifiers are used.
+ # Otherwise, only the identifiers in this tuple are used.
+ enabled_identifiers: tuple[str, ...] | None = None
+
+ # Issue identifiers that are disabled are never used.
+ disabled_identifiers: tuple[str, ...] | None = None
+
+ # Similar to the above, but for reporting specific types of issues.
+ # (Use the values from the vet.data_types.IssueCode enum.)
+ enabled_issue_codes: tuple[IssueCode, ...] | None = None
+ disabled_issue_codes: tuple[IssueCode, ...] | None = ()
+
+ # Todo: Different models for different issue identifiers
+ language_model_generation_config: LanguageModelGenerationConfig = LanguageModelGenerationConfig(
+ model_name=AnthropicModelName.CLAUDE_4_5_HAIKU_2025_10_01
+ )
+ max_identifier_spend_dollars: float = 5.0
+ max_output_tokens: int = 20000
+ enable_parallel_agentic_issue_identification: bool = False
+ max_identify_workers: int | None = None
+ temperature: float = 0.5
+
+ # If True, apply an additional LLM-based filtering stage, where each identified issue is evaluated
+ # according to a number of quality criteria. Only issues that pass the evaluation are returned.
+ filter_issues: bool = True
+ filter_issues_through_llm_evaluator: bool = True
+ filter_issues_below_confidence: float | None = DEFAULT_CONFIDENCE_THRESHOLD
+
+ enable_deduplication: bool = True
+ enable_collation: bool = True
+
+ # If True, we attempt to cache the full prompts including specific inputs with the LLM provider.
+ # There can be an additional cost for such a cache write, but it can help save cost in evaluation
+ # contexts (such as black_box_evals) where the same inputs are being evaluated multiple times.
+ cache_full_prompt: bool = False
+
+ extra_context: str | None = None
+
+ @classmethod
+ def build(
+ cls,
+ language_model_name: str | None = None,
+ language_model_cache_path: Path | None = None,
+ enabled_identifiers: tuple[str, ...] | None = None,
+ enable_parallel_agentic_issue_identification: bool = False,
+ max_identify_workers: int | None = None,
+ filter_issues: bool = True,
+ filter_issues_below_confidence: float | None = DEFAULT_CONFIDENCE_THRESHOLD,
+ enable_deduplication: bool = True,
+ enable_collation: bool = True,
+ enabled_issue_codes: tuple[IssueCode, ...] | None = None,
+ temperature: float = 0.5,
+ retry_jitter_factor: float = 0.0,
+ cache_full_prompt: bool = False,
+ ) -> "VetConfig":
+ if not language_model_name:
+ language_model_name = AnthropicModelName.CLAUDE_4_5_HAIKU_2025_10_01
+ language_model_generation_config = LanguageModelGenerationConfig(
+ model_name=language_model_name,
+ cache_path=language_model_cache_path,
+ retry_jitter_factor=retry_jitter_factor,
+ )
+ return cls(
+ language_model_generation_config=language_model_generation_config,
+ enabled_identifiers=enabled_identifiers,
+ enable_parallel_agentic_issue_identification=enable_parallel_agentic_issue_identification,
+ max_identify_workers=max_identify_workers,
+ filter_issues=filter_issues,
+ filter_issues_below_confidence=filter_issues_below_confidence,
+ enable_deduplication=enable_deduplication,
+ enable_collation=enable_collation,
+ enabled_issue_codes=enabled_issue_codes,
+ temperature=temperature,
+ cache_full_prompt=cache_full_prompt,
+ )
+
+
+def get_enabled_issue_codes(config: VetConfig) -> set[IssueCode]:
+ all_issue_code_values = {item.value for item in IssueCode}
+ explicitly_enabled = config.enabled_issue_codes or tuple()
+ explicitly_disabled = config.disabled_issue_codes or tuple()
+ for code in explicitly_enabled + explicitly_disabled:
+ if code not in all_issue_code_values:
+ raise ValueError(f"Bad config: unknown issue code: {code}")
+ possibly_enabled_values = set(explicitly_enabled) if len(explicitly_enabled) > 0 else set(v for v in IssueCode)
+ disabled_values = set(explicitly_disabled)
+ return possibly_enabled_values - disabled_values
diff --git a/imbue_tools/pyproject.toml b/imbue_tools/pyproject.toml
@@ -29,7 +29,7 @@ requires-python = ">=3.11"
[project.optional-dependencies]
test = [
- "imbue-verify",
+ "vet",
]
[tool.setuptools]
@@ -40,4 +40,4 @@ include = ["imbue_tools*"]
[tool.uv.sources]
imbue_core = { path = "../imbue_core", editable = true }
-imbue-verify = { path = "..", editable = true }
+vet = { path = "..", editable = true }
diff --git a/imbue_tools/uv.lock b/imbue_tools/uv.lock
@@ -210,15 +210,6 @@ wheels = [
]
[[package]]
-name = "backoff"
-version = "2.2.1"
-source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/47/d7/5bbeb12c44d7c4f2fb5b56abce497eb5ed9f34d85701de869acedd602619/backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba", size = 17001, upload-time = "2022-10-05T19:19:32.061Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148, upload-time = "2022-10-05T19:19:30.546Z" },
-]
-
-[[package]]
name = "black"
version = "25.12.0"
source = { registry = "https://pypi.org/simple" }
@@ -872,7 +863,6 @@ dependencies = [
{ name = "loguru" },
{ name = "openai" },
{ name = "pathspec" },
- { name = "posthog" },
{ name = "prometheus-client" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
@@ -884,7 +874,6 @@ dependencies = [
{ name = "pytest-asyncio" },
{ name = "pytest-mock" },
{ name = "python-gitlab" },
- { name = "sentry-sdk" },
{ name = "syrupy" },
{ name = "tblib" },
{ name = "tenacity" },
@@ -913,7 +902,6 @@ requires-dist = [
{ name = "loguru" },
{ name = "openai", specifier = ">=1.79.0" },
{ name = "pathspec" },
- { name = "posthog", specifier = "==5.4.0" },
{ name = "prometheus-client", specifier = ">=0.20.0" },
{ name = "pydantic", specifier = ">=2.11.4" },
{ name = "pydantic-settings" },
@@ -925,7 +913,6 @@ requires-dist = [
{ name = "pytest-asyncio" },
{ name = "pytest-mock" },
{ name = "python-gitlab", specifier = ">=4.5.0" },
- { name = "sentry-sdk" },
{ name = "syrupy" },
{ name = "tblib", specifier = "==2.0.0" },
{ name = "tenacity", specifier = ">=8.2.2" },
@@ -965,7 +952,7 @@ dependencies = [
[package.optional-dependencies]
test = [
- { name = "imbue-verify" },
+ { name = "vet" },
]
[package.metadata]
@@ -974,7 +961,6 @@ requires-dist = [
{ name = "async-lru" },
{ name = "attrs" },
{ name = "imbue-core", editable = "../imbue_core" },
- { name = "imbue-verify", marker = "extra == 'test'", editable = "../" },
{ name = "jinja2" },
{ name = "libcst" },
{ name = "loguru" },
@@ -987,43 +973,11 @@ requires-dist = [
{ name = "python-gitlab" },
{ name = "requests" },
{ name = "syrupy" },
+ { name = "vet", marker = "extra == 'test'", editable = "../" },
]
provides-extras = ["test"]
[[package]]
-name = "imbue-verify"
-version = "0.1.0"
-source = { editable = "../" }
-dependencies = [
- { name = "aiohttp" },
- { name = "click" },
- { name = "imbue-core" },
- { name = "imbue-tools" },
- { name = "jinja2" },
- { name = "loguru" },
- { name = "pydantic" },
- { name = "pygments" },
- { name = "pytest" },
- { name = "syrupy" },
- { name = "together" },
-]
-
-[package.metadata]
-requires-dist = [
- { name = "aiohttp", specifier = ">=3.8.0" },
- { name = "click" },
- { name = "imbue-core", editable = "../imbue_core" },
- { name = "imbue-tools", editable = "." },
- { name = "jinja2" },
- { name = "loguru" },
- { name = "pydantic" },
- { name = "pygments", specifier = ">=2.0.0" },
- { name = "pytest" },
- { name = "syrupy" },
- { name = "together", specifier = ">=1.5.35" },
-]
-
-[[package]]
name = "iniconfig"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
@@ -1693,22 +1647,6 @@ wheels = [
]
[[package]]
-name = "posthog"
-version = "5.4.0"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "backoff" },
- { name = "distro" },
- { name = "python-dateutil" },
- { name = "requests" },
- { name = "six" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/48/20/60ae67bb9d82f00427946218d49e2e7e80fb41c15dc5019482289ec9ce8d/posthog-5.4.0.tar.gz", hash = "sha256:701669261b8d07cdde0276e5bc096b87f9e200e3b9589c5ebff14df658c5893c", size = 88076, upload-time = "2025-06-20T23:19:23.485Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/4f/98/e480cab9a08d1c09b1c59a93dade92c1bb7544826684ff2acbfd10fcfbd4/posthog-5.4.0-py3-none-any.whl", hash = "sha256:284dfa302f64353484420b52d4ad81ff5c2c2d1d607c4e2db602ac72761831bd", size = 105364, upload-time = "2025-06-20T23:19:22.001Z" },
-]
-
-[[package]]
name = "prometheus-client"
version = "0.24.1"
source = { registry = "https://pypi.org/simple" }
@@ -2500,19 +2438,6 @@ wheels = [
]
[[package]]
-name = "sentry-sdk"
-version = "2.51.0"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "certifi" },
- { name = "urllib3" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/6f/9f/094bbb6be5cf218ab6712c6528310687f3d3fe8818249fcfe1d74192f7c5/sentry_sdk-2.51.0.tar.gz", hash = "sha256:b89d64577075fd8c13088bc3609a2ce77a154e5beb8cba7cc16560b0539df4f7", size = 407447, upload-time = "2026-01-28T10:29:50.962Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/a0/da/df379404d484ca9dede4ad8abead5de828cdcff35623cd44f0351cf6869c/sentry_sdk-2.51.0-py2.py3-none-any.whl", hash = "sha256:e21016d318a097c2b617bb980afd9fc737e1efc55f9b4f0cdc819982c9717d5f", size = 431426, upload-time = "2026-01-28T10:29:48.868Z" },
-]
-
-[[package]]
name = "shellingham"
version = "1.5.4"
source = { registry = "https://pypi.org/simple" }
@@ -2806,6 +2731,61 @@ wheels = [
]
[[package]]
+name = "vet"
+version = "0.1.0"
+source = { editable = "../" }
+dependencies = [
+ { name = "aiohttp" },
+ { name = "click" },
+ { name = "imbue-core" },
+ { name = "imbue-tools" },
+ { name = "jinja2" },
+ { name = "loguru" },
+ { name = "pydantic" },
+ { name = "pygments" },
+ { name = "pytest" },
+ { name = "syrupy" },
+ { name = "together" },
+ { name = "vet-types" },
+]
+
+[package.metadata]
+requires-dist = [
+ { name = "aiohttp", specifier = ">=3.8.0" },
+ { name = "click" },
+ { name = "imbue-core", editable = "../imbue_core" },
+ { name = "imbue-tools", editable = "." },
+ { name = "jinja2" },
+ { name = "loguru" },
+ { name = "pydantic" },
+ { name = "pygments", specifier = ">=2.0.0" },
+ { name = "pytest" },
+ { name = "syrupy" },
+ { name = "together", specifier = ">=1.5.35" },
+ { name = "vet-types", editable = "../vet_types" },
+]
+
+[package.metadata.requires-dev]
+dev = [{ name = "black" }]
+
+[[package]]
+name = "vet-types"
+version = "0.1.0"
+source = { editable = "../vet_types" }
+dependencies = [
+ { name = "imbue-core" },
+ { name = "pydantic" },
+ { name = "typeid-python" },
+]
+
+[package.metadata]
+requires-dist = [
+ { name = "imbue-core" },
+ { name = "pydantic" },
+ { name = "typeid-python" },
+]
+
+[[package]]
name = "websockets"
version = "15.0.1"
source = { registry = "https://pypi.org/simple" }
diff --git a/imbue_verify/api.py b/imbue_verify/api.py
@@ -1,125 +0,0 @@
-"""Public API for imbue_verify.
-
-This module provides functions to identify issues in code changes. Issue identifiers are pieces of logic capable of finding issues in code.
-By default, imbue_verify runs all registered issue identifiers and returns all found issues.
-"""
-
-from pathlib import Path
-
-from loguru import logger
-
-from imbue_core.data_types import IdentifiedVerifyIssue
-from imbue_core.data_types import IssueIdentificationDebugInfo
-from vet_types.messages import ConversationMessageUnion
-from imbue_tools.get_conversation_history.get_conversation_history import (
- ConversationLoadingError,
-)
-from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
-from imbue_tools.repo_utils.project_context import LazyProjectContext
-from imbue_tools.repo_utils.project_context import ProjectContext
-from imbue_tools.types.imbue_verify_config import ImbueVerifyConfig
-from imbue_tools.util_prompts.goal_from_conversation import get_goal_from_conversation
-from imbue_verify.issue_identifiers import registry
-from imbue_verify.issue_identifiers.utils import ReturnCapturingGenerator
-from imbue_verify.repo_utils import IMBUE_VERIFY_MAX_PROMPT_TOKENS
-from imbue_verify.repo_utils import get_code_to_check
-
-
-def get_issues_with_raw_responses(
- base_commit: str,
- diff: str,
- diff_no_binary: str,
- goal: str,
- config: ImbueVerifyConfig,
- repo_path: Path,
- conversation_history: tuple[ConversationMessageUnion, ...] | None = None,
-) -> tuple[tuple[IdentifiedVerifyIssue, ...], IssueIdentificationDebugInfo, ProjectContext]:
- if not goal or not goal.strip():
- logger.info("No goal was provided, generating one from conversation history")
- # should be not None and not empty
- if conversation_history:
- try:
- # TODO: we use the imbue verify config here, but we may want to configure this separately
- goal = get_goal_from_conversation(conversation_history, config.language_model_generation_config)
- logger.info("Generated goal from conversation history: {}", goal)
- except Exception as e:
- raise ConversationLoadingError(
- f"No goal was provided and generating one from conversation history failed: {e}"
- )
- else:
- # TODO: Consider which CLI options we should show this for (quiet, normal, verbose).
- logger.info("No goal or conversation history provided, only goal-independent identifiers will run")
- goal = ""
-
- lm_config = config.language_model_generation_config
- if diff_no_binary:
- diff_no_binary_tokens = lm_config.count_tokens(diff_no_binary)
- else:
- diff_no_binary_tokens = 0
-
- project_context = LazyProjectContext.build(
- base_commit,
- diff,
- language_model_name=lm_config.model_name,
- repo_path=repo_path,
- # This needs to account for the imbue_verify prompt, as well as the max_tokens output tokens.
- tokens_to_reserve=IMBUE_VERIFY_MAX_PROMPT_TOKENS + diff_no_binary_tokens + config.max_output_tokens,
- context_window=lm_config.get_max_context_length(),
- is_custom_model=lm_config.is_custom_model(),
- )
-
- identifier_inputs = IdentifierInputs(
- maybe_diff=diff_no_binary or None,
- maybe_goal=goal,
- maybe_conversation_history=conversation_history,
- )
-
- results_generator = registry.run(
- identifier_inputs=identifier_inputs,
- project_context=project_context,
- config=config,
- )
-
- issues = []
- results_generator_with_capture = ReturnCapturingGenerator(results_generator)
- for result in results_generator_with_capture:
- if result.passes_filtration:
- issues.append(result.issue)
- issue_identification_debug_info = results_generator_with_capture.return_value
-
- return tuple(issues), issue_identification_debug_info, project_context
-
-
-def find_issues(
- repo_path: Path,
- relative_to: str,
- goal: str,
- config: ImbueVerifyConfig,
- conversation_history: tuple[ConversationMessageUnion, ...] | None = None,
-) -> tuple[IdentifiedVerifyIssue, ...]:
- logger.info(
- "Finding issues in {repo_path} relative to commit hash {relative_to}",
- repo_path=repo_path,
- relative_to=relative_to,
- )
-
- base_commit, diff, diff_no_binary = get_code_to_check(relative_to, repo_path)
- if not diff.strip():
- logger.info(
- "No code changes detected in repo {repo_path} since the specified relative_to commit {relative_to}, skipping issue identification",
- repo_path=repo_path,
- relative_to=relative_to,
- )
- # No code changes detected since the specified relative_to commit, so no issues to find.
- return tuple()
-
- issues, _, _ = get_issues_with_raw_responses(
- base_commit=base_commit,
- diff=diff,
- diff_no_binary=diff_no_binary,
- goal=goal,
- config=config,
- repo_path=repo_path,
- conversation_history=conversation_history,
- )
- return issues
diff --git a/imbue_verify/cli/config/cli_config_test.py b/imbue_verify/cli/config/cli_config_test.py
@@ -1,446 +0,0 @@
-from __future__ import annotations
-
-import argparse
-import os
-from pathlib import Path
-from unittest.mock import patch
-
-import pytest
-
-from imbue_verify.cli.config.cli_config_schema import CLI_DEFAULTS
-from imbue_verify.cli.config.cli_config_schema import CliConfigPreset
-from imbue_verify.cli.config.cli_config_schema import merge_presets
-from imbue_verify.cli.config.cli_config_schema import parse_cli_config_from_dict
-from imbue_verify.cli.config.loader import ConfigLoadError
-from imbue_verify.cli.config.loader import _load_cli_config_file
-from imbue_verify.cli.config.loader import get_cli_config_file_paths
-from imbue_verify.cli.config.loader import get_config_preset
-from imbue_verify.cli.config.loader import load_cli_config
-from imbue_verify.cli.main import apply_config_preset
-
-
-def test_parse_cli_config_from_dict_parses_single_config() -> None:
- data = {
- "ci": {
- "confidence_threshold": 0.9,
- "max_workers": 4,
- "quiet": True,
- }
- }
-
- result = parse_cli_config_from_dict(data)
-
- assert "ci" in result
- assert result["ci"].confidence_threshold == 0.9
- assert result["ci"].max_workers == 4
- assert result["ci"].quiet is True
-
-
-def test_parse_cli_config_from_dict_parses_multiple_configs() -> None:
- data = {
- "ci": {"confidence_threshold": 0.9},
- "strict": {"confidence_threshold": 0.6, "model": "claude-4-sonnet"},
- "default": {},
- }
-
- result = parse_cli_config_from_dict(data)
-
- assert len(result) == 3
- assert result["ci"].confidence_threshold == 0.9
- assert result["strict"].confidence_threshold == 0.6
- assert result["strict"].model == "claude-4-sonnet"
- assert result["default"].confidence_threshold is None
-
-
-def test_parse_cli_config_from_dict_handles_all_fields() -> None:
- data = {
- "full": {
- "goal": "Check for security issues",
- "repo": "/path/to/repo",
- "base_commit": "main",
- "history_loader": "cat history.jsonl",
- "extra_context": ["context1.txt", "context2.txt"],
- "enabled_issue_codes": ["correctness", "style"],
- "disabled_issue_codes": ["minor"],
- "model": "test-model",
- "temperature": 0.7,
- "confidence_threshold": 0.85,
- "max_workers": 8,
- "output": "results.json",
- "output_format": "json",
- "output_fields": ["file", "line", "message"],
- "verbose": True,
- "quiet": False,
- }
- }
-
- result = parse_cli_config_from_dict(data)
-
- preset = result["full"]
- assert preset.goal == "Check for security issues"
- assert preset.repo == "/path/to/repo"
- assert preset.base_commit == "main"
- assert preset.history_loader == "cat history.jsonl"
- assert preset.extra_context == ["context1.txt", "context2.txt"]
- assert preset.enabled_issue_codes == ["correctness", "style"]
- assert preset.disabled_issue_codes == ["minor"]
- assert preset.model == "test-model"
- assert preset.temperature == 0.7
- assert preset.confidence_threshold == 0.85
- assert preset.max_workers == 8
- assert preset.output == "results.json"
- assert preset.output_format == "json"
- assert preset.output_fields == ["file", "line", "message"]
- assert preset.verbose is True
- assert preset.quiet is False
-
-
-def test_merge_presets_override_takes_precedence() -> None:
- base = CliConfigPreset(confidence_threshold=0.8, max_workers=2, model="base-model")
- override = CliConfigPreset(confidence_threshold=0.9, max_workers=None, model="override-model")
-
- result = merge_presets(base, override)
-
- assert result.confidence_threshold == 0.9
- assert result.max_workers == 2
- assert result.model == "override-model"
-
-
-def test_merge_presets_preserves_base_when_override_is_none() -> None:
- base = CliConfigPreset(
- confidence_threshold=0.8,
- max_workers=4,
- model="base-model",
- verbose=True,
- )
- override = CliConfigPreset()
-
- result = merge_presets(base, override)
-
- assert result.confidence_threshold == 0.8
- assert result.max_workers == 4
- assert result.model == "base-model"
- assert result.verbose is True
-
-
-def test_cli_defaults_and_cli_config_preset_have_same_fields() -> None:
- """Verify CliDefaults and CliConfigPreset define the same fields.
-
- These two models exist for different purposes:
- - CliDefaults: Holds actual default values for CLI arguments (e.g., temperature=0.0)
- - CliConfigPreset: Used for config file presets where None means "not specified"
-
- They must have identical field names to ensure presets can override any default.
- This test catches drift if a field is added to one model but not the other.
- """
- from imbue_verify.cli.config.cli_config_schema import CliDefaults
-
- defaults_fields = set(CliDefaults.model_fields.keys())
- preset_fields = set(CliConfigPreset.model_fields.keys())
-
- assert defaults_fields == preset_fields, (
- f"Field mismatch between CliDefaults and CliConfigPreset.\n"
- f"Only in CliDefaults: {defaults_fields - preset_fields}\n"
- f"Only in CliConfigPreset: {preset_fields - defaults_fields}"
- )
-
-
-def test_get_cli_config_file_paths_returns_global_path(tmp_path: Path) -> None:
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path)}):
- paths = get_cli_config_file_paths(repo_path=None)
-
- assert len(paths) == 1
- assert paths[0] == tmp_path / "imbue-verify" / "config.toml"
-
-
-def test_get_cli_config_file_paths_includes_project_path(tmp_path: Path) -> None:
- repo_path = tmp_path / "repo"
- repo_path.mkdir()
-
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "xdg")}):
- paths = get_cli_config_file_paths(repo_path=repo_path)
-
- assert len(paths) == 2
- assert paths[0] == tmp_path / "xdg" / "imbue-verify" / "config.toml"
- assert paths[1] == repo_path / "imbue-verify.toml"
-
-
-def test_get_cli_config_file_paths_finds_git_root(tmp_path: Path) -> None:
- git_root = tmp_path / "repo"
- git_root.mkdir()
- (git_root / ".git").mkdir()
- subdir = git_root / "src" / "deep"
- subdir.mkdir(parents=True)
-
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "xdg")}):
- paths = get_cli_config_file_paths(repo_path=subdir)
-
- assert paths[1] == git_root / "imbue-verify.toml"
-
-
-def test_load_cli_config_file_loads_valid_toml(tmp_path: Path) -> None:
- config_file = tmp_path / "config.toml"
- config_file.write_text(
- """
-[ci]
-confidence_threshold = 0.9
-max_workers = 4
-quiet = true
-
-[strict]
-confidence_threshold = 0.6
-model = "claude-4-sonnet"
-"""
- )
-
- result = _load_cli_config_file(config_file)
-
- assert "ci" in result
- assert result["ci"].confidence_threshold == 0.9
- assert result["ci"].max_workers == 4
- assert result["ci"].quiet is True
- assert "strict" in result
- assert result["strict"].model == "claude-4-sonnet"
-
-
-def test_load_cli_config_file_raises_on_invalid_toml(tmp_path: Path) -> None:
- config_file = tmp_path / "config.toml"
- config_file.write_text("not = valid = toml")
-
- with pytest.raises(ConfigLoadError) as exc_info:
- _load_cli_config_file(config_file)
-
- assert "Invalid TOML" in str(exc_info.value)
-
-
-def test_load_cli_config_file_raises_on_invalid_schema(tmp_path: Path) -> None:
- config_file = tmp_path / "config.toml"
- config_file.write_text(
- """
-[ci]
-confidence_threshold = "not-a-float"
-"""
- )
-
- with pytest.raises(ConfigLoadError) as exc_info:
- _load_cli_config_file(config_file)
-
- assert "Invalid configuration" in str(exc_info.value)
-
-
-def test_load_cli_config_file_raises_on_unknown_field(tmp_path: Path) -> None:
- config_file = tmp_path / "config.toml"
- config_file.write_text(
- """
-[ci]
-unknown_field = "value"
-"""
- )
-
- with pytest.raises(ConfigLoadError) as exc_info:
- _load_cli_config_file(config_file)
-
- assert "Invalid configuration" in str(exc_info.value)
-
-
-def test_load_cli_config_returns_empty_when_no_files_exist(tmp_path: Path) -> None:
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "nonexistent")}):
- result = load_cli_config(repo_path=tmp_path)
-
- assert result == {}
-
-
-def test_load_cli_config_loads_single_file(tmp_path: Path) -> None:
- repo_path = tmp_path / "repo"
- repo_path.mkdir()
- config_file = repo_path / "imbue-verify.toml"
- config_file.write_text(
- """
-[ci]
-confidence_threshold = 0.9
-"""
- )
-
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "nonexistent")}):
- result = load_cli_config(repo_path=repo_path)
-
- assert "ci" in result
- assert result["ci"].confidence_threshold == 0.9
-
-
-def test_load_cli_config_merges_global_and_project(tmp_path: Path) -> None:
- xdg_config = tmp_path / "xdg"
- (xdg_config / "imbue-verify").mkdir(parents=True)
- global_config = xdg_config / "imbue-verify" / "config.toml"
- global_config.write_text(
- """
-[ci]
-confidence_threshold = 0.8
-max_workers = 2
-
-[global-only]
-model = "global-model"
-"""
- )
-
- repo_path = tmp_path / "repo"
- repo_path.mkdir()
- project_config = repo_path / "imbue-verify.toml"
- project_config.write_text(
- """
-[ci]
-confidence_threshold = 0.9
-
-[project-only]
-model = "project-model"
-"""
- )
-
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(xdg_config)}):
- result = load_cli_config(repo_path=repo_path)
-
- assert result["ci"].confidence_threshold == 0.9
- assert result["ci"].max_workers == 2
-
- assert "global-only" in result
- assert result["global-only"].model == "global-model"
- assert "project-only" in result
- assert result["project-only"].model == "project-model"
-
-
-def test_get_config_preset_returns_preset() -> None:
- configs = {
- "ci": CliConfigPreset(confidence_threshold=0.9),
- "strict": CliConfigPreset(confidence_threshold=0.6),
- }
-
- result = get_config_preset("ci", configs)
-
- assert result.confidence_threshold == 0.9
-
-
-def test_get_config_preset_raises_on_unknown_with_available() -> None:
- configs = {
- "ci": CliConfigPreset(),
- "strict": CliConfigPreset(),
- }
-
- with pytest.raises(ConfigLoadError) as exc_info:
- get_config_preset("unknown", configs)
-
- error_msg = str(exc_info.value)
- assert "unknown" in error_msg
- assert "ci" in error_msg
- assert "strict" in error_msg
-
-
-def test_get_config_preset_raises_on_unknown_with_no_configs(tmp_path: Path) -> None:
- configs: dict[str, CliConfigPreset] = {}
- repo_path = tmp_path / "repo"
- repo_path.mkdir()
-
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "xdg")}):
- with pytest.raises(ConfigLoadError) as exc_info:
- get_config_preset("unknown", configs, repo_path)
-
- error_msg = str(exc_info.value)
- assert "unknown" in error_msg
- assert "No configuration files found" in error_msg
- # Verify the error message contains dynamically generated paths with labels
- assert f"{tmp_path / 'xdg' / 'imbue-verify' / 'config.toml'} (global)" in error_msg
- assert f"{repo_path / 'imbue-verify.toml'} (project)" in error_msg
-
-
-def _create_default_args() -> argparse.Namespace:
- return argparse.Namespace(
- model=CLI_DEFAULTS.model,
- temperature=CLI_DEFAULTS.temperature,
- confidence_threshold=CLI_DEFAULTS.confidence_threshold,
- max_workers=CLI_DEFAULTS.max_workers,
- output_format=CLI_DEFAULTS.output_format,
- output_fields=CLI_DEFAULTS.output_fields,
- verbose=CLI_DEFAULTS.verbose,
- quiet=CLI_DEFAULTS.quiet,
- enabled_issue_codes=CLI_DEFAULTS.enabled_issue_codes,
- disabled_issue_codes=CLI_DEFAULTS.disabled_issue_codes,
- )
-
-
-def test_apply_config_preset_applies_all_values() -> None:
- args = _create_default_args()
- preset = CliConfigPreset(
- model="preset-model",
- temperature=0.7,
- confidence_threshold=0.9,
- max_workers=4,
- output_format="json",
- output_fields=["file", "line"],
- verbose=True,
- quiet=False,
- )
-
- result = apply_config_preset(args, preset)
-
- assert result.model == "preset-model"
- assert result.temperature == 0.7
- assert result.confidence_threshold == 0.9
- assert result.max_workers == 4
- assert result.output_format == "json"
- assert result.output_fields == ["file", "line"]
- assert result.verbose is True
-
-
-def test_apply_config_preset_cli_args_take_precedence() -> None:
- args = argparse.Namespace(
- model="cli-model",
- temperature=0.0,
- confidence_threshold=0.95,
- max_workers=2,
- output_format="text",
- output_fields=None,
- verbose=False,
- quiet=False,
- enabled_issue_codes=None,
- disabled_issue_codes=None,
- )
- preset = CliConfigPreset(
- model="preset-model",
- temperature=0.3,
- confidence_threshold=0.6,
- max_workers=8,
- )
-
- result = apply_config_preset(args, preset)
-
- assert result.model == "cli-model"
- assert result.confidence_threshold == 0.95
-
- assert result.temperature == 0.3
- assert result.max_workers == 8
-
-
-def test_apply_config_preset_leaves_defaults_when_preset_is_none() -> None:
- args = _create_default_args()
- preset = CliConfigPreset()
-
- result = apply_config_preset(args, preset)
-
- assert result.model is None
- assert result.temperature == 0.0
- assert result.confidence_threshold == 0.8
- assert result.max_workers == 2
-
-
-def test_apply_config_preset_handles_issue_codes() -> None:
- args = _create_default_args()
- preset = CliConfigPreset(
- enabled_issue_codes=["incorrect_function_implementation"],
- disabled_issue_codes=["bad_naming"],
- )
-
- result = apply_config_preset(args, preset)
-
- assert len(result.enabled_issue_codes) == 1
- assert result.enabled_issue_codes[0].value == "incorrect_function_implementation"
- assert len(result.disabled_issue_codes) == 1
- assert result.disabled_issue_codes[0].value == "bad_naming"
diff --git a/imbue_verify/cli/config/loader.py b/imbue_verify/cli/config/loader.py
@@ -1,210 +0,0 @@
-from __future__ import annotations
-
-import os
-import tomllib
-from pathlib import Path
-
-from pydantic import ValidationError
-
-from imbue_core.agents.configs import LanguageModelGenerationConfig
-from imbue_core.agents.configs import OpenAICompatibleModelConfig
-from imbue_core.agents.llm_apis.common import get_model_max_output_tokens
-from imbue_verify.cli.config.cli_config_schema import CliConfigPreset
-from imbue_verify.cli.config.cli_config_schema import merge_presets
-from imbue_verify.cli.config.cli_config_schema import parse_cli_config_from_dict
-from imbue_verify.cli.config.schema import ModelsConfig
-from imbue_verify.cli.config.schema import ProviderConfig
-
-
-class ConfigLoadError(Exception):
- pass
-
-
-class MissingAPIKeyError(Exception):
- def __init__(self, env_var: str, provider_name: str, model_id: str) -> None:
- self.env_var = env_var
- self.provider_name = provider_name
- self.model_id = model_id
- super().__init__(
- f"API key not found: environment variable '{env_var}' is not set. "
- + f"This is required for model '{model_id}' from provider '{provider_name}'."
- )
-
-
-def get_xdg_config_home() -> Path:
- xdg_config = os.environ.get("XDG_CONFIG_HOME")
- if xdg_config:
- return Path(xdg_config)
- return Path.home() / ".config"
-
-
-def find_git_repo_root(start_path: Path) -> Path | None:
- current = start_path.resolve()
- while current != current.parent:
- if (current / ".git").exists():
- return current
- current = current.parent
- if (current / ".git").exists():
- return current
- return None
-
-
-def _get_config_file_paths(
- global_subpath: str,
- global_filename: str,
- project_filename: str,
- repo_path: Path | None = None,
-) -> list[Path]:
- paths = [get_xdg_config_home() / global_subpath / global_filename]
-
- if repo_path:
- git_root = find_git_repo_root(repo_path)
- root = git_root if git_root else repo_path
- paths.append(root / project_filename)
-
- return paths
-
-
-def get_config_file_paths(repo_path: Path | None = None) -> list[Path]:
- return _get_config_file_paths("imbue", "models.json", "models.json", repo_path)
-
-
-def _load_single_config_file(config_path: Path) -> ModelsConfig:
- try:
- with open(config_path) as f:
- return ModelsConfig.model_validate_json(f.read())
- except ValidationError as e:
- raise ConfigLoadError(f"Invalid configuration in {config_path}: {e}") from e
- except OSError as e:
- raise ConfigLoadError(f"Cannot read {config_path}: {e}") from e
-
-
-def load_models_config(repo_path: Path | None = None) -> ModelsConfig:
- merged_providers: dict[str, ProviderConfig] = {}
-
- for config_path in get_config_file_paths(repo_path):
- if config_path.exists():
- config = _load_single_config_file(config_path)
- merged_providers.update(config.providers)
-
- return ModelsConfig(providers=merged_providers)
-
-
-def get_user_defined_model_ids(config: ModelsConfig) -> set[str]:
- model_ids: set[str] = set()
- for provider in config.providers.values():
- model_ids.update(provider.models.keys())
- return model_ids
-
-
-def get_provider_for_model(model_id: str, config: ModelsConfig) -> ProviderConfig | None:
- for provider in config.providers.values():
- if model_id in provider.models:
- return provider
- return None
-
-
-def validate_api_key_for_model(model_id: str, config: ModelsConfig) -> None:
- provider = get_provider_for_model(model_id, config)
- if provider is None:
- return
-
- api_key_env = provider.api_key_env
- if api_key_env is None:
- return
-
- api_key = os.environ.get(api_key_env, "")
- if not api_key:
- provider_name = provider.name or "unknown provider"
- raise MissingAPIKeyError(
- env_var=api_key_env,
- provider_name=provider_name,
- model_id=model_id,
- )
-
-
-def get_models_by_provider_from_config(config: ModelsConfig) -> dict[str, list[str]]:
- result: dict[str, list[str]] = {}
- for provider_id, provider in config.providers.items():
- display_name = provider.name or provider_id
- result[display_name] = list(provider.models.keys())
- return result
-
-
-def get_max_output_tokens_for_model(model_id: str, config: ModelsConfig) -> int | None:
- provider = get_provider_for_model(model_id, config)
- if provider is not None:
- return provider.models[model_id].max_output_tokens
-
- try:
- return get_model_max_output_tokens(model_id)
- except Exception:
- return None
-
-
-def build_language_model_config(model_id: str, user_config: ModelsConfig) -> LanguageModelGenerationConfig:
- provider = get_provider_for_model(model_id, user_config)
- if provider is None:
- return LanguageModelGenerationConfig(model_name=model_id)
-
- model_config = provider.models[model_id]
- actual_model_name = model_config.model_id or model_id
-
- return OpenAICompatibleModelConfig(
- model_name=actual_model_name,
- custom_base_url=provider.base_url,
- custom_api_key_env=provider.api_key_env or "",
- custom_context_window=model_config.context_window,
- custom_max_output_tokens=model_config.max_output_tokens,
- )
-
-
-def get_cli_config_file_paths(repo_path: Path | None = None) -> list[Path]:
- return _get_config_file_paths("imbue-verify", "config.toml", "imbue-verify.toml", repo_path)
-
-
-def _load_cli_config_file(config_path: Path) -> dict[str, CliConfigPreset]:
- try:
- with open(config_path, "rb") as f:
- data = tomllib.load(f)
- return parse_cli_config_from_dict(data)
- except tomllib.TOMLDecodeError as e:
- raise ConfigLoadError(f"Invalid TOML in {config_path}: {e}") from e
- except ValidationError as e:
- raise ConfigLoadError(f"Invalid configuration in {config_path}: {e}") from e
- except OSError as e:
- raise ConfigLoadError(f"Cannot read {config_path}: {e}") from e
-
-
-def load_cli_config(repo_path: Path | None = None) -> dict[str, CliConfigPreset]:
- merged_configs: dict[str, CliConfigPreset] = {}
-
- for config_path in get_cli_config_file_paths(repo_path):
- if config_path.exists():
- file_configs = _load_cli_config_file(config_path)
- for name, preset in file_configs.items():
- if name in merged_configs:
- merged_configs[name] = merge_presets(merged_configs[name], preset)
- else:
- merged_configs[name] = preset
-
- return merged_configs
-
-
-def get_config_preset(
- config_name: str,
- cli_configs: dict[str, CliConfigPreset],
- repo_path: Path | None = None,
-) -> CliConfigPreset:
- if config_name not in cli_configs:
- available = sorted(cli_configs.keys())
- if available:
- raise ConfigLoadError(f"Configuration '{config_name}' not found. Available configs: {', '.join(available)}")
- else:
- paths = get_cli_config_file_paths(repo_path)
- paths_list = "\n".join(f" - {p} ({'global' if i == 0 else 'project'})" for i, p in enumerate(paths))
- raise ConfigLoadError(
- f"Configuration '{config_name}' not found.\n\n"
- f"No configuration files found. Create a config at one of these locations:\n{paths_list}"
- )
- return cli_configs[config_name]
diff --git a/imbue_verify/cli/config/loader_test.py b/imbue_verify/cli/config/loader_test.py
@@ -1,377 +0,0 @@
-from __future__ import annotations
-
-import json
-import os
-from pathlib import Path
-from unittest.mock import patch
-
-import pytest
-
-from imbue_verify.cli.config.loader import ConfigLoadError
-from imbue_verify.cli.config.loader import MissingAPIKeyError
-from imbue_verify.cli.config.loader import _load_single_config_file
-from imbue_verify.cli.config.loader import find_git_repo_root
-from imbue_verify.cli.config.loader import get_config_file_paths
-from imbue_verify.cli.config.loader import get_models_by_provider_from_config
-from imbue_verify.cli.config.loader import get_provider_for_model
-from imbue_verify.cli.config.loader import get_user_defined_model_ids
-from imbue_verify.cli.config.loader import get_xdg_config_home
-from imbue_verify.cli.config.loader import load_models_config
-from imbue_verify.cli.config.loader import validate_api_key_for_model
-from imbue_verify.cli.config.schema import ModelConfig
-from imbue_verify.cli.config.schema import ModelsConfig
-from imbue_verify.cli.config.schema import ProviderConfig
-
-
-def test_get_xdg_config_home_uses_env_var(tmp_path: Path) -> None:
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path)}):
- assert get_xdg_config_home() == tmp_path
-
-
-def test_get_xdg_config_home_defaults_to_home_config() -> None:
- with patch.dict(os.environ, {}, clear=True):
- os.environ.pop("XDG_CONFIG_HOME", None)
- result = get_xdg_config_home()
- assert result == Path.home() / ".config"
-
-
-def test_find_git_repo_root_finds_root(tmp_path: Path) -> None:
- git_root = tmp_path / "repo"
- git_root.mkdir()
- (git_root / ".git").mkdir()
- subdir = git_root / "src" / "deep" / "nested"
- subdir.mkdir(parents=True)
-
- result = find_git_repo_root(subdir)
- assert result == git_root
-
-
-def test_find_git_repo_root_returns_none_when_not_in_repo(tmp_path: Path) -> None:
- non_repo = tmp_path / "not_a_repo"
- non_repo.mkdir()
-
- result = find_git_repo_root(non_repo)
- assert result is None
-
-
-def test_get_config_file_paths_returns_global_path(tmp_path: Path) -> None:
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path)}):
- paths = get_config_file_paths(repo_path=None)
- assert len(paths) == 1
- assert paths[0] == tmp_path / "imbue" / "models.json"
-
-
-def test_get_config_file_paths_finds_git_root(tmp_path: Path) -> None:
- xdg_config = tmp_path / "xdg"
- git_root = tmp_path / "repo"
- git_root.mkdir()
- (git_root / ".git").mkdir()
- subdir = git_root / "src" / "submodule"
- subdir.mkdir(parents=True)
-
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(xdg_config)}):
- paths = get_config_file_paths(repo_path=subdir)
- assert len(paths) == 2
- assert paths[0] == xdg_config / "imbue" / "models.json"
- assert paths[1] == git_root / "models.json"
-
-
-def test_load_single_config_file_loads_valid_config(tmp_path: Path) -> None:
- config_file = tmp_path / "models.json"
- config_data = {
- "providers": {
- "test-provider": {
- "name": "Test Provider",
- "api_type": "openai_compatible",
- "base_url": "http://localhost:8080/v1",
- "api_key_env": "TEST_API_KEY",
- "models": {
- "test-model": {
- "model_id": "test-model-v1",
- "context_window": 128000,
- "max_output_tokens": 16384,
- }
- },
- }
- }
- }
- config_file.write_text(json.dumps(config_data))
-
- result = _load_single_config_file(config_file)
-
- assert "test-provider" in result.providers
- provider = result.providers["test-provider"]
- assert provider.name == "Test Provider"
- assert provider.base_url == "http://localhost:8080/v1"
- assert provider.api_key_env == "TEST_API_KEY"
- assert "test-model" in provider.models
- assert provider.models["test-model"].model_id == "test-model-v1"
-
-
-def test_load_single_config_file_raises_on_invalid_json(tmp_path: Path) -> None:
- config_file = tmp_path / "models.json"
- config_file.write_text("not valid json")
-
- with pytest.raises(ConfigLoadError) as exc_info:
- _load_single_config_file(config_file)
- assert "Invalid JSON" in str(exc_info.value)
-
-
-def test_load_single_config_file_raises_on_invalid_schema(tmp_path: Path) -> None:
- config_file = tmp_path / "models.json"
- config_data = {
- "providers": {
- "test-provider": {
- "name": "Test Provider",
- }
- }
- }
- config_file.write_text(json.dumps(config_data))
-
- with pytest.raises(ConfigLoadError) as exc_info:
- _load_single_config_file(config_file)
- assert "Invalid configuration" in str(exc_info.value)
-
-
-def test_load_single_config_file_raises_on_invalid_api_type(tmp_path: Path) -> None:
- config_file = tmp_path / "models.json"
- config_data = {
- "providers": {
- "test-provider": {
- "name": "Test Provider",
- "api_type": "anthropic",
- "base_url": "http://localhost:8080/v1",
- "api_key_env": "TEST_API_KEY",
- "models": {},
- }
- }
- }
- config_file.write_text(json.dumps(config_data))
-
- with pytest.raises(ConfigLoadError) as exc_info:
- _load_single_config_file(config_file)
- assert "Invalid configuration" in str(exc_info.value)
-
-
-def test_load_models_config_returns_empty_when_no_files_exist(tmp_path: Path) -> None:
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "nonexistent")}):
- result = load_models_config(repo_path=tmp_path)
- assert result.providers == {}
-
-
-def test_load_models_config_loads_project_config(tmp_path: Path) -> None:
- repo_path = tmp_path / "repo"
- repo_path.mkdir()
- config_file = repo_path / "models.json"
- config_data = {
- "providers": {
- "project-provider": {
- "base_url": "http://project:8080/v1",
- "api_key_env": "PROJECT_KEY",
- "models": {
- "project-model": {
- "context_window": 128000,
- "max_output_tokens": 16384,
- }
- },
- }
- }
- }
- config_file.write_text(json.dumps(config_data))
-
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "nonexistent")}):
- result = load_models_config(repo_path=repo_path)
-
- assert "project-provider" in result.providers
-
-
-def test_load_models_config_project_overrides_global(tmp_path: Path) -> None:
- xdg_config = tmp_path / "xdg"
- (xdg_config / "imbue").mkdir(parents=True)
- global_config = xdg_config / "imbue" / "models.json"
- global_config.write_text(
- json.dumps(
- {
- "providers": {
- "shared-provider": {
- "name": "Global Name",
- "base_url": "http://global:8080/v1",
- "api_key_env": "GLOBAL_KEY",
- "models": {
- "global-model": {
- "context_window": 128000,
- "max_output_tokens": 16384,
- }
- },
- }
- }
- }
- )
- )
-
- repo_path = tmp_path / "repo"
- repo_path.mkdir()
- project_config = repo_path / "models.json"
- project_config.write_text(
- json.dumps(
- {
- "providers": {
- "shared-provider": {
- "name": "Project Name",
- "base_url": "http://project:8080/v1",
- "api_key_env": "PROJECT_KEY",
- "models": {
- "project-model": {
- "context_window": 128000,
- "max_output_tokens": 16384,
- }
- },
- }
- }
- }
- )
- )
-
- with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(xdg_config)}):
- result = load_models_config(repo_path=repo_path)
-
- assert result.providers["shared-provider"].name == "Project Name"
- assert result.providers["shared-provider"].base_url == "http://project:8080/v1"
-
-
-def test_get_user_defined_model_ids_extracts_all_ids() -> None:
- config = ModelsConfig(
- providers={
- "provider1": ProviderConfig(
- base_url="http://localhost:8080/v1",
- api_key_env="KEY1",
- models={
- "model-a": ModelConfig(context_window=128000, max_output_tokens=16384),
- "model-b": ModelConfig(context_window=128000, max_output_tokens=16384),
- },
- ),
- "provider2": ProviderConfig(
- base_url="http://localhost:8081/v1",
- api_key_env="KEY2",
- models={
- "model-c": ModelConfig(context_window=128000, max_output_tokens=16384),
- },
- ),
- }
- )
-
- result = get_user_defined_model_ids(config)
-
- assert result == {"model-a", "model-b", "model-c"}
-
-
-def test_get_provider_for_model_finds_provider() -> None:
- config = ModelsConfig(
- providers={
- "provider1": ProviderConfig(
- base_url="http://localhost:8080/v1",
- api_key_env="KEY1",
- models={"model-a": ModelConfig(context_window=128000, max_output_tokens=16384)},
- ),
- "provider2": ProviderConfig(
- base_url="http://localhost:8081/v1",
- api_key_env="KEY2",
- models={"model-b": ModelConfig(context_window=128000, max_output_tokens=16384)},
- ),
- }
- )
-
- result = get_provider_for_model("model-b", config)
-
- assert result is not None
- assert result.api_key_env == "KEY2"
-
-
-def test_get_provider_for_model_returns_none_for_unknown() -> None:
- config = ModelsConfig(
- providers={
- "provider1": ProviderConfig(
- base_url="http://localhost:8080/v1",
- api_key_env="KEY1",
- models={"model-a": ModelConfig(context_window=128000, max_output_tokens=16384)},
- ),
- }
- )
-
- result = get_provider_for_model("unknown-model", config)
-
- assert result is None
-
-
-def test_validate_api_key_passes_when_key_is_set() -> None:
- config = ModelsConfig(
- providers={
- "provider1": ProviderConfig(
- name="Test Provider",
- base_url="http://localhost:8080/v1",
- api_key_env="TEST_API_KEY",
- models={"model-a": ModelConfig(context_window=128000, max_output_tokens=16384)},
- ),
- }
- )
-
- with patch.dict(os.environ, {"TEST_API_KEY": "secret-key"}):
- validate_api_key_for_model("model-a", config)
-
-
-def test_validate_api_key_raises_when_key_not_set() -> None:
- config = ModelsConfig(
- providers={
- "provider1": ProviderConfig(
- name="Test Provider",
- base_url="http://localhost:8080/v1",
- api_key_env="MISSING_KEY",
- models={"model-a": ModelConfig(context_window=128000, max_output_tokens=16384)},
- ),
- }
- )
-
- with patch.dict(os.environ, {}, clear=True):
- os.environ.pop("MISSING_KEY", None)
- with pytest.raises(MissingAPIKeyError) as exc_info:
- validate_api_key_for_model("model-a", config)
-
- assert exc_info.value.env_var == "MISSING_KEY"
- assert exc_info.value.model_id == "model-a"
- assert "MISSING_KEY" in str(exc_info.value)
-
-
-def test_validate_api_key_passes_for_unknown_model() -> None:
- config = ModelsConfig(providers={})
- validate_api_key_for_model("unknown-model", config)
-
-
-def test_get_models_by_provider_groups_models() -> None:
- config = ModelsConfig(
- providers={
- "ollama": ProviderConfig(
- name="Ollama Local",
- base_url="http://localhost:11434/v1",
- api_key_env="OLLAMA_KEY",
- models={
- "llama3.2:latest": ModelConfig(context_window=128000, max_output_tokens=16384),
- "qwen:7b": ModelConfig(context_window=32768, max_output_tokens=8192),
- },
- ),
- "openrouter": ProviderConfig(
- base_url="https://openrouter.ai/api/v1",
- api_key_env="OPENROUTER_KEY",
- models={
- "anthropic/claude-3": ModelConfig(context_window=200000, max_output_tokens=16384),
- },
- ),
- }
- )
-
- result = get_models_by_provider_from_config(config)
-
- assert "Ollama Local" in result
- assert set(result["Ollama Local"]) == {"llama3.2:latest", "qwen:7b"}
-
- assert "openrouter" in result
- assert result["openrouter"] == ["anthropic/claude-3"]
diff --git a/imbue_verify/cli/main.py b/imbue_verify/cli/main.py
@@ -1,502 +0,0 @@
-from __future__ import annotations
-
-# The choice to use argparse was primarily driven by the idea that imbue-verify will be called by agents / llms.
-# Given this, we want to have the most standardized outputs possible.
-import argparse
-import json
-import subprocess
-import sys
-from importlib.metadata import version
-from pathlib import Path
-
-from loguru import logger
-
-from imbue_core.data_types import IssueCode
-from imbue_tools.get_conversation_history.get_conversation_history import (
- parse_conversation_history,
-)
-from imbue_tools.types.imbue_verify_config import ImbueVerifyConfig
-from imbue_verify.api import find_issues
-from imbue_verify.cli.config.cli_config_schema import CLI_DEFAULTS
-from imbue_verify.cli.config.cli_config_schema import CliConfigPreset
-from imbue_verify.cli.config.loader import ConfigLoadError
-from imbue_verify.cli.config.loader import build_language_model_config
-from imbue_verify.cli.config.loader import get_cli_config_file_paths
-from imbue_verify.cli.config.loader import get_config_preset
-from imbue_verify.cli.config.loader import get_max_output_tokens_for_model
-from imbue_verify.cli.config.loader import load_cli_config
-from imbue_verify.cli.config.loader import load_models_config
-from imbue_verify.cli.config.loader import validate_api_key_for_model
-from imbue_verify.cli.config.schema import ModelsConfig
-from imbue_verify.cli.models import DEFAULT_MODEL_ID
-from imbue_verify.cli.models import get_models_by_provider
-from imbue_verify.cli.models import validate_model_id
-from imbue_verify.formatters import OUTPUT_FIELDS
-from imbue_verify.formatters import OUTPUT_FORMATS
-from imbue_verify.formatters import format_issue_text
-from imbue_verify.formatters import issue_to_dict
-from imbue_verify.formatters import validate_output_fields
-
-VERSION = version("imbue_verify")
-
-_ISSUE_CODE_FIELDS = frozenset({"enabled_issue_codes", "disabled_issue_codes"})
-_PATH_FIELDS = frozenset({"repo", "output"})
-_PATH_LIST_FIELDS = frozenset({"extra_context"})
-
-
-def create_parser() -> argparse.ArgumentParser:
- parser = argparse.ArgumentParser(
- prog="imbue-verify",
- description="Identify issues in code changes using LLM-based analysis.",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- )
-
- parser.add_argument(
- "goal",
- type=str,
- nargs="?",
- default=CLI_DEFAULTS.goal,
- metavar="GOAL",
- help=(
- "Description of what the code change is trying to accomplish. "
- + "If not provided, only goal-independent issue identifiers will run."
- ),
- )
-
- parser.add_argument(
- "--repo",
- "-r",
- type=Path,
- default=Path.cwd(),
- metavar="PATH",
- help="Path to the repository for analysis (default: current directory)",
- )
-
- parser.add_argument(
- "--version",
- "-V",
- action="version",
- version=f"%(prog)s {VERSION}",
- )
-
- parser.add_argument(
- "--config",
- "-c",
- type=str,
- default=None,
- metavar="NAME",
- help="Name of the configuration to use. Configurations are defined in imbue-verify.toml in your target project's root or ~/.config/imbue-verify/config.toml.",
- )
- parser.add_argument(
- "--list-configs",
- action="store_true",
- help="List all available named configurations",
- )
-
- diff_group = parser.add_argument_group("diff options")
- diff_group.add_argument(
- "--base-commit",
- type=str,
- default=CLI_DEFAULTS.base_commit,
- metavar="REF",
- help=f"Git commit, branch, or ref to use as the base for computing the diff (default: {CLI_DEFAULTS.base_commit})",
- )
-
- context_group = parser.add_argument_group("context options")
- context_group.add_argument(
- "--history-loader",
- type=str,
- default=CLI_DEFAULTS.history_loader,
- metavar="COMMAND",
- help=(
- "Shell command that outputs conversation history as JSON to stdout. "
- + "Used to derive a goal if one is not provided."
- ),
- )
- context_group.add_argument(
- "--extra-context",
- type=Path,
- nargs="*",
- default=CLI_DEFAULTS.extra_context,
- metavar="FILE",
- help="Path(s) to file(s) containing additional context (e.g., library documentation). Content is included in the prompt after the codebase snapshot.",
- )
-
- analysis_group = parser.add_argument_group("analysis options")
- # Valid issue codes are defined in imbue_core.data_types.IssueCode
- analysis_group.add_argument(
- "--enabled-issue-codes",
- type=IssueCode,
- nargs="+",
- default=CLI_DEFAULTS.enabled_issue_codes,
- metavar="CODE",
- help="Only report issues of the given type(s). Use --list-issue-codes to see valid codes.",
- )
- analysis_group.add_argument(
- "--disabled-issue-codes",
- type=IssueCode,
- nargs="+",
- default=CLI_DEFAULTS.disabled_issue_codes,
- metavar="CODE",
- help="Do not report issues of the given type(s). Use --list-issue-codes to see valid codes.",
- )
- analysis_group.add_argument(
- "--list-issue-codes",
- action="store_true",
- help="List all available issue codes",
- )
-
- model_group = parser.add_argument_group("model configuration")
- model_group.add_argument(
- "--model",
- "-m",
- type=str,
- default=CLI_DEFAULTS.model,
- metavar="MODEL",
- help=f"LLM to use for analysis (default: {DEFAULT_MODEL_ID}). ",
- )
- model_group.add_argument(
- "--list-models",
- action="store_true",
- help="List all available models",
- )
- model_group.add_argument(
- "--temperature",
- type=float,
- default=CLI_DEFAULTS.temperature,
- metavar="TEMP",
- help=f"Override the default temperature for the model (default: {CLI_DEFAULTS.temperature}).",
- )
-
- filter_group = parser.add_argument_group("filtering options")
- filter_group.add_argument(
- "--confidence-threshold",
- type=float,
- default=CLI_DEFAULTS.confidence_threshold,
- metavar="THRESHOLD",
- help=f"Minimum confidence score (0.0-1.0) for issues to be reported (default: {CLI_DEFAULTS.confidence_threshold})",
- )
-
- parallel_group = parser.add_argument_group("parallelization options")
- parallel_group.add_argument(
- "--max-workers",
- type=int,
- default=CLI_DEFAULTS.max_workers,
- metavar="N",
- help=f"Maximum number of parallel workers for identification (default: {CLI_DEFAULTS.max_workers})",
- )
-
- output_group = parser.add_argument_group("output options")
- output_group.add_argument(
- "--output",
- "-o",
- type=Path,
- default=CLI_DEFAULTS.output,
- metavar="FILE",
- help="Output file path (default: stdout). Use - to write to stdout.",
- )
- output_group.add_argument(
- "--output-format",
- type=str,
- choices=OUTPUT_FORMATS,
- default=CLI_DEFAULTS.output_format,
- metavar="FORMAT",
- help=f"Output format. Choices: {', '.join(OUTPUT_FORMATS)} (default: {CLI_DEFAULTS.output_format})",
- )
- output_group.add_argument(
- "--output-fields",
- type=str,
- nargs="+",
- default=CLI_DEFAULTS.output_fields,
- metavar="FIELD",
- help="Output fields to include (default: all)",
- )
- output_group.add_argument(
- "--list-fields",
- action="store_true",
- help="List all available output data fields",
- )
- output_group.add_argument(
- "--verbose",
- "-v",
- action="store_true",
- default=CLI_DEFAULTS.verbose,
- help="Show verbose logger messages",
- )
- output_group.add_argument(
- "--quiet",
- "-q",
- action="store_true",
- default=CLI_DEFAULTS.quiet,
- help="Suppress progress indicator and non-essential output",
- )
-
- return parser
-
-
-def _get_available_issue_codes() -> list[IssueCode]:
- return [code for code in IssueCode if not code.name.startswith("_DEPRECATED")]
-
-
-# TODO: There are logical groupings of codes we should consider because some issue_codes are associated with the same prompts / categories of issues.
-# This should likely be used to dictate the ordering instead of sorting.
-def list_issue_codes() -> None:
- print("Available issue codes:")
- print()
- for code in sorted(_get_available_issue_codes(), key=lambda c: c.value):
- print(f" {code.value}")
-
-
-def list_models(user_config: ModelsConfig | None = None) -> None:
- print("Available models:")
- print()
- models_by_provider = get_models_by_provider(user_config)
- for provider, model_ids in sorted(models_by_provider.items()):
- print(f" {provider}:")
- for model_id in sorted(model_ids):
- default_marker = " (default)" if model_id == DEFAULT_MODEL_ID else ""
- print(f" {model_id}{default_marker}")
-
-
-def list_fields() -> None:
- print("Available output fields:")
- print()
- for field in OUTPUT_FIELDS:
- print(f" {field}")
-
-
-def list_configs(cli_configs: dict[str, CliConfigPreset], repo_path: Path) -> None:
- print("Available configurations:")
- print()
-
- if not cli_configs:
- print(" No configurations found.")
- print()
- print("Configuration files are loaded from:")
- for path in get_cli_config_file_paths(repo_path):
- exists_marker = " (exists)" if path.exists() else ""
- print(f" {path}{exists_marker}")
- return
-
- for name, preset in sorted(cli_configs.items()):
- print(f" {name}:")
- preset_dict = preset.model_dump(exclude_none=True)
- if preset_dict:
- for key, value in preset_dict.items():
- print(f" {key}: {value}")
- else:
- print(" (uses all defaults)")
- print()
-
-
-def configure_logging(verbose: bool, quiet: bool) -> None:
- logger.remove()
- if quiet:
- level = "WARNING"
- elif verbose:
- level = "DEBUG"
- else:
- level = "INFO"
- logger.add(sys.stderr, level=level)
-
-
-def load_conversation_from_command(command: str, cwd: Path) -> tuple:
- result = subprocess.run(command, shell=True, capture_output=True, text=True, cwd=cwd)
- if result.returncode != 0:
- logger.warning(f"History loader command failed with exit code {result.returncode}: {result.stderr}")
- return ()
- if not result.stdout.strip():
- return ()
- return parse_conversation_history(result.stdout)
-
-
-def apply_config_preset(args: argparse.Namespace, preset: CliConfigPreset) -> argparse.Namespace:
- preset_dict = preset.model_dump(exclude_none=True)
-
- for field, preset_value in preset_dict.items():
- default_value = getattr(CLI_DEFAULTS, field, None)
- if getattr(args, field) == default_value:
- if field in _ISSUE_CODE_FIELDS:
- preset_value = [IssueCode(code) for code in preset_value]
- elif field in _PATH_LIST_FIELDS:
- preset_value = [Path(p) for p in preset_value]
- elif field in _PATH_FIELDS:
- preset_value = Path(preset_value)
- setattr(args, field, preset_value)
-
- return args
-
-
-def main(argv: list[str] | None = None) -> int:
- parser = create_parser()
- args = parser.parse_args(argv)
-
- goal = args.goal or ""
-
- repo_path = args.repo
-
- try:
- user_config = load_models_config(repo_path)
- except ConfigLoadError as e:
- print(f"Error loading model configuration: {e}", file=sys.stderr)
- return 2
-
- if args.list_issue_codes:
- list_issue_codes()
- return 0
-
- if args.list_models:
- list_models(user_config)
- return 0
-
- if args.list_fields:
- list_fields()
- return 0
-
- try:
- cli_configs = load_cli_config(repo_path)
- except ConfigLoadError as e:
- print(f"Error loading CLI configuration: {e}", file=sys.stderr)
- return 2
-
- if args.list_configs:
- list_configs(cli_configs, repo_path)
- return 0
-
- if args.config is not None:
- try:
- preset = get_config_preset(args.config, cli_configs, repo_path)
- args = apply_config_preset(args, preset)
- except ConfigLoadError as e:
- print(f"Error: {e}", file=sys.stderr)
- return 2
-
- if not repo_path.exists():
- print(f"Error: Repository path does not exist: {repo_path}", file=sys.stderr)
- return 2
-
- if not repo_path.is_dir():
- print(f"Error: Repository path is not a directory: {repo_path}", file=sys.stderr)
- return 2
-
- if args.extra_context:
- for extra_context_file in args.extra_context:
- if not extra_context_file.exists():
- print(
- f"Error: Extra context file does not exist: {extra_context_file}",
- file=sys.stderr,
- )
- return 2
-
- if args.verbose and args.quiet:
- print(
- "Error: --verbose and --quiet are mutually exclusive",
- file=sys.stderr,
- )
- return 2
-
- if not 0.0 <= args.confidence_threshold <= 1.0:
- print(
- f"Error: Confidence threshold must be between 0.0 and 1.0, got: {args.confidence_threshold}",
- file=sys.stderr,
- )
- return 2
-
- if not 0.0 <= args.temperature <= 2.0:
- print(
- f"Error: Temperature must be between 0.0 and 2.0, got: {args.temperature}",
- file=sys.stderr,
- )
- return 2
-
- configure_logging(args.verbose, args.quiet)
-
- conversation_history = None
- if args.history_loader is not None:
- conversation_history = load_conversation_from_command(args.history_loader, repo_path)
-
- extra_context = None
- if args.extra_context:
- extra_context_parts = []
- for context_file in args.extra_context:
- extra_context_parts.append(context_file.read_text())
- extra_context = "\n\n".join(extra_context_parts)
-
- if args.output_fields is not None:
- try:
- validate_output_fields(args.output_fields)
- except ValueError as e:
- print(f"Error: {e}", file=sys.stderr)
- return 2
-
- model_id = args.model or DEFAULT_MODEL_ID
-
- try:
- model_id = validate_model_id(model_id, user_config)
- except ValueError as e:
- print(f"Error: {e}", file=sys.stderr)
- return 2
-
- try:
- validate_api_key_for_model(model_id, user_config)
- except Exception as e:
- print(f"Error: {e}", file=sys.stderr)
- return 2
-
- # TODO: Support OFFLINE, UPDATE_SNAPSHOT, and MOCKED modes.
- language_model_config = build_language_model_config(model_id, user_config)
- max_output_tokens = get_max_output_tokens_for_model(model_id, user_config)
-
- config = ImbueVerifyConfig(
- disabled_identifiers=("agentic_issue_identifier",),
- language_model_generation_config=language_model_config,
- enabled_issue_codes=(tuple(args.enabled_issue_codes) if args.enabled_issue_codes else None),
- disabled_issue_codes=(tuple(args.disabled_issue_codes) if args.disabled_issue_codes else None),
- temperature=args.temperature,
- filter_issues_below_confidence=args.confidence_threshold,
- max_identify_workers=args.max_workers,
- max_output_tokens=max_output_tokens or 20000,
- extra_context=extra_context,
- )
-
- issues = find_issues(
- repo_path=repo_path,
- relative_to=args.base_commit,
- goal=goal,
- config=config,
- conversation_history=conversation_history,
- )
-
- output_fields = args.output_fields if args.output_fields else OUTPUT_FIELDS
-
- output_file = None
- if args.output is not None and str(args.output) != "-":
- output_file = open(args.output, "w")
- output_stream = output_file
- else:
- output_stream = sys.stdout
-
- try:
- if not issues:
- if args.output_format == "json":
- print(json.dumps({"issues": []}, indent=2), file=output_stream)
- elif not args.quiet:
- print("No issues found.", file=output_stream)
- return 0
-
- if args.output_format == "json":
- issues_list = [issue_to_dict(issue, output_fields) for issue in issues]
- print(json.dumps({"issues": issues_list}, indent=2), file=output_stream)
- else:
- for issue in issues:
- print(format_issue_text(issue, output_fields), file=output_stream)
- print(file=output_stream)
-
- return 1
- finally:
- if output_file is not None:
- output_file.close()
-
-
-if __name__ == "__main__":
- sys.exit(main())
diff --git a/imbue_verify/cli/models.py b/imbue_verify/cli/models.py
@@ -1,65 +0,0 @@
-from __future__ import annotations
-
-from imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
-from imbue_core.agents.llm_apis.common import get_all_model_names
-from imbue_core.agents.llm_apis.gemini_api import GeminiModelName
-from imbue_core.agents.llm_apis.groq_api import GroqSupportedModelName
-from imbue_core.agents.llm_apis.openai_api import OpenAIModelName
-from imbue_core.agents.llm_apis.together_api import TogetherAIModelName
-from imbue_verify.cli.config.loader import get_models_by_provider_from_config
-from imbue_verify.cli.config.loader import get_user_defined_model_ids
-from imbue_verify.cli.config.schema import ModelsConfig
-
-DEFAULT_MODEL_ID = AnthropicModelName.CLAUDE_4_5_HAIKU_2025_10_01.value
-
-
-def get_builtin_model_ids() -> set[str]:
- return {str(name) for name in get_all_model_names()}
-
-
-def get_all_model_ids(user_config: ModelsConfig | None = None) -> set[str]:
- model_ids = get_builtin_model_ids()
-
- if user_config:
- model_ids.update(get_user_defined_model_ids(user_config))
-
- return model_ids
-
-
-def is_valid_model_id(model_id: str, user_config: ModelsConfig | None = None) -> bool:
- return model_id in get_all_model_ids(user_config)
-
-
-def is_user_defined_model(model_id: str, user_config: ModelsConfig | None = None) -> bool:
- if user_config is None:
- return False
- return model_id in get_user_defined_model_ids(user_config)
-
-
-def validate_model_id(model_id: str, user_config: ModelsConfig | None = None) -> str:
- if not is_valid_model_id(model_id, user_config):
- raise ValueError(f"Unknown model: {model_id}. Use --list-models to see available models.")
- return model_id
-
-
-def get_builtin_models_by_provider() -> dict[str, list[str]]:
- return {
- "anthropic": [m.value for m in AnthropicModelName],
- "openai": [m.value for m in OpenAIModelName],
- "gemini": [m.value for m in GeminiModelName],
- "groq": [m.value for m in GroqSupportedModelName],
- "together": [m.value for m in TogetherAIModelName],
- }
-
-
-def get_models_by_provider(
- user_config: ModelsConfig | None = None,
-) -> dict[str, list[str]]:
- providers = get_builtin_models_by_provider()
-
- if user_config:
- user_providers = get_models_by_provider_from_config(user_config)
- for provider_name, model_ids in user_providers.items():
- providers[provider_name] = model_ids
-
- return providers
diff --git a/imbue_verify/cli/models_test.py b/imbue_verify/cli/models_test.py
@@ -1,168 +0,0 @@
-from __future__ import annotations
-
-import pytest
-
-from imbue_verify.cli.config.schema import ModelConfig
-from imbue_verify.cli.config.schema import ModelsConfig
-from imbue_verify.cli.config.schema import ProviderConfig
-from imbue_verify.cli.models import DEFAULT_MODEL_ID
-from imbue_verify.cli.models import get_all_model_ids
-from imbue_verify.cli.models import get_builtin_model_ids
-from imbue_verify.cli.models import get_builtin_models_by_provider
-from imbue_verify.cli.models import get_models_by_provider
-from imbue_verify.cli.models import is_user_defined_model
-from imbue_verify.cli.models import is_valid_model_id
-from imbue_verify.cli.models import validate_model_id
-
-SAMPLE_USER_CONFIG = ModelsConfig(
- providers={
- "custom": ProviderConfig(
- base_url="http://localhost:8080/v1",
- api_key_env="CUSTOM_KEY",
- models={
- "my-custom-model": ModelConfig(context_window=128000, max_output_tokens=16384),
- "another-model": ModelConfig(context_window=128000, max_output_tokens=16384),
- },
- )
- }
-)
-
-
-def test_default_model_is_in_builtin_models() -> None:
- assert DEFAULT_MODEL_ID in get_builtin_model_ids()
-
-
-def test_get_builtin_model_ids_returns_strings() -> None:
- model_ids = get_builtin_model_ids()
- assert all(isinstance(m, str) for m in model_ids)
-
-
-def test_get_all_model_ids_returns_builtin_models_when_no_config() -> None:
- all_ids = get_all_model_ids(user_config=None)
- builtin_ids = get_builtin_model_ids()
- assert all_ids == builtin_ids
-
-
-def test_get_all_model_ids_includes_user_defined_models() -> None:
- all_ids = get_all_model_ids(SAMPLE_USER_CONFIG)
-
- assert "my-custom-model" in all_ids
- assert "another-model" in all_ids
- assert DEFAULT_MODEL_ID in all_ids
-
-
-@pytest.mark.parametrize(
- ("model_id", "user_config", "expected"),
- [
- (DEFAULT_MODEL_ID, None, True),
- ("nonexistent-model-xyz", None, False),
- ("my-custom-model", SAMPLE_USER_CONFIG, True),
- ],
-)
-def test_is_valid_model_id(model_id: str, user_config: ModelsConfig | None, expected: bool) -> None:
- assert is_valid_model_id(model_id, user_config) is expected
-
-
-@pytest.mark.parametrize(
- ("model_id", "user_config", "expected"),
- [
- ("any-model", None, False),
- ("my-custom-model", SAMPLE_USER_CONFIG, True),
- (DEFAULT_MODEL_ID, SAMPLE_USER_CONFIG, False),
- ],
-)
-def test_is_user_defined_model(model_id: str, user_config: ModelsConfig | None, expected: bool) -> None:
- assert is_user_defined_model(model_id, user_config) is expected
-
-
-def test_validate_model_id_returns_model_id_when_valid() -> None:
- result = validate_model_id(DEFAULT_MODEL_ID)
- assert result == DEFAULT_MODEL_ID
-
-
-def test_validate_model_id_raises_for_invalid_model() -> None:
- with pytest.raises(ValueError) as exc_info:
- validate_model_id("nonexistent-model-xyz")
-
- assert "Unknown model: nonexistent-model-xyz" in str(exc_info.value)
- assert "--list-models" in str(exc_info.value)
-
-
-def test_validate_model_id_validates_user_defined_model() -> None:
- user_config = ModelsConfig(
- providers={
- "custom": ProviderConfig(
- base_url="http://localhost:8080/v1",
- api_key_env="CUSTOM_KEY",
- models={"my-custom-model": ModelConfig(context_window=128000, max_output_tokens=16384)},
- )
- }
- )
-
- result = validate_model_id("my-custom-model", user_config)
- assert result == "my-custom-model"
-
-
-def test_get_builtin_models_by_provider_returns_dict_with_expected_providers() -> None:
- providers = get_builtin_models_by_provider()
-
- assert "anthropic" in providers
- assert "openai" in providers
- assert "gemini" in providers
- assert "groq" in providers
- assert "together" in providers
-
-
-def test_get_builtin_models_by_provider_all_values_are_lists_of_strings() -> None:
- providers = get_builtin_models_by_provider()
-
- for provider_name, models in providers.items():
- assert isinstance(models, list), f"{provider_name} should have a list of models"
- assert all(isinstance(m, str) for m in models), f"{provider_name} models should all be strings"
-
-
-def test_get_models_by_provider_returns_builtin_providers_when_no_config() -> None:
- providers = get_models_by_provider(user_config=None)
- builtin_providers = get_builtin_models_by_provider()
-
- assert providers == builtin_providers
-
-
-def test_get_models_by_provider_includes_user_defined_providers() -> None:
- user_config = ModelsConfig(
- providers={
- "ollama": ProviderConfig(
- name="Ollama Local",
- base_url="http://localhost:11434/v1",
- api_key_env="OLLAMA_KEY",
- models={
- "llama3.2:latest": ModelConfig(context_window=128000, max_output_tokens=16384),
- "qwen:7b": ModelConfig(context_window=32768, max_output_tokens=8192),
- },
- )
- }
- )
-
- providers = get_models_by_provider(user_config)
-
- assert "Ollama Local" in providers
- assert set(providers["Ollama Local"]) == {"llama3.2:latest", "qwen:7b"}
- assert "anthropic" in providers
- assert "openai" in providers
-
-
-def test_get_models_by_provider_user_provider_overrides_builtin_with_same_name() -> None:
- user_config = ModelsConfig(
- providers={
- "custom": ProviderConfig(
- name="anthropic",
- base_url="http://localhost:8080/v1",
- api_key_env="CUSTOM_KEY",
- models={"custom-model": ModelConfig(context_window=128000, max_output_tokens=16384)},
- )
- }
- )
-
- providers = get_models_by_provider(user_config)
-
- assert providers["anthropic"] == ["custom-model"]
diff --git a/imbue_verify/issue_identifiers/agentic_issue_collation.py b/imbue_verify/issue_identifiers/agentic_issue_collation.py
@@ -1,183 +0,0 @@
-import json
-from typing import Generator
-from typing import Iterable
-
-import jinja2
-
-from imbue_core.data_types import AgenticPhase
-from imbue_core.data_types import IssueCode
-from imbue_core.data_types import IssueIdentificationDebugInfo
-from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
-from imbue_core.data_types import LLMResponse
-from imbue_tools.get_conversation_history.input_data_types import CommitInputs
-from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
-from imbue_tools.get_conversation_history.input_data_types import (
- to_specific_inputs_type,
-)
-from imbue_tools.repo_utils.context_utils import escape_prompt_markers
-from imbue_tools.repo_utils.project_context import ProjectContext
-from imbue_tools.types.imbue_verify_config import ImbueVerifyConfig
-from imbue_verify.issue_identifiers.common import GeneratedIssueSchema
-from imbue_verify.issue_identifiers.common import GeneratedResponseSchema
-from imbue_verify.issue_identifiers.common import extract_invocation_info_from_messages
-from imbue_verify.issue_identifiers.common import (
- format_issue_identification_guide_for_llm,
-)
-from imbue_verify.issue_identifiers.common import generate_issues_from_response_texts
-from imbue_verify.issue_identifiers.common import generate_response_from_claude_code
-from imbue_verify.issue_identifiers.common import get_claude_code_options
-from imbue_verify.issue_identifiers.identification_guides import (
- ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
-)
-from imbue_verify.issue_identifiers.utils import ReturnCapturingGenerator
-
-COLLATION_PROMPT_TEMPLATE = """You are reviewing the results from parallel code analysis for potential issues.
-Multiple specialized agents analyzed the following code diff, each focusing on a specific type of issue.
-The repository files are available in {{ repo_path }}.
-
-### User request ###
-{% filter indent(width=2) %}
-{{ commit_message }}
-{% endfilter %}
-
-### Diff (lines starting with `-` indicate removed code, and lines starting with `+` indicate added code) ###
-{% filter indent(width=2) %}
-{{ unified_diff }}
-{% endfilter %}
-###
-
-The rubric below outlines the categories of issues we care about:
-{% for issue_code, guide in guides.items() %}
----
-**{{ issue_code }}**:
-{{ guide }}
-{% endfor %}
----
-
-### Parallel Analysis Results ###
-{{ generated_issues }}
-
-Your task is to:
-1. Review all the findings for accuracy and relevance using the category definitions above
-2. Consolidate any duplicate or overlapping issues
-3. Ensure each issue is correctly categorized according to the category definitions and re-categorize any issues if necessary
-4. Return a consolidated set of issues
-
-Guidelines:
-- Merge similar issues that refer to the same underlying problem
-- Do not remove any issues, you may only re-categorize or merge issues
-
-After your analysis, provide your response in JSON format matching this schema:
-
-{{ response_schema | tojson(indent=2) }}
-"""
-
-
-def _get_collation_prompt(
- project_context: ProjectContext,
- identifier_inputs: CommitInputs,
- enabled_issue_codes: tuple[IssueCode, ...],
- generated_issues: str,
-) -> str:
- # Sort issue codes to make the resulting prompts deterministic (for snapshot tests and LLM caching)
- sorted_issue_codes = sorted(enabled_issue_codes)
- formatted_guides = {
- code: format_issue_identification_guide_for_llm(ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[code])
- for code in sorted_issue_codes
- }
-
- env = jinja2.Environment(undefined=jinja2.StrictUndefined)
- jinja_template = env.from_string(COLLATION_PROMPT_TEMPLATE)
-
- prompt = jinja_template.render(
- {
- "repo_path": project_context.repo_path,
- "commit_message": escape_prompt_markers(identifier_inputs.goal),
- "unified_diff": escape_prompt_markers(identifier_inputs.diff),
- "guides": formatted_guides,
- "response_schema": GeneratedResponseSchema.model_json_schema(),
- "generated_issues": escape_prompt_markers(generated_issues),
- }
- )
- return prompt
-
-
-def _convert_parsed_issues_to_combined_string(
- all_parsed_issues: Iterable[GeneratedIssueSchema],
-) -> str:
- """Convert all parsed issues from all issue types to a combined string for collation prompt."""
- combined_issues = []
-
- for issue in all_parsed_issues:
- issue_dict = issue.model_dump()
- for key in ("location", "code_part"):
- if key in issue_dict and issue_dict[key] is None:
- del issue_dict[key]
- combined_issues.append(issue_dict)
-
- return json.dumps({"issues": combined_issues}, indent=2)
-
-
-def collate_issues_with_agent(
- issue_generator: Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo],
- identifier_inputs: IdentifierInputs,
- project_context: ProjectContext,
- config: ImbueVerifyConfig,
- enabled_issue_codes: tuple[IssueCode, ...],
-) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
- """
- Collate issues from multiple issue identifiers.
-
- Args:
- issues: The issues to collate.
- identifier_inputs: The inputs which determine the content provided to the identifiers.
- project_context: Loaded data corresponding to the inputs, e.g. diffs or files.
- config: Settings for imbue verify.
- enabled_issue_codes: The issue types used by the issue identifiers.
-
- Returns:
- A generator of collated issues. Returns IssueIdentificationDebugInfo after the generator is exhausted.
-
- Raises:
- IdentifierInputsMissingError: If the identifier inputs are missing the commit message or diff, which are required for collation.
- """
- collation_inputs = to_specific_inputs_type(identifier_inputs, CommitInputs)
-
- all_issues = []
- issue_generator_with_capture = ReturnCapturingGenerator(issue_generator)
- for issue in issue_generator_with_capture:
- all_issues.append(issue)
- issue_generator_debug_info = issue_generator_with_capture.return_value
-
- options = get_claude_code_options(
- cwd=project_context.repo_path,
- model_name=config.language_model_generation_config.model_name,
- )
- combined_issues_string = _convert_parsed_issues_to_combined_string(all_issues)
- collation_prompt = _get_collation_prompt(
- project_context, collation_inputs, enabled_issue_codes, combined_issues_string
- )
- claude_response = generate_response_from_claude_code(collation_prompt, options)
- assert claude_response is not None
- response_text, collation_messages = claude_response
- collation_raw_messages = tuple(json.dumps(message.model_dump()) for message in collation_messages)
- collation_invocation_info = extract_invocation_info_from_messages(collation_messages)
-
- collation_llm_responses = (
- LLMResponse(
- metadata=IssueIdentificationLLMResponseMetadata(
- agentic_phase=AgenticPhase.COLLATION,
- issue_type=None,
- ),
- raw_response=collation_raw_messages,
- invocation_info=collation_invocation_info,
- ),
- )
-
- yield from generate_issues_from_response_texts(response_texts=(response_text,))
-
- augmented_debug_info = IssueIdentificationDebugInfo(
- llm_responses=issue_generator_debug_info.llm_responses + collation_llm_responses
- )
-
- return augmented_debug_info
diff --git a/imbue_verify/issue_identifiers/base.py b/imbue_verify/issue_identifiers/base.py
@@ -1,92 +0,0 @@
-import abc
-from typing import Generator
-from typing import Generic
-from typing import TypeVar
-
-from imbue_core.data_types import IssueCode
-from imbue_core.data_types import IssueIdentificationDebugInfo
-from imbue_core.pydantic_serialization import SerializableModel
-from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
-from imbue_tools.get_conversation_history.input_data_types import (
- to_specific_inputs_type,
-)
-from imbue_tools.repo_utils.project_context import ProjectContext
-from imbue_tools.types.imbue_verify_config import ImbueVerifyConfig
-from imbue_verify.issue_identifiers.common import GeneratedIssueSchema
-
-T = TypeVar("T", bound=IdentifierInputs)
-
-
-class IssueIdentifier(SerializableModel, abc.ABC, Generic[T]):
- """
- A protocol for identifying issues given certain inputs.
-
- By implementing this protocol and registering the new instance in `imbue_verify/issue_identifiers/registry.py`,
- one can create a new issue identifier and automatically expand the default abilities of imbue_verify.
-
- """
-
- @abc.abstractmethod
- def identify_issues(
- self,
- identifier_inputs: T,
- project_context: ProjectContext,
- config: ImbueVerifyConfig,
- ) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
- """
- Identify issues given the identifier inputs.
-
- Args:
- identifier_inputs: The inputs which determine the content provided to the identifier.
- project_context: Loaded data corresponding to the inputs, e.g. diffs or files.
- config: Settings for imbue verify.
-
- Returns:
- A generator of identified issues. When done iterating, returns the debug info.
-
- Raises:
- IdentifierInputsMissingError: If the identifier inputs are missing required information for this identifier.
- """
-
- @abc.abstractmethod
- def input_type(self) -> type[T]:
- """
- The type of inputs that this identifier expects.
- """
-
- def to_required_inputs(self, identifier_inputs: IdentifierInputs) -> T:
- return to_specific_inputs_type(identifier_inputs, self.input_type())
-
- @property
- @abc.abstractmethod
- def enabled_issue_codes(self) -> tuple[IssueCode, ...]:
- """
- The issue codes that this identifier is capable of identifying.
- """
-
- @property
- def requires_agentic_collation(self) -> bool:
- """
- Whether this identifier requires agentic collation of issues.
- """
- return False
-
- @property
- @abc.abstractmethod
- def identifies_code_issues(self) -> bool:
- """
- Whether this identifier identifies code-related issues (as opposed to e.g. conversation-related issues).
- """
- pass
-
- @abc.abstractmethod
- def _get_prompt(
- self,
- project_context: ProjectContext,
- config: ImbueVerifyConfig,
- identifier_inputs: T,
- ) -> str:
- """
- Get the prompt for this identifier.
- """
- pass
diff --git a/imbue_verify/issue_identifiers/common.py b/imbue_verify/issue_identifiers/common.py
@@ -1,328 +0,0 @@
-"""
-Common components shared between issue identifiers.
-"""
-
-from pathlib import Path
-from typing import Generator
-from typing import Iterable
-
-import jinja2
-from loguru import logger
-from pydantic import Field
-from pydantic import PrivateAttr
-
-from imbue_core.agents.agent_api.api import get_agent_client
-from imbue_core.agents.agent_api.claude.data_types import ClaudeCodeOptions
-from imbue_core.agents.agent_api.data_types import AgentAssistantMessage
-from imbue_core.agents.agent_api.data_types import AgentMessage
-from imbue_core.agents.agent_api.data_types import AgentResultMessage
-from imbue_core.agents.agent_api.data_types import AgentTextBlock
-from imbue_core.agents.agent_api.data_types import AgentToolName
-from imbue_core.agents.agent_api.data_types import READ_ONLY_TOOLS
-from imbue_core.agents.llm_apis.anthropic_data_types import AnthropicCachingInfo
-from imbue_core.agents.llm_apis.data_types import CostedLanguageModelResponse
-from imbue_core.async_monkey_patches import log_exception
-from imbue_core.data_types import ConfidenceScore
-from imbue_core.data_types import IdentifiedVerifyIssue
-from imbue_core.data_types import InvocationInfo
-from imbue_core.data_types import IssueCode
-from imbue_core.data_types import IssueIdentificationDebugInfo
-from imbue_core.data_types import IssueIdentifierResult
-from imbue_core.data_types import IssueLocation
-from imbue_core.data_types import LineRange
-from imbue_core.data_types import SeverityScore
-from imbue_core.pydantic_serialization import SerializableModel
-from imbue_tools.llm_output_parsing.parse_model_json_response import (
- ResponseParsingError,
-)
-from imbue_tools.llm_output_parsing.parse_model_json_response import (
- parse_model_json_response,
-)
-from imbue_tools.repo_utils.project_context import ProjectContext
-from imbue_verify.issue_identifiers.identification_guides import (
- IssueIdentificationGuide,
-)
-from imbue_verify.issue_identifiers.utils import ReturnCapturingGenerator
-
-
-class GeneratedIssueSchema(SerializableModel):
- """Individual issue from LLM response."""
-
- issue_code: str = Field(description="Category of the issue")
- description: str = Field(description="Specific explanation of what's wrong and why it's incorrect")
- location: str | None = Field(default=None, description="File path where the issue occurs")
- code_part: str | None = Field(default=None, description="Specific code snippet that has the issue")
- # pyre doesn't like the way ints/floats implement ge/le
- severity: int = Field(description="Integer 1-5 (1=minor issue, 5=critical bug)", ge=1, le=5) # pyre-ignore[6]
- confidence: float = Field(description="Confidence in this assessment", ge=0.0, le=1.0) # pyre-ignore[6]
-
- # ----------------------------------------------------------------
- # Internal mutable fields used by the post-identification pipeline for tagging.
- # These fields are mutable, but "monotic", in the sense that they can only be populated once and never
- # be changed again after that.
- # These won't be populated by issue identifiers and are not shown to LLMs.
- # ----------------------------------------------------------------
- _passes_filtration: bool | None = PrivateAttr(default=None)
-
- @property
- def passes_filtration(self) -> bool:
- if self._passes_filtration is None:
- # Default to True if not set
- return True
- else:
- return self._passes_filtration
-
- def set_passes_filtration(self, passes: bool) -> None:
- assert self._passes_filtration is None, "passes_filtration can only be set once"
- self._passes_filtration = passes
-
-
-class GeneratedResponseSchema(SerializableModel):
- """Complete response structure for issue identification."""
-
- issues: list[GeneratedIssueSchema] = Field(default_factory=list, description="List of identified issues")
-
-
-def generate_issues_from_response_texts(
- response_texts: Iterable[str],
-) -> Generator[GeneratedIssueSchema, None, None]:
- """Generate IssueIdentifierResult objects from LLM response text."""
- for response_text in response_texts:
- try:
- parsed_data = parse_model_json_response(response_text, GeneratedResponseSchema)
- except ResponseParsingError:
- logger.warning(f"Failed to parse response text: {response_text}")
- continue
-
- for raw_issue in parsed_data.issues:
- yield raw_issue
-
-
-def line_ranges_to_issue_locations(line_ranges: Iterable[LineRange], file_path: str) -> tuple[IssueLocation, ...]:
- """Convert LineRange objects to IssueLocation objects."""
- return tuple(
- IssueLocation(
- line_start=line_range.start,
- line_end=line_range.end,
- filename=file_path,
- )
- for line_range in line_ranges
- )
-
-
-def convert_generated_issue_to_identified_issue(
- issue_data: GeneratedIssueSchema,
- project_context: ProjectContext,
- enabled_issue_codes: tuple[IssueCode, ...],
-) -> IdentifiedVerifyIssue | None:
- try:
- # Validate issue code
- issue_code = issue_data.issue_code
- if issue_code not in enabled_issue_codes:
- logger.error(
- "Got issue code '{issue_code}', skipping. Expected one of: {enabled_issue_codes}",
- issue_code=issue_code,
- enabled_issue_codes=enabled_issue_codes,
- )
- return None
-
- # Extract location and code part for line ranges
- line_ranges: tuple[LineRange, ...] = ()
- issue_location = issue_data.location
- try:
- issue_location_path = Path(issue_location) if issue_location else None
- if project_context.repo_path and issue_location_path and issue_location_path.is_absolute():
- # Make absolute path relative.
- # This will raise ValueError if issue_location_path is not under repo_path.
- repo_path = project_context.repo_path
- assert repo_path is not None
- issue_location_path = issue_location_path.relative_to(repo_path)
- except ValueError:
- issue_location_path = None
- logger.warning(f"Invalid location '{issue_location}', skipping line range detection.")
- issue_code_part = issue_data.code_part
- if issue_location_path and issue_code_part:
- contents = project_context.file_contents_by_path.get(issue_location_path.as_posix())
- if contents is not None:
- line_ranges = LineRange.build_from_substring(contents, issue_code_part)
- if not line_ranges:
- logger.debug(
- "Could not find code_part in file {location}: {code_part_repr}",
- location=issue_location,
- code_part_repr=repr(issue_code_part),
- )
- else:
- logger.warning(f"Unknown location '{issue_location}', skipping line range detection.")
-
- # Convert severity (1-5) to normalized score (0-1)
- severity_normalized = (issue_data.severity - 1) / 4.0 # Map 1-5 to 0-1
- locations = line_ranges_to_issue_locations(
- line_ranges, issue_location_path.as_posix() if issue_location_path else ""
- )
- return IdentifiedVerifyIssue(
- code=IssueCode(issue_data.issue_code),
- description=issue_data.description,
- severity_score=SeverityScore(raw=issue_data.severity, normalized=severity_normalized),
- confidence_score=ConfidenceScore(raw=issue_data.confidence, normalized=issue_data.confidence),
- location=locations,
- )
-
- except (ValueError, KeyError, TypeError) as e:
- log_exception(
- e,
- "Error processing issue data: {issue_data}, skipping",
- issue_data=issue_data,
- )
- return None
-
-
-def convert_to_issue_identifier_result(
- generator: Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo],
- project_context: ProjectContext,
- enabled_issue_codes: tuple[IssueCode, ...],
-) -> Generator[IssueIdentifierResult, None, IssueIdentificationDebugInfo]:
- """Convert a generator of GeneratedIssueSchema to IssueIdentifierResult."""
- generator_with_capture = ReturnCapturingGenerator(generator)
- for issue_data in generator_with_capture:
- issue = convert_generated_issue_to_identified_issue(
- issue_data=issue_data,
- project_context=project_context,
- enabled_issue_codes=enabled_issue_codes,
- )
- if issue:
- yield IssueIdentifierResult(issue=issue, passes_filtration=issue_data.passes_filtration)
-
- return generator_with_capture.return_value
-
-
-def get_claude_code_options(cwd: Path | None, model_name: str) -> ClaudeCodeOptions:
- options = ClaudeCodeOptions(
- cwd=cwd,
- permission_mode="bypassPermissions", # Equivalent to --dangerously-skip-permissions
- allowed_tools=list(READ_ONLY_TOOLS) + [AgentToolName.BASH],
- model=model_name,
- )
- return options
-
-
-def generate_response_from_claude_code(
- prompt: str, options: ClaudeCodeOptions
-) -> tuple[str, list[AgentMessage]] | None:
- messages = []
- assistant_messages = []
- result_message = None
- try:
- with get_agent_client(options=options) as client:
- for message in client.process_query(prompt):
- messages.append(message)
- if isinstance(message, AgentAssistantMessage):
- assistant_messages.append(message)
- elif isinstance(message, AgentResultMessage):
- result_message = message
- except Exception as e:
- log_exception(e, "Claude Code API call failed")
- return None
-
- # Try to get response from result message first
- response_text = ""
- if result_message and result_message.result:
- response_text = result_message.result
-
- # If no result message, concatenate assistant messages
- if not response_text and assistant_messages:
- for message in assistant_messages:
- for content_block in message.content:
- if isinstance(content_block, AgentTextBlock):
- response_text += content_block.text.strip() + "\n"
-
- return response_text, messages
-
-
-def extract_invocation_info_from_costed_response(
- response: CostedLanguageModelResponse,
-) -> InvocationInfo:
- usage = response.usage
-
- cache_creation_tokens = None
- cache_read_tokens = None
-
- if usage.caching_info is not None:
- caching_info = usage.caching_info
- cache_read_tokens = caching_info.read_from_cache
-
- if caching_info.provider_specific_data is not None:
- if isinstance(caching_info.provider_specific_data, AnthropicCachingInfo):
- cache_creation_tokens = (
- caching_info.provider_specific_data.written_5m + caching_info.provider_specific_data.written_1h
- )
- else:
- logger.info(
- "Not recording caching info for provider specific data type {}",
- type(caching_info.provider_specific_data),
- )
-
- return InvocationInfo(
- input_tokens=usage.prompt_tokens_used,
- cache_creation_input_tokens=cache_creation_tokens,
- cache_read_input_tokens=cache_read_tokens,
- total_input_tokens=usage.prompt_tokens_used,
- output_tokens=usage.completion_tokens_used,
- cost=usage.dollars_used,
- )
-
-
-def extract_invocation_info_from_messages(
- messages: list[AgentMessage],
-) -> InvocationInfo:
- """Extract invocation information from Agent messages."""
- for message in messages:
- if isinstance(message, AgentResultMessage):
- total_input_tokens = None
- usage = message.usage
- if usage:
- input_tokens = usage.input_tokens
- cached_tokens = usage.cached_tokens
- output_tokens = usage.output_tokens
- else:
- input_tokens = None
- cached_tokens = None
- output_tokens = None
- if usage and input_tokens is not None and cached_tokens is not None:
- total_input_tokens = input_tokens + cached_tokens
- return InvocationInfo(
- input_tokens=input_tokens,
- cache_creation_input_tokens=None,
- cache_read_input_tokens=cached_tokens,
- total_input_tokens=total_input_tokens,
- output_tokens=output_tokens,
- duration_ms=message.duration_ms,
- cost=usage.total_cost_usd if usage else None,
- num_turns=message.num_turns,
- )
- return InvocationInfo()
-
-
-_ISSUE_IDENTIFICATION_LLM_FORMAT = """
-Guidelines:{% filter indent(width=4) %}
-{{ guide }}{% endfilter %}
-{%- if examples %}
-Examples:
-{%- for example in examples %}
- - {{ example }}
-{%- endfor %}
-{%- endif -%}
-{%- if exceptions %}
-Exceptions:
-{%- for exception in exceptions %}
- - {{ exception }}
-{%- endfor %}
-{%- endif -%}
-"""
-
-
-def format_issue_identification_guide_for_llm(guide: IssueIdentificationGuide) -> str:
- formatted_guide = jinja2.Template(_ISSUE_IDENTIFICATION_LLM_FORMAT).render(
- guide=guide.guide, examples=guide.examples, exceptions=guide.exceptions
- )
-
- return formatted_guide.strip()
diff --git a/imbue_verify/issue_identifiers/common_test.py b/imbue_verify/issue_identifiers/common_test.py
@@ -1,375 +0,0 @@
-import json
-from pathlib import Path
-from typing import Iterable
-
-from imbue_core.async_monkey_patches_test import expect_exact_logged_errors
-from imbue_core.data_types import IdentifiedVerifyIssue
-from imbue_core.data_types import IssueCode
-from imbue_core.frozen_utils import FrozenDict
-from imbue_core.itertools import only
-from imbue_tools.llm_output_parsing.parse_model_json_response import (
- ResponseParsingError,
-)
-from imbue_tools.llm_output_parsing.parse_model_json_response import (
- parse_model_json_response,
-)
-from imbue_tools.repo_utils.project_context import BaseProjectContext
-from imbue_tools.repo_utils.project_context import ProjectContext
-from imbue_verify.issue_identifiers.common import GeneratedResponseSchema
-from imbue_verify.issue_identifiers.common import (
- convert_generated_issue_to_identified_issue,
-)
-from imbue_verify.issue_identifiers.common import (
- format_issue_identification_guide_for_llm,
-)
-from imbue_verify.issue_identifiers.identification_guides import (
- ISSUE_CODES_FOR_CORRECTNESS_CHECK,
-)
-from imbue_verify.issue_identifiers.identification_guides import (
- IssueIdentificationGuide,
-)
-
-
-def _parse_issues(
- valid_response: str,
- project_context: ProjectContext,
- enabled_issue_codes: Iterable[IssueCode],
-) -> list[IdentifiedVerifyIssue]:
- issues = []
- try:
- issue_data = parse_model_json_response(valid_response, GeneratedResponseSchema)
- except ResponseParsingError:
- return []
- for parsed_issue in issue_data.issues:
- issue = convert_generated_issue_to_identified_issue(
- issue_data=parsed_issue,
- project_context=project_context,
- enabled_issue_codes=tuple(enabled_issue_codes),
- )
- if issue:
- issues.append(issue)
- return issues
-
-
-def test_parse_issues_valid_json() -> None:
- project_context = BaseProjectContext(
- file_contents_by_path=FrozenDict({"test.py": "def test():\n while True:\n pass"}),
- cached_prompt_prefix="test",
- )
-
- valid_response = json.dumps(
- {
- "issues": [
- {
- "issue_code": "logic_error",
- "description": "Infinite loop detected",
- "location": "test.py",
- "code_part": "while True:\n pass",
- "severity": 5,
- "confidence": 0.95,
- }
- ]
- }
- )
-
- issues = _parse_issues(valid_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
-
- issue = only(issues)
- assert issue.code == IssueCode.LOGIC_ERROR
- assert issue.description == "Infinite loop detected"
- assert issue.confidence_score is not None
- assert issue.confidence_score.normalized == 0.95
- assert issue.severity_score is not None
- assert issue.severity_score.normalized == 1.0 # severity 5 maps to 1.0
- assert len(issue.location) == 1
- assert issue.location[0].filename == "test.py"
-
-
-def test_parse_response_with_leading_and_trailing_text() -> None:
- project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="test")
- valid_response = json.dumps(
- {
- "issues": [
- {
- "issue_code": "logic_error",
- "description": "Infinite loop detected",
- "location": "test.py",
- "code_part": "while True:\n pass",
- "severity": 5,
- "confidence": 0.95,
- }
- ]
- }
- )
-
- response_text = "Some leading text\n```json\n" + valid_response + "\n```\nSome trailing text"
- # Note: This logs a warning about "Unknown location" since test.py isn't in the project context
- issues = _parse_issues(response_text, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
- issue = only(issues)
- assert issue.code == IssueCode.LOGIC_ERROR
- assert issue.description == "Infinite loop detected"
- assert issue.confidence_score is not None
- assert issue.confidence_score.normalized == 0.95
- assert issue.severity_score is not None
- assert issue.severity_score.normalized == 1.0 # severity 5 maps to 1.0
-
-
-def test_parse_issues_empty_response() -> None:
- project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="test")
-
- empty_response = json.dumps({"issues": []})
-
- issues = _parse_issues(empty_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
- assert len(issues) == 0
-
-
-def test_parse_issues_invalid_json() -> None:
- project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="test")
-
- invalid_response = "not json"
-
- with expect_exact_logged_errors(["Response does not match the expected schema"]):
- issues = _parse_issues(invalid_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
- assert len(issues) == 0
-
-
-def test_parse_issues_with_markdown_formatting() -> None:
- project_context = BaseProjectContext(
- file_contents_by_path=FrozenDict({"test.py": "x = 1"}),
- cached_prompt_prefix="test",
- )
-
- markdown_response = (
- "```json\n"
- + json.dumps(
- {
- "issues": [
- {
- "issue_code": "runtime_error_risk",
- "description": "Test issue",
- "severity": 3,
- "confidence": 0.8,
- }
- ]
- }
- )
- + "\n```"
- )
-
- issues = _parse_issues(markdown_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
- assert len(issues) == 1
- assert issues[0].code == IssueCode.RUNTIME_ERROR_RISK
-
-
-def test_parse_issues_invalid_severity() -> None:
- project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="test")
-
- invalid_severity_response = json.dumps(
- {
- "issues": [
- {
- "issue_code": "logic_error",
- "description": "Test issue",
- "severity": 10, # Invalid - should be 1-5
- "confidence": 0.8,
- }
- ]
- }
- )
-
- with expect_exact_logged_errors(["Response does not match the expected schema"]):
- issues = _parse_issues(
- invalid_severity_response,
- project_context,
- ISSUE_CODES_FOR_CORRECTNESS_CHECK,
- )
- assert len(issues) == 0 # Should be skipped due to invalid severity
-
-
-def test_parse_issues_unknown_issue_code() -> None:
- project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="test")
-
- unknown_code_response = json.dumps(
- {
- "issues": [
- {
- "issue_code": "unknown_issue", # Not in our defined codes
- "description": "Test issue",
- "severity": 3,
- "confidence": 0.8,
- }
- ]
- }
- )
-
- with expect_exact_logged_errors(["Got issue code"]):
- issues = _parse_issues(unknown_code_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
- assert len(issues) == 0 # Should be skipped due to unknown code
-
-
-def test_parse_issues_missing_required_fields() -> None:
- project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="[ROLE=SYSTEM]\ntest")
-
- # Missing required field 'confidence'
- missing_field_response = json.dumps(
- {
- "issues": [
- {
- "issue_code": "logic_error",
- "description": "Test issue",
- "severity": 3,
- # Missing 'confidence' field
- }
- ]
- }
- )
-
- with expect_exact_logged_errors(["Response does not match the expected schema"]):
- issues = _parse_issues(missing_field_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
- assert len(issues) == 0 # Should be skipped due to missing field
-
-
-def test_parse_issues_invalid_confidence() -> None:
- project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="[ROLE=SYSTEM]\ntest")
-
- invalid_confidence_response = json.dumps(
- {
- "issues": [
- {
- "issue_code": "logic_error",
- "description": "Test issue",
- "severity": 3,
- "confidence": 1.5, # Invalid - should be 0.0-1.0
- }
- ]
- }
- )
-
- with expect_exact_logged_errors(["Response does not match the expected schema"]):
- issues = _parse_issues(
- invalid_confidence_response,
- project_context,
- ISSUE_CODES_FOR_CORRECTNESS_CHECK,
- )
- assert len(issues) == 0 # Should be skipped due to invalid confidence
-
-
-def test_parse_issues_with_line_ranges() -> None:
- code_content = "def hello():\n print('world')\n return True"
- project_context = BaseProjectContext(
- file_contents_by_path=FrozenDict({"test.py": code_content}),
- cached_prompt_prefix="[ROLE=SYSTEM]\ntest",
- )
-
- response_with_location = json.dumps(
- {
- "issues": [
- {
- "issue_code": "logic_error",
- "description": "Test issue with location",
- "location": "test.py",
- "code_part": "print('world')",
- "severity": 3,
- "confidence": 0.8,
- }
- ]
- }
- )
-
- issues = _parse_issues(response_with_location, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
- issue = only(issues)
- assert issue.location[0].filename == "test.py"
- assert len(issue.location) > 0 # Should have found line ranges
-
-
-def test_parse_issues_malformed_response_structure() -> None:
- project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="[ROLE=SYSTEM]\ntest")
-
- # Test with non-dict response
- non_dict_response = json.dumps(["not", "a", "dict"])
- with expect_exact_logged_errors(["Response does not match the expected schema"]):
- issues = _parse_issues(non_dict_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
- assert len(issues) == 0
-
- # Test with missing `issues` key
- missing_key_response = json.dumps({"other_key": ["some value", "another value"]})
- # note that this doesn't log an error; the model validation allows "issues" to be missing, and fills in an empty list
- issues = _parse_issues(missing_key_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
- assert len(issues) == 0
-
- # Test with missing everything
- missing_everything_response = json.dumps({})
- # note that this doesn't log an error; the model validation allows "issues" to be missing, and fills in an empty list
- issues = _parse_issues(missing_everything_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
- assert len(issues) == 0
-
- # Test with non-list `issues` value
- non_list_response = json.dumps({"issues": "not a list"})
- with expect_exact_logged_errors(["Response does not match the expected schema"]):
- issues = _parse_issues(non_list_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
- assert len(issues) == 0
-
-
-def test_format_issue_identification_guide_for_llm() -> None:
- complete_guide = IssueIdentificationGuide(
- issue_code=IssueCode.LOGIC_ERROR,
- guide="- Guideline 1\n- Guideline 2",
- examples=(
- "Example 1",
- "Example 2",
- ),
- exceptions=(
- "Exception 1",
- "Exception 2",
- ),
- )
-
- expected_formatted_complete_guide = """Guidelines:
- - Guideline 1
- - Guideline 2
-Examples:
- - Example 1
- - Example 2
-Exceptions:
- - Exception 1
- - Exception 2"""
-
- minimal_guide = IssueIdentificationGuide(issue_code=IssueCode.LOGIC_ERROR, guide="Only has a guide.")
- expected_formatted_minimal_guide = """Guidelines:
- Only has a guide."""
-
- formatted_complete_guide = format_issue_identification_guide_for_llm(complete_guide)
- assert formatted_complete_guide == expected_formatted_complete_guide
-
- formatted_minimal_guide = format_issue_identification_guide_for_llm(minimal_guide)
- assert formatted_minimal_guide == expected_formatted_minimal_guide
-
-
-def test_strips_absolute_filenames() -> None:
- project_context = BaseProjectContext(
- file_contents_by_path=FrozenDict({"test.py": "def test():\n while True:\n pass"}),
- cached_prompt_prefix="test",
- repo_path=Path("/code"),
- )
-
- valid_response = json.dumps(
- {
- "issues": [
- {
- "issue_code": "logic_error",
- "description": "Infinite loop detected",
- "location": "/code/test.py",
- "code_part": "while True:\n pass",
- "severity": 5,
- "confidence": 0.95,
- }
- ]
- }
- )
-
- issues = _parse_issues(valid_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
-
- issue = only(issues)
- assert issue.description == "Infinite loop detected"
- assert len(issue.location) == 1
- assert issue.location[0].filename == "test.py"
diff --git a/imbue_verify/issue_identifiers/harnesses/agentic.py b/imbue_verify/issue_identifiers/harnesses/agentic.py
@@ -1,361 +0,0 @@
-"""
-Agentic harness that checks a given diff for issues using Claude Code agents with tools.
-"""
-
-import concurrent.futures
-import json
-from concurrent.futures import ThreadPoolExecutor
-from functools import cached_property
-from typing import Any
-from typing import Generator
-
-import jinja2
-from loguru import logger
-
-from imbue_core.agents.agent_api.claude.data_types import ClaudeCodeOptions
-from imbue_core.agents.agent_api.data_types import AgentMessage
-from imbue_core.agents.agent_api.data_types import AgentToolName
-from imbue_core.agents.agent_api.data_types import READ_ONLY_TOOLS
-from imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
-from imbue_core.async_monkey_patches import log_exception
-from imbue_core.data_types import AgenticPhase
-from imbue_core.data_types import IssueCode
-from imbue_core.data_types import IssueIdentificationDebugInfo
-from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
-from imbue_core.data_types import LLMResponse
-from imbue_tools.get_conversation_history.input_data_types import CommitInputs
-from imbue_tools.repo_utils.context_utils import escape_prompt_markers
-from imbue_tools.repo_utils.project_context import ProjectContext
-from imbue_tools.types.imbue_verify_config import ImbueVerifyConfig
-from imbue_verify.issue_identifiers.base import IssueIdentifier
-from imbue_verify.issue_identifiers.common import GeneratedIssueSchema
-from imbue_verify.issue_identifiers.common import GeneratedResponseSchema
-from imbue_verify.issue_identifiers.common import extract_invocation_info_from_messages
-from imbue_verify.issue_identifiers.common import (
- format_issue_identification_guide_for_llm,
-)
-from imbue_verify.issue_identifiers.common import generate_issues_from_response_texts
-from imbue_verify.issue_identifiers.common import generate_response_from_claude_code
-from imbue_verify.issue_identifiers.harnesses.base import IssueIdentifierHarness
-from imbue_verify.issue_identifiers.identification_guides import (
- IssueIdentificationGuide,
-)
-
-PROMPT_TEMPLATE = """You are analyzing a code repository for potential issues. The repository files are available in {{ repo_path }}.
-
-Assume that a user requested work to be done and a programmer delivered the diff below.
-The changes from the diff are present in the codebase but are not yet committed.
-
-### User request ###
-{% filter indent(width=2) %}
-{{ commit_message }}
-{% endfilter %}
-
-### Diff (lines starting with `-` indicate removed code, and lines starting with `+` indicate added code) ###
-{% filter indent(width=2) %}
-{{ unified_diff }}
-{% endfilter %}
-###
-
-Your task is to help verify the quality of the diff.
-We care only about specific categories of important issues.
-The rubric below outlines these categories of important issues, and contains guidelines and examples to correctly identify them:
-{% for issue_code, guide in guides.items() %}
----
-**{{ issue_code }}**:
-{{ guide }}
-{% endfor %}
----
-
-Use your standard tools to explore the repository and analyze the code thoroughly.
-Look at the additional guidance section below for more details on how to find issues.
-
-After your analysis, provide your response in JSON format matching this schema:
-
-{{ response_schema | tojson(indent=2) }}
-
-For each issue found, provide:
-- issue_code: Category from the rubric above
-- description: Specific explanation of the issue
-- (if applicable) location: File path where the issue occurs (relative to {{ repo_path }})
-- (if applicable) code_part: Specific code snippet that has the issue. Your code snippet should be the exact same as the original code including whitespace.
-- severity: Integer 1-5 (1=minor, 5=critical)
-- confidence: Float 0.0-1.0 indicating your confidence
-
-Your response should look like:
-```json
-{
- "issues": [
- <list of issues>
- ]
-}
-```
-
-If no issues are found, return: ```json{"issues": []}```
-
-Focus on real issues that impact code quality, correctness, or maintainability.
-You must not return issues that were already present in the code or issues that are fixed by the diff.
-You must only return issues that were introduced by the diff.
-Do not report duplicate issues with the same or equivalent descriptions.
-
-### Additional Guidance for Finding Issues ###
-You should use a Task tool to create a parallel task for each issue type in the rubric.
-You should pass along the exact issue type definition with all details to the task.
-Once all the Tasks have completed you can collate their results.
-You should pass along any relevant information from the guidance below to the task.
-Here is a non-exhaustive list of things that you can do using your tools within the task to find issues:
-{% for issue_code, guidance in additional_guidance.items() %}
----
-**{{ issue_code }}**:
-{{ guidance }}
-{% endfor %}
----
-Note that this is just guidance on how to find issues, please refer to the rubric for the types of issues to find.
-"""
-
-ISSUE_TYPE_PROMPT_TEMPLATE = """You are analyzing a code repository for potential issues of type {{ issue_type }}. The repository files are available in {{ repo_path }}.
-
-Assume that a user requested work to be done and a programmer delivered the diff below.
-The changes from the diff are present in the codebase but are not yet committed.
-
-### User request ###
-{% filter indent(width=2) %}
-{{ commit_message }}
-{% endfilter %}
-
-### Diff (lines starting with `-` indicate removed code, and lines starting with `+` indicate added code) ###
-{% filter indent(width=2) %}
-{{ unified_diff }}
-{% endfilter %}
-###
-
-Your task is to help verify the quality of the diff.
-Here is the definition of the issue type you are looking for:
-**{{ issue_type }}**:
-{{ guide }}
-
-Use your standard tools to explore the repository and analyze the code thoroughly.
-ONLY look for issues related to {{ issue_type }}.
-Do NOT modify any files - this is read-only analysis.
-
-After your analysis, provide your response in JSON format matching this schema:
-
-{{ response_schema | tojson(indent=2) }}
-
-For each issue found, provide:
-- issue_code: Category from the rubric above
-- description: Specific explanation of the issue
-- (if applicable) location: File path where the issue occurs (relative to {{ repo_path }})
-- (if applicable) code_part: Specific code snippet that has the issue. Your code snippet should be the exact same as the original code including whitespace.
-- severity: Integer 1-5 (1=minor, 5=critical)
-- confidence: Float 0.0-1.0 indicating your confidence
-
-Your response should look like:
-```json
-{
- "issues": [
- <list of issues>
- ]
-}
-```
-
-If no issues of this type are found, return: ```json{"issues": []}```
-You must not return issues that were already present in the code or issues that are fixed by the diff.
-You must only return issues that were introduced by the diff.
-Do not report duplicate issues with the same or equivalent descriptions.
-"""
-
-
-MAX_PARALLEL_CLAUDE_CODE_SESSIONS = 5 # TODO: this was arbitrarily chosen
-ResponseText = str
-
-
-def _generate_issues_worker(
- issue_code: IssueCode,
- prompt: str,
- options: ClaudeCodeOptions,
-) -> tuple[IssueCode, ResponseText, list[AgentMessage]] | None:
- issue_result = generate_response_from_claude_code(prompt, options)
- if issue_result is None:
- return None
- return issue_code, issue_result[0], issue_result[1]
-
-
-class _AgenticIssueIdentifier(IssueIdentifier[CommitInputs]):
- _identification_guides: tuple[IssueIdentificationGuide, ...]
-
- def __init__(self, identification_guides: tuple[IssueIdentificationGuide, ...]) -> None:
- assert len(identification_guides) > 0, "At least one identification guide must be provided"
- self._identification_guides = identification_guides
-
- @cached_property
- def _response_schema(self) -> dict[str, Any]:
- return GeneratedResponseSchema.model_json_schema()
-
- def _get_prompt(
- self,
- project_context: ProjectContext,
- config: ImbueVerifyConfig, # unused
- identifier_inputs: CommitInputs,
- ) -> str:
- env = jinja2.Environment(undefined=jinja2.StrictUndefined)
- jinja_template = env.from_string(PROMPT_TEMPLATE)
- additional_guidance_by_issue_code = {
- guide.issue_code: guide.additional_guide_for_agent for guide in self._identification_guides
- }
-
- formatted_guides = {
- guide.issue_code: format_issue_identification_guide_for_llm(guide) for guide in self._identification_guides
- }
-
- prompt = jinja_template.render(
- {
- "repo_path": project_context.repo_path,
- "commit_message": escape_prompt_markers(identifier_inputs.goal),
- "unified_diff": escape_prompt_markers(identifier_inputs.diff),
- "guides": formatted_guides,
- "response_schema": self._response_schema,
- "additional_guidance": additional_guidance_by_issue_code,
- }
- )
- return prompt
-
- def _get_prompt_for_issue_type(
- self,
- project_context: ProjectContext,
- identifier_inputs: CommitInputs,
- guide: IssueIdentificationGuide,
- ) -> str:
- env = jinja2.Environment(undefined=jinja2.StrictUndefined)
- jinja_template = env.from_string(ISSUE_TYPE_PROMPT_TEMPLATE)
-
- formatted_guide = format_issue_identification_guide_for_llm(guide)
-
- prompt = jinja_template.render(
- {
- "repo_path": project_context.repo_path,
- "commit_message": escape_prompt_markers(identifier_inputs.goal),
- "unified_diff": escape_prompt_markers(identifier_inputs.diff),
- "guide": formatted_guide,
- "response_schema": self._response_schema,
- "issue_type": guide.issue_code,
- }
- )
- return prompt
-
- def identify_issues(
- self,
- identifier_inputs: CommitInputs,
- project_context: ProjectContext,
- config: ImbueVerifyConfig,
- ) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
- assert project_context.repo_path is not None, "Project context must have a valid repo_path, got None"
-
- config_model_name = config.language_model_generation_config.model_name
- if config_model_name in [anthropic_model.value for anthropic_model in AnthropicModelName]:
- model_name = config_model_name
- else:
- model_name = AnthropicModelName.CLAUDE_4_5_HAIKU_2025_10_01
- logger.info(
- "Config model_name {config_model_name} is not a valid Anthropic model, using default ({model_name}).",
- config_model_name=config_model_name,
- model_name=model_name,
- )
-
- options = ClaudeCodeOptions(
- cwd=project_context.repo_path,
- permission_mode="bypassPermissions", # Equivalent to --dangerously-skip-permissions
- allowed_tools=list(READ_ONLY_TOOLS) + [AgentToolName.BASH], # Allow read-only tools for analysis
- model=model_name,
- )
-
- if config.enable_parallel_agentic_issue_identification:
- llm_responses = []
-
- issue_prompts = [
- (
- guide.issue_code,
- self._get_prompt_for_issue_type(project_context, identifier_inputs, guide),
- )
- for guide in self._identification_guides
- ]
- with ThreadPoolExecutor(max_workers=MAX_PARALLEL_CLAUDE_CODE_SESSIONS) as executor:
- tasks = [
- executor.submit(_generate_issues_worker, issue_code, prompt, options)
- for issue_code, prompt in issue_prompts
- ]
-
- for task in concurrent.futures.as_completed(tasks):
- try:
- result = task.result()
- except Exception as e:
- log_exception(e, "Error processing issue type: {e}", e=e)
- continue
-
- if result is None:
- continue
-
- issue_code, issue_type_response_text, messages = result
-
- yield from generate_issues_from_response_texts(response_texts=(issue_type_response_text,))
-
- message_dumps = tuple(json.dumps(message.model_dump()) for message in messages)
- invocation_info = extract_invocation_info_from_messages(messages)
-
- llm_responses.append(
- LLMResponse(
- metadata=IssueIdentificationLLMResponseMetadata(
- agentic_phase=AgenticPhase.ISSUE_IDENTIFICATION,
- issue_type=issue_code,
- ),
- raw_response=message_dumps,
- invocation_info=invocation_info,
- )
- )
-
- return IssueIdentificationDebugInfo(llm_responses=tuple(llm_responses))
- else:
- prompt = self._get_prompt(project_context, config, identifier_inputs)
- claude_response = generate_response_from_claude_code(prompt, options)
- assert claude_response is not None
- response_text, messages = claude_response
-
- message_dumps = tuple(json.dumps(message.model_dump()) for message in messages)
- invocation_info = extract_invocation_info_from_messages(messages)
-
- llm_responses = [
- LLMResponse(
- metadata=IssueIdentificationLLMResponseMetadata(
- agentic_phase=AgenticPhase.ISSUE_IDENTIFICATION,
- issue_type=None,
- ),
- raw_response=message_dumps,
- invocation_info=invocation_info,
- )
- ]
-
- yield from generate_issues_from_response_texts(response_texts=(response_text,))
-
- return IssueIdentificationDebugInfo(llm_responses=tuple(llm_responses))
-
- def input_type(self) -> type[CommitInputs]:
- return CommitInputs
-
- @property
- def enabled_issue_codes(self) -> tuple[IssueCode, ...]:
- return tuple(guide.issue_code for guide in self._identification_guides)
-
- @property
- def requires_agentic_collation(self) -> bool:
- return True
-
- @property
- def identifies_code_issues(self) -> bool:
- return True
-
-
-class AgenticHarness(IssueIdentifierHarness[CommitInputs]):
- def make_issue_identifier(
- self, identification_guides: tuple[IssueIdentificationGuide, ...]
- ) -> IssueIdentifier[CommitInputs]:
- return _AgenticIssueIdentifier(identification_guides=identification_guides)
diff --git a/imbue_verify/issue_identifiers/harnesses/base.py b/imbue_verify/issue_identifiers/harnesses/base.py
@@ -1,17 +0,0 @@
-import abc
-from typing import Generic
-from typing import TypeVar
-
-from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
-from imbue_verify.issue_identifiers.base import IssueIdentifier
-from imbue_verify.issue_identifiers.identification_guides import (
- IssueIdentificationGuide,
-)
-
-T = TypeVar("T", bound=IdentifierInputs)
-
-
-class IssueIdentifierHarness(abc.ABC, Generic[T]):
- @abc.abstractmethod
- def make_issue_identifier(self, identification_guides: tuple[IssueIdentificationGuide, ...]) -> IssueIdentifier[T]:
- """Return an issue identifier based on this harness by binding it to the provided issue identification guides."""
diff --git a/imbue_verify/issue_identifiers/harnesses/conversation_single_prompt.py b/imbue_verify/issue_identifiers/harnesses/conversation_single_prompt.py
@@ -1,153 +0,0 @@
-"""
-Single-prompt issue identification harness that operates on the conversation history.
-
-Currently hard-coded to check for misleading behavior in a conversation.
-"""
-
-from functools import cached_property
-from typing import Any
-from typing import Generator
-
-import jinja2
-
-from imbue_core.agents.llm_apis.build_apis import build_language_model_from_config
-from imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams
-from imbue_core.data_types import AgenticPhase
-from imbue_core.data_types import IssueCode
-from imbue_core.data_types import IssueIdentificationDebugInfo
-from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
-from imbue_core.data_types import LLMResponse
-from imbue_core.itertools import only
-from imbue_tools.get_conversation_history.get_conversation_history import (
- format_conversation_history_for_prompt,
-)
-from imbue_tools.get_conversation_history.input_data_types import ConversationInputs
-from imbue_tools.repo_utils.project_context import ProjectContext
-from imbue_tools.types.imbue_verify_config import ImbueVerifyConfig
-from imbue_tools.util_prompts.conversation_prefix import CONVERSATION_PREFIX_TEMPLATE
-from imbue_verify.issue_identifiers.base import IssueIdentifier
-from imbue_verify.issue_identifiers.common import GeneratedIssueSchema
-from imbue_verify.issue_identifiers.common import GeneratedResponseSchema
-from imbue_verify.issue_identifiers.common import (
- extract_invocation_info_from_costed_response,
-)
-from imbue_verify.issue_identifiers.common import (
- format_issue_identification_guide_for_llm,
-)
-from imbue_verify.issue_identifiers.common import generate_issues_from_response_texts
-from imbue_verify.issue_identifiers.harnesses.base import IssueIdentifierHarness
-from imbue_verify.issue_identifiers.identification_guides import (
- IssueIdentificationGuide,
-)
-
-PROMPT_TEMPLATE = (
- CONVERSATION_PREFIX_TEMPLATE
- + """
-{% if cache_full_prompt %}[ROLE=USER_CACHED]{% else %}[ROLE=USER]{% endif %}{% if instruction_context %}
-Here are the instruction files that were provided to the agent:
-{{ instruction_context }}{% endif %}
-
-Your task is to examine the conversation history to find events of interest.
-These events will be used to generate suggestions for what the agent should do next to best achieve the user's goal.
-We care only about specific categories of events. The rubric below outlines these categories of events, and contains guidelines and examples to correctly identify them:
-{% for guide_name, guide in guides.items() %}
----
-**{{ guide_name }}**:
-{{ guide }}
-
-{% endfor %}
----
-
-Respond with valid JSON that matches this exact schema:
-
-{{ response_schema | tojson(indent=2) }}
-
-[ROLE=ASSISTANT]
-"""
-)
-
-
-class _ConversationSinglePromptIssueIdentifier(IssueIdentifier[ConversationInputs]):
- _identification_guides: tuple[IssueIdentificationGuide, ...]
-
- def __init__(self, identification_guides: tuple[IssueIdentificationGuide, ...]) -> None:
- self._identification_guides = identification_guides
-
- @cached_property
- def _response_schema(self) -> dict[str, Any]:
- return GeneratedResponseSchema.model_json_schema()
-
- def _get_prompt(
- self,
- project_context: ProjectContext,
- config: ImbueVerifyConfig,
- identifier_inputs: ConversationInputs,
- ) -> str:
- # Sort the guides by issue code to ensure prompt caching (and snapshotting in tests) works.
- sorted_guides = sorted(self._identification_guides, key=lambda guide: guide.issue_code)
- formatted_guides = {
- guide.issue_code: format_issue_identification_guide_for_llm(guide) for guide in sorted_guides
- }
-
- env = jinja2.Environment(undefined=jinja2.StrictUndefined)
- jinja_template = env.from_string(PROMPT_TEMPLATE)
- return jinja_template.render(
- cached_prompt_prefix=project_context.cached_prompt_prefix,
- cache_full_prompt=config.cache_full_prompt,
- conversation_history=format_conversation_history_for_prompt(identifier_inputs.conversation_history),
- # pyre-fixme[16]: SubrepoContext need not have a formatted_repo_context, and instruction_context can be None
- instruction_context=project_context.instruction_context.formatted_repo_context,
- response_schema=self._response_schema,
- guides=formatted_guides,
- )
-
- def identify_issues(
- self,
- identifier_inputs: ConversationInputs,
- project_context: ProjectContext,
- config: ImbueVerifyConfig,
- ) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
- language_model = build_language_model_from_config(config.language_model_generation_config)
- language_model_params = LanguageModelGenerationParams(
- temperature=config.temperature,
- max_tokens=config.max_output_tokens,
- )
- prompt = self._get_prompt(project_context, config, identifier_inputs)
- costed_response = language_model.complete_with_usage_sync(
- prompt,
- params=language_model_params,
- is_caching_enabled=language_model.cache_path is not None,
- )
-
- response = only(costed_response.responses)
- invocation_info = extract_invocation_info_from_costed_response(costed_response)
-
- llm_responses = (
- LLMResponse(
- metadata=IssueIdentificationLLMResponseMetadata(agentic_phase=AgenticPhase.ISSUE_IDENTIFICATION),
- raw_response=(response.text,),
- invocation_info=invocation_info,
- ),
- )
-
- yield from generate_issues_from_response_texts(response_texts=(response.text,))
-
- return IssueIdentificationDebugInfo(llm_responses=llm_responses)
-
- def input_type(self) -> type[ConversationInputs]:
- return ConversationInputs
-
- @property
- def enabled_issue_codes(self) -> tuple[IssueCode, ...]:
- return tuple(guide.issue_code for guide in self._identification_guides)
-
- @property
- def identifies_code_issues(self) -> bool:
- return False
-
-
-class ConversationSinglePromptHarness(IssueIdentifierHarness[ConversationInputs]):
- def make_issue_identifier(
- self, identification_guides: tuple[IssueIdentificationGuide, ...]
- ) -> IssueIdentifier[ConversationInputs]:
- return _ConversationSinglePromptIssueIdentifier(identification_guides=identification_guides)
diff --git a/imbue_verify/issue_identifiers/harnesses/conversation_single_prompt_test.py b/imbue_verify/issue_identifiers/harnesses/conversation_single_prompt_test.py
@@ -1,68 +0,0 @@
-import pytest
-
-from imbue_core.data_types import IssueCode
-from vet_types.chat_state import TextBlock
-from vet_types.ids import AssistantMessageID
-from vet_types.messages import AgentMessageSource
-from vet_types.messages import ChatInputUserMessage
-from vet_types.messages import LLMModel
-from vet_types.messages import ResponseBlockAgentMessage
-from imbue_tools.get_conversation_history.input_data_types import ConversationInputs
-from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
-from imbue_tools.get_conversation_history.input_data_types import (
- IdentifierInputsMissingError,
-)
-from imbue_verify.issue_identifiers.harnesses.conversation_single_prompt import (
- ConversationSinglePromptHarness,
-)
-from imbue_verify.issue_identifiers.identification_guides import (
- ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
-)
-
-
-def test_to_required_inputs() -> None:
- harness = ConversationSinglePromptHarness()
- classifier = harness.make_issue_identifier(
- identification_guides=(ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[IssueCode.MISLEADING_BEHAVIOR],)
- )
-
- # should support inputs where only the conversation history is present
- conversation_history_inputs = IdentifierInputs(
- maybe_conversation_history=(
- ChatInputUserMessage(
- text="fake content",
- model_name=LLMModel.CLAUDE_4_SONNET,
- ),
- )
- )
- cvi = classifier.to_required_inputs(conversation_history_inputs)
- assert isinstance(cvi, ConversationInputs)
-
- # and inputs where the conversation history and commit message are present
- conversation_history_and_commit_message_inputs = IdentifierInputs(
- maybe_conversation_history=(
- ResponseBlockAgentMessage(
- source=AgentMessageSource.AGENT,
- role="assistant",
- assistant_message_id=AssistantMessageID("fake_message_id"),
- content=(TextBlock(text="fake content"),),
- ),
- ),
- maybe_goal="test",
- maybe_diff="test",
- )
- cvi = classifier.to_required_inputs(conversation_history_and_commit_message_inputs)
- assert isinstance(cvi, ConversationInputs)
- assert cvi.maybe_goal == "test"
- assert cvi.maybe_diff == "test"
-
- # should not support inputs where the conversation history is absent
- commit_inputs = IdentifierInputs(maybe_goal="test", maybe_diff="test")
- with pytest.raises(IdentifierInputsMissingError):
- classifier.to_required_inputs(commit_inputs)
- file_inputs = IdentifierInputs(maybe_files=("test.py",))
- with pytest.raises(IdentifierInputsMissingError):
- classifier.to_required_inputs(file_inputs)
- no_inputs = IdentifierInputs()
- with pytest.raises(IdentifierInputsMissingError):
- classifier.to_required_inputs(no_inputs)
diff --git a/imbue_verify/issue_identifiers/harnesses/single_prompt.py b/imbue_verify/issue_identifiers/harnesses/single_prompt.py
@@ -1,200 +0,0 @@
-"""
-Simple zero-shot issue identification harness that checks a diff for issues in a single prompt.
-"""
-
-from functools import cached_property
-from typing import Any
-from typing import Generator
-
-import jinja2
-
-from imbue_core.agents.llm_apis.build_apis import build_language_model_from_config
-from imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams
-from imbue_core.data_types import AgenticPhase
-from imbue_core.data_types import IssueCode
-from imbue_core.data_types import IssueIdentificationDebugInfo
-from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
-from imbue_core.data_types import LLMResponse
-from imbue_core.itertools import only
-from imbue_tools.get_conversation_history.input_data_types import CommitInputs
-from imbue_tools.repo_utils.context_utils import escape_prompt_markers
-from imbue_tools.repo_utils.project_context import ProjectContext
-from imbue_tools.types.imbue_verify_config import ImbueVerifyConfig
-from imbue_verify.issue_identifiers.base import IssueIdentifier
-from imbue_verify.issue_identifiers.common import GeneratedIssueSchema
-from imbue_verify.issue_identifiers.common import GeneratedResponseSchema
-from imbue_verify.issue_identifiers.common import (
- extract_invocation_info_from_costed_response,
-)
-from imbue_verify.issue_identifiers.common import (
- format_issue_identification_guide_for_llm,
-)
-from imbue_verify.issue_identifiers.common import generate_issues_from_response_texts
-from imbue_verify.issue_identifiers.harnesses.base import IssueIdentifierHarness
-from imbue_verify.issue_identifiers.identification_guides import (
- IssueIdentificationGuide,
-)
-
-USER_REQUEST_PREFIX_TEMPLATE = """{{cached_prompt_prefix}}
-[ROLE=USER_CACHED]
-I'm working on a project, adding commits one after another. The current state of the project is captured by the codebase snapshot above.
-{% if extra_context %}
-
-=== ADDITIONAL CONTEXT BEGIN ===
-{{ extra_context }}
-=== ADDITIONAL CONTEXT END ===
-{% endif %}
-
-Assume that I asked for a piece of work to be done by specifying the user request and another programmer has delivered the diff.
-{% if include_request_and_diff %}
-Below, you can see the user request, as well as the delivered diff. IMPORTANT: The codebase snapshot already includes the changes made in this diff!
-
-=== USER REQUEST BEGIN ===
-{{ commit_message }}
-=== USER REQUEST END ===
-
-=== DIFF BEGIN (unified; lines starting with `-` are removed and `+` are added) ===
-{{ unified_diff }}
-=== DIFF END ===
-
-{% endif %}{% if not cache_full_prompt %}
-[ROLE=USER]{% endif %}
-"""
-
-
-PROMPT_TEMPLATE = (
- USER_REQUEST_PREFIX_TEMPLATE
- + """Your task is to help me verify the quality of the diff.
-
-We care only about specific categories of issues. The rubric below outlines these categories of issues, and contains guidelines and examples to correctly identify them:
-{% for issue_type_name, guide in guides.items() %}
-[Issue Category {{ loop.index }}: {{ issue_type_name }}]
-{{ guide }}
-[End of issue category: {{ issue_type_name }}]
-{% endfor %}
-
-## Instructions:
-
-1. Look at each category of issues outlined above, one at a time.
-2. For each given category, analyze the diff for issues that match the category.
-3. For each issue found, provide:
- - issue_code: One of the category names above
- - description: Specific explanation of what's wrong and what a better implementation could be. The description should not exceed a few sentences unless absolutely necessary.
- - location: File path where the issue occurs (if applicable)
- - code_part: Specific code snippet that has the issue (if applicable). Must match exactly, including whitespace. If the code part spans multiple lines, include the exact whitespace and newlines. If there are multiple locations that are relevant to the issue, select a single one to represent the issue.
- - severity: Integer 1-5 (1=minor issue, 5=critical issue that will definitely cause problems)
- - confidence: Float 0.0-1.0 indicating your confidence in this issue
-4. When you have identified all issues of the current category, move on to the next category and repeat the process.
-
-Respond with valid JSON that matches this exact schema:
-
-{{ response_schema | tojson(indent=2) }}
-
-Every issue you report must stand on its own, and should not reference other issues in its description.
-Do not report duplicate issues with the same or equivalent descriptions within one issue category.
-Do not output any issues that are merely based on the absence of information in the codebase snapshot.
-Do not speculate about the way a piece of code might get used if that use is not supported by the code included above.
-Only raise issues that were introduced by the diff.
-It is fine to output an empty list if no issues are found!
-
-IMPORTANT: Do not include any additional commentary outside the JSON response, your response should only contain the JSON object:
-
-```json
-{
- "issues": [
- <list of issues>
- ]
-}
-```
-[ROLE=ASSISTANT]
-"""
-)
-
-
-class _SinglePromptIssueIdentifier(IssueIdentifier[CommitInputs]):
- _identification_guides: tuple[IssueIdentificationGuide, ...]
-
- def __init__(self, identification_guides: tuple[IssueIdentificationGuide, ...]) -> None:
- self._identification_guides = identification_guides
-
- @cached_property
- def _response_schema(self) -> dict[str, Any]:
- return GeneratedResponseSchema.model_json_schema()
-
- def _get_prompt(
- self,
- project_context: ProjectContext,
- config: ImbueVerifyConfig,
- identifier_inputs: CommitInputs,
- ) -> str:
- # Sort the guides by issue code to ensure prompt caching (and snapshotting in tests) works.
- sorted_guides = sorted(self._identification_guides, key=lambda guide: guide.issue_code)
- formatted_guides = {
- guide.issue_code: format_issue_identification_guide_for_llm(guide) for guide in sorted_guides
- }
-
- env = jinja2.Environment(undefined=jinja2.StrictUndefined)
- jinja_template = env.from_string(PROMPT_TEMPLATE)
- return jinja_template.render(
- {
- "include_request_and_diff": True,
- "cached_prompt_prefix": project_context.cached_prompt_prefix,
- "cache_full_prompt": config.cache_full_prompt,
- "extra_context": (escape_prompt_markers(config.extra_context) if config.extra_context else None),
- "commit_message": escape_prompt_markers(identifier_inputs.goal),
- "unified_diff": escape_prompt_markers(identifier_inputs.diff),
- "guides": formatted_guides,
- "response_schema": self._response_schema,
- }
- )
-
- def identify_issues(
- self,
- identifier_inputs: CommitInputs,
- project_context: ProjectContext,
- config: ImbueVerifyConfig,
- ) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
- prompt = self._get_prompt(project_context, config, identifier_inputs)
- language_model = build_language_model_from_config(config.language_model_generation_config)
- language_model_params = LanguageModelGenerationParams(
- temperature=config.temperature,
- max_tokens=config.max_output_tokens,
- )
- costed_response = language_model.complete_with_usage_sync(
- prompt,
- params=language_model_params,
- is_caching_enabled=language_model.cache_path is not None,
- )
-
- response = only(costed_response.responses)
- invocation_info = extract_invocation_info_from_costed_response(costed_response)
-
- llm_responses = (
- LLMResponse(
- metadata=IssueIdentificationLLMResponseMetadata(agentic_phase=AgenticPhase.ISSUE_IDENTIFICATION),
- raw_response=(response.text,),
- invocation_info=invocation_info,
- ),
- )
-
- yield from generate_issues_from_response_texts(response_texts=(response.text,))
-
- return IssueIdentificationDebugInfo(llm_responses=llm_responses)
-
- def input_type(self) -> type[CommitInputs]:
- return CommitInputs
-
- @property
- def enabled_issue_codes(self) -> tuple[IssueCode, ...]:
- return tuple(guide.issue_code for guide in self._identification_guides)
-
- @property
- def identifies_code_issues(self) -> bool:
- return True
-
-
-class SinglePromptHarness(IssueIdentifierHarness[CommitInputs]):
- def make_issue_identifier(
- self, identification_guides: tuple[IssueIdentificationGuide, ...]
- ) -> IssueIdentifier[CommitInputs]:
- return _SinglePromptIssueIdentifier(identification_guides=identification_guides)
diff --git a/imbue_verify/issue_identifiers/harnesses/single_prompt_test.py b/imbue_verify/issue_identifiers/harnesses/single_prompt_test.py
@@ -1,174 +0,0 @@
-"""
-Tests for the SinglePromptHarness.
-"""
-
-import json
-from unittest import mock
-
-import pytest
-
-from imbue_core.agents.llm_apis.data_types import CostedLanguageModelResponse
-from imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams
-from imbue_core.agents.llm_apis.data_types import LanguageModelResponseUsage
-from imbue_core.agents.llm_apis.data_types import LanguageModelResponseWithLogits
-from imbue_core.agents.llm_apis.data_types import ResponseStopReason
-from imbue_core.agents.llm_apis.mock_api import LanguageModelMock
-from imbue_core.data_types import IssueCode
-from imbue_core.frozen_utils import FrozenDict
-from imbue_tools.get_conversation_history.input_data_types import CommitInputs
-from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
-from imbue_tools.get_conversation_history.input_data_types import (
- IdentifierInputsMissingError,
-)
-from imbue_tools.repo_utils.project_context import BaseProjectContext
-from imbue_tools.types.imbue_verify_config import ImbueVerifyConfig
-from imbue_verify.issue_identifiers.base import IssueIdentifier
-from imbue_verify.issue_identifiers.harnesses.single_prompt import SinglePromptHarness
-from imbue_verify.issue_identifiers.identification_guides import (
- ISSUE_CODES_FOR_CORRECTNESS_CHECK,
-)
-from imbue_verify.issue_identifiers.identification_guides import (
- ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
-)
-from imbue_verify.issue_identifiers.utils import ReturnCapturingGenerator
-
-
-class SinglePromptHarnessMock(LanguageModelMock):
- """Mock language model for testing SinglePromptHarness."""
-
- response_text: str = ""
-
- def complete_with_usage_sync(
- self,
- prompt: str,
- params: LanguageModelGenerationParams,
- is_caching_enabled: bool = True,
- ) -> CostedLanguageModelResponse:
- self.stats.complete_calls += 1
- response = LanguageModelResponseWithLogits(
- text=self.response_text,
- token_count=len(self.response_text.split()),
- stop_reason=ResponseStopReason.END_TURN,
- network_failure_count=0,
- token_probabilities=self._get_token_probabilities(self.response_text),
- )
- usage = LanguageModelResponseUsage(
- prompt_tokens_used=100,
- completion_tokens_used=50,
- dollars_used=0.001,
- caching_info=None,
- )
- return CostedLanguageModelResponse(usage=usage, responses=(response,))
-
-
-def make_identifier() -> IssueIdentifier:
- harness = SinglePromptHarness()
- identifier = harness.make_issue_identifier(
- identification_guides=tuple(
- ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[code] for code in ISSUE_CODES_FOR_CORRECTNESS_CHECK
- )
- )
- return identifier
-
-
-def test_to_required_inputs() -> None:
- identifier = make_identifier()
-
- # Should support inputs where only the commit message and diff are present
- commit_inputs = IdentifierInputs(maybe_goal="test", maybe_diff="test")
- cmi = identifier.to_required_inputs(commit_inputs)
- assert isinstance(cmi, CommitInputs)
-
- # Should support inputs where the commit message and diff are present
- combined_inputs = IdentifierInputs(
- maybe_goal="test",
- maybe_diff="test",
- maybe_files=("test.py",),
- maybe_conversation_history=(),
- )
- cmi = identifier.to_required_inputs(combined_inputs)
- assert isinstance(cmi, CommitInputs)
-
- # Should not support inputs where the commit message and diff are absent
- file_inputs = IdentifierInputs(maybe_files=("test.py",))
- with pytest.raises(IdentifierInputsMissingError):
- identifier.to_required_inputs(file_inputs)
- no_inputs = IdentifierInputs()
- with pytest.raises(IdentifierInputsMissingError):
- identifier.to_required_inputs(no_inputs)
-
- # Should not support inputs where only one of the commit message and diff are present
- commit_message_inputs = IdentifierInputs(maybe_goal="test", maybe_conversation_history=())
- with pytest.raises(IdentifierInputsMissingError):
- identifier.to_required_inputs(commit_message_inputs)
- diff_inputs = IdentifierInputs(maybe_diff="test")
- with pytest.raises(IdentifierInputsMissingError):
- identifier.to_required_inputs(diff_inputs)
-
-
-def test_get_prompt_structure() -> None:
- identifier = make_identifier()
- project_context = BaseProjectContext(
- file_contents_by_path=FrozenDict({"test.py": "print('hello')"}),
- cached_prompt_prefix="[ROLE=SYSTEM]\nSystem context here",
- )
- commit_inputs = CommitInputs(
- maybe_goal="Add hello world function",
- maybe_diff="+def hello():\n+ print('hello')",
- )
- config = ImbueVerifyConfig()
-
- prompt = identifier._get_prompt(project_context, config, commit_inputs)
-
- # Check that prompt contains key elements
- assert "System context here" in prompt
- assert "Add hello world function" in prompt
- assert "+def hello():" in prompt
- assert "logic_error" in prompt
- assert "runtime_error_risk" in prompt
- assert "issues" in prompt
- assert "schema" in prompt.lower() # Should contain schema from pydantic model
-
-
-def test_identify_issues_integration() -> None:
- """Test the full identify_issues flow with mocked LLM."""
- identifier = make_identifier()
-
- # Create mock language model with specific response
- response_text = json.dumps(
- {
- "issues": [
- {
- "issue_code": "logic_error",
- "description": "Test logic error",
- "severity": 4,
- "confidence": 0.9,
- }
- ]
- }
- )
-
- mock_language_model = SinglePromptHarnessMock(response_text=response_text)
- with mock.patch(
- "imbue_verify.issue_identifiers.harnesses.single_prompt.build_language_model_from_config",
- return_value=mock_language_model,
- ):
- project_context = BaseProjectContext(
- file_contents_by_path=FrozenDict({"test.py": "print('hello')"}),
- cached_prompt_prefix="[ROLE=SYSTEM]\nSystem context",
- )
- commit_inputs = IdentifierInputs(maybe_goal="Add hello function", maybe_diff="+print('hello')")
- config = ImbueVerifyConfig()
-
- inputs = identifier.to_required_inputs(commit_inputs)
- raw_issues_generator = identifier.identify_issues(inputs, project_context, config)
- raw_issues = []
- raw_issues_generator_with_capture = ReturnCapturingGenerator(raw_issues_generator)
- for raw_issue in raw_issues_generator_with_capture:
- raw_issues.append(raw_issue)
- llm_responses = raw_issues_generator_with_capture.return_value.llm_responses
-
- assert len(raw_issues) == 1
- assert raw_issues[0].issue_code == IssueCode.LOGIC_ERROR
- assert raw_issues[0].description == "Test logic error"
- assert len(llm_responses) > 0 # Should have LLM responses
diff --git a/imbue_verify/issue_identifiers/issue_deduplication.py b/imbue_verify/issue_identifiers/issue_deduplication.py
@@ -1,190 +0,0 @@
-import json
-from typing import Generator
-from typing import Iterable
-
-import jinja2
-
-from imbue_core.agents.llm_apis.build_apis import build_language_model_from_config
-from imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams
-from imbue_core.data_types import AgenticPhase
-from imbue_core.data_types import IssueCode
-from imbue_core.data_types import IssueIdentificationDebugInfo
-from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
-from imbue_core.data_types import LLMResponse
-from imbue_core.itertools import only
-from imbue_tools.repo_utils.context_utils import escape_prompt_markers
-from imbue_tools.types.imbue_verify_config import ImbueVerifyConfig
-from imbue_verify.issue_identifiers.common import GeneratedIssueSchema
-from imbue_verify.issue_identifiers.common import GeneratedResponseSchema
-from imbue_verify.issue_identifiers.common import (
- extract_invocation_info_from_costed_response,
-)
-from imbue_verify.issue_identifiers.common import (
- format_issue_identification_guide_for_llm,
-)
-from imbue_verify.issue_identifiers.common import generate_issues_from_response_texts
-from imbue_verify.issue_identifiers.identification_guides import (
- ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
-)
-from imbue_verify.issue_identifiers.utils import ReturnCapturingGenerator
-
-DEDUPLICATION_PROMPT_TEMPLATE = """[ROLE=USER]
-You are reviewing the results from parallel code analysis for potential issues.
-Multiple specialized checks analyzed the work of an automated coding agent, each focusing on checking for a specific type of issue.
-
-The rubric below outlines the categories of issues we care about:
-{% for issue_code, guide in guides.items() %}
----
-**{{ issue_code }}**:
-{{ guide }}
-{% endfor %}
----
-
-### Individual Analysis Results ###
-{{ generated_issues }}
-
-Your task is to:
-1. Consolidate any duplicate issues
-2. If duplicates are categorized as different issue types, pick the most appropriate issue type for the merged issue according to the category definitions above.
-3. Return the consolidated set of issues
-
-Guidelines:
-- Merge issues that refer to the same underlying problem and would be solved by the same fix. Make sure that their locations (if available) are the same, and that their descriptions describe the same underlying problem. The issue_code and other properties can be different.
-- A merged issue should represent a single problem. Never merge multiple distinct problems, even if they are closely related or share the same location.
-- Never merge issues that refer to different locations, functions or files.
-- Do not remove any issues, you may only re-categorize or merge issues
-- When merging issues, pick A SINGLE most relevant location + code_part pair from the issues that you are merging together. NEVER try to combine multiple locations or code_part into one. Just pick one of them. Make sure that you repeat the code part string verbatim (including any whitespaces) in the resulting merged issue.
-- The confidence value of a merged issue should be the highest confidence value among the issues being merged.
-
-After your analysis, provide your response in JSON format matching this schema:
-
-{{ response_schema | tojson(indent=2) }}
-
-Do not output any other JSON, only the consolidated issues in the specified format:
-```json
-{
- "issues": [
- <list of consolidated issues>
- ]
-}
-```
-[ROLE=ASSISTANT]
-"""
-
-
-def _get_deduplication_prompt(
- enabled_issue_codes: Iterable[IssueCode],
- generated_issues: str,
-) -> str:
- # Sort issue codes to make the resulting prompts deterministic (for snapshot tests and LLM caching)
- sorted_issue_codes = sorted(enabled_issue_codes)
- formatted_guides = {
- code: format_issue_identification_guide_for_llm(ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[code])
- for code in sorted_issue_codes
- }
-
- env = jinja2.Environment(undefined=jinja2.StrictUndefined)
- jinja_template = env.from_string(DEDUPLICATION_PROMPT_TEMPLATE)
-
- prompt = jinja_template.render(
- {
- "guides": formatted_guides,
- "response_schema": GeneratedResponseSchema.model_json_schema(),
- "generated_issues": escape_prompt_markers(generated_issues),
- }
- )
- return prompt
-
-
-def _convert_parsed_issues_to_combined_string(
- all_parsed_issues: Iterable[GeneratedIssueSchema],
-) -> str:
- """Convert all parsed issues from all issue types to a combined string for the deduplication prompt."""
- combined_issues = []
-
- for issue in all_parsed_issues:
- issue_dict = issue.model_dump()
- combined_issues.append(issue_dict)
-
- return json.dumps({"issues": combined_issues}, indent=2)
-
-
-def deduplicate_issues(
- issue_generator: Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo],
- config: ImbueVerifyConfig,
- enabled_issue_codes: Iterable[IssueCode],
-) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
- """
- Deduplicate issues from multiple issue identifiers.
-
- Args:
- issues: The issues to deduplicate.
- config: Settings for imbue verify.
- enabled_issue_codes: The issue types used by the issue identifiers.
-
- Returns:
- A generator of deduplicated issues. Returns IssueIdentificationDebugInfo after the generator is exhausted.
- """
-
- # This current implementation is not streaming. Rather, we collect all issues, then send them to the LLM for deduplication all at once.
- # In the future, we can consider changing this into a streaming version that performs deduplication as issues come in.
- all_issues = []
- issue_generator_with_capture = ReturnCapturingGenerator(issue_generator)
- for issue in issue_generator_with_capture:
- all_issues.append(issue)
- issue_generator_debug_info = issue_generator_with_capture.return_value
-
- # TODO: This is a bit hacky, since it breaks abstraction boundaries:
- # We need to apply some special handling here around issue filtration.
- # This will go away when in the future, we move the filtration step to after the deduplication step.
- # However, we can't do that yet, because the filtration currently only works for certain issue types.
- # For now, we make the following compromise:
- # - We deduplicate only over issues that pass filtration.
- # (The resulting deduplicated issues will implicitly be set to have passed filtration as well, as per default value of _passes_filtration)
- # - Issues that didn't pass filtration will be yielded out unchanged.
- issues_passing_filtration = [issue for issue in all_issues if issue.passes_filtration]
- issues_not_passing_filtration = [issue for issue in all_issues if not issue.passes_filtration]
-
- if len(issues_passing_filtration) <= 1:
- # None or one issues that pass filtration: nothing to deduplicate, return early
- for issue in all_issues:
- yield issue
- return issue_generator_debug_info
-
- language_model = build_language_model_from_config(config.language_model_generation_config)
-
- # As per above TODO, only deduplicate over issues that passed filtration
- combined_issues_string = _convert_parsed_issues_to_combined_string(issues_passing_filtration)
- prompt = _get_deduplication_prompt(enabled_issue_codes, combined_issues_string)
-
- costed_response = language_model.complete_with_usage_sync(
- prompt,
- params=LanguageModelGenerationParams(temperature=0.0, max_tokens=config.max_output_tokens),
- is_caching_enabled=language_model.cache_path is not None,
- )
-
- response = only(costed_response.responses)
- invocation_info = extract_invocation_info_from_costed_response(costed_response)
-
- yield from generate_issues_from_response_texts(response_texts=(response.text,))
-
- # As per above TODO, now also yield out all issues that didn't pass filtration unchanged (these will keep their passes_filtration=False)
- for issue in issues_not_passing_filtration:
- yield issue
-
- deduplication_llm_responses = (
- LLMResponse(
- metadata=IssueIdentificationLLMResponseMetadata(
- agentic_phase=AgenticPhase.DEDUPLICATION,
- issue_type=None,
- ),
- raw_response=(response.text,),
- invocation_info=invocation_info,
- ),
- )
-
- augmented_debug_info = IssueIdentificationDebugInfo(
- llm_responses=issue_generator_debug_info.llm_responses + deduplication_llm_responses
- )
-
- return augmented_debug_info
diff --git a/imbue_verify/issue_identifiers/issue_evaluation.py b/imbue_verify/issue_identifiers/issue_evaluation.py
@@ -1,295 +0,0 @@
-from typing import Generator
-
-import jinja2
-
-from imbue_core.agents.llm_apis.build_apis import build_language_model_from_config
-from imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams
-from imbue_core.data_types import AgenticPhase
-from imbue_core.data_types import IssueCode
-from imbue_core.data_types import IssueIdentificationDebugInfo
-from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
-from imbue_core.data_types import LLMResponse
-from imbue_core.itertools import only
-from imbue_core.pydantic_serialization import SerializableModel
-from imbue_tools.get_conversation_history.get_conversation_history import (
- format_conversation_history_for_prompt,
-)
-from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
-from imbue_tools.llm_output_parsing.parse_model_json_response import (
- ResponseParsingError,
-)
-from imbue_tools.llm_output_parsing.parse_model_json_response import (
- parse_model_json_response,
-)
-from imbue_tools.repo_utils.context_utils import escape_prompt_markers
-from imbue_tools.repo_utils.project_context import ProjectContext
-from imbue_tools.types.imbue_verify_config import DEFAULT_CONFIDENCE_THRESHOLD
-from imbue_tools.types.imbue_verify_config import ImbueVerifyConfig
-from imbue_tools.util_prompts.conversation_prefix import CONVERSATION_PREFIX_TEMPLATE
-from imbue_verify.issue_identifiers.common import GeneratedIssueSchema
-from imbue_verify.issue_identifiers.common import (
- extract_invocation_info_from_costed_response,
-)
-from imbue_verify.issue_identifiers.common import (
- format_issue_identification_guide_for_llm,
-)
-from imbue_verify.issue_identifiers.harnesses.single_prompt import (
- USER_REQUEST_PREFIX_TEMPLATE,
-)
-from imbue_verify.issue_identifiers.identification_guides import (
- ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
-)
-from imbue_verify.issue_identifiers.utils import ReturnCapturingGenerator
-
-CODE_BASED_CRITERIA = (
- "1. The issue is based on specific code, and not merely on the absence of information in the codebase snapshot. (true/false)",
- "2. The issue does not speculate about the way a piece of code might get used without having specific knowledge of how it's being used. (true/false)",
- "3. The issues seems important and not overly pedantic. (true/false)",
- "4. The issue was introduced by the diff. (true/false)",
- "5. The issue matches the issue type definition given below. (true/false)",
- "6. The issue flags a piece of code that is already being removed by the diff (line in diff starts with a `-`). (true/false)",
-)
-
-CONVERSATION_BASED_CRITERIA = ("1. The issue matches the issue type definition given below. (true/false)",)
-
-PROMPT_TEMPLATE = """Somebody has reviewed the {% if is_code_based_issue %}diff{% else %}conversation history{% endif %} and flagged an issue with it, which you can see here:
-
-### Issue description ###
-{% filter indent(width=2) %}
-{{ issue_description }}
-{% endfilter %}
-
-Please evaluate the issue and determine whether it matches the following criteria:
-
-{% for criterion in criteria %}
-{{ criterion }}
-{% endfor %}
-
-### Issue type definition ###
-{% filter indent(width=2) %}
-**{{ issue_code }}**:
-{{ guide }}
-{% endfilter %}
-
-Please answer the questions above in the form of a JSON object with this exact JSON schema:
-
-{{ response_schema | tojson(indent=2) }}
-
-The keys correspond to the question numbers ("q1" for question 1, "q2" for question 2, and so on), and the values should be boolean values indicating whether the issue matches the criteria (true or false).
-
-IMPORTANT: Do not include any additional commentary outside the JSON response, your response should only contain the JSON object:
-
-```json
-{
- "q1": <true|false>,
- "q2": <true|false>,
- ...
-}
-```
-[ROLE=ASSISTANT]
-"""
-
-
-def _get_full_prompt_template(is_code_based_issue: bool) -> str:
- """Get the full prompt template with the appropriate prefix."""
- prefix = USER_REQUEST_PREFIX_TEMPLATE if is_code_based_issue else CONVERSATION_PREFIX_TEMPLATE
- return prefix + PROMPT_TEMPLATE
-
-
-class CodeBasedEvaluationResponse(SerializableModel):
- q1: bool
- q2: bool
- q3: bool
- q4: bool
- q5: bool
- q6: bool
-
- def is_passing_result(self) -> bool:
- return all([self.q1, self.q2, self.q3, self.q4, self.q5]) and not self.q6
-
-
-class ConversationBasedEvaluationResponse(SerializableModel):
- q1: bool
-
- def is_passing_result(self) -> bool:
- return self.q1
-
-
-def _format_prompt(
- issue: GeneratedIssueSchema,
- project_context: ProjectContext,
- config: ImbueVerifyConfig,
- inputs: IdentifierInputs,
- is_code_based_issue: bool,
-) -> str:
- env = jinja2.Environment(undefined=jinja2.StrictUndefined)
- prompt_template = _get_full_prompt_template(is_code_based_issue)
- jinja_template = env.from_string(prompt_template)
- issue_code = IssueCode(issue.issue_code)
- guide = format_issue_identification_guide_for_llm(ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[issue_code])
-
- criteria = CODE_BASED_CRITERIA if is_code_based_issue else CONVERSATION_BASED_CRITERIA
- response_class = CodeBasedEvaluationResponse if is_code_based_issue else ConversationBasedEvaluationResponse
-
- template_vars = {
- "cached_prompt_prefix": project_context.cached_prompt_prefix,
- "cache_full_prompt": config.cache_full_prompt,
- "issue_description": issue.description,
- "issue_code": issue_code,
- "guide": guide,
- "criteria": criteria,
- "response_schema": response_class.model_json_schema(),
- "is_code_based_issue": is_code_based_issue,
- }
-
- if is_code_based_issue:
- template_vars["include_request_and_diff"] = True
- template_vars["commit_message"] = escape_prompt_markers(inputs.maybe_goal or "")
- template_vars["unified_diff"] = escape_prompt_markers(inputs.maybe_diff or "")
- template_vars["extra_context"] = escape_prompt_markers(config.extra_context) if config.extra_context else None
- else:
- template_vars["conversation_history"] = format_conversation_history_for_prompt(
- inputs.maybe_conversation_history or ()
- )
-
- return jinja_template.render(template_vars)
-
-
-def _parse_response(
- response_text: str, is_code_based_issue: bool
-) -> CodeBasedEvaluationResponse | ConversationBasedEvaluationResponse:
- # Fallback value of True for now, since we assume that most issues will pass the evaluation.
- if is_code_based_issue:
- FALLBACK_VALUE = CodeBasedEvaluationResponse(q1=True, q2=True, q3=True, q4=True, q5=True, q6=False)
- response_class = CodeBasedEvaluationResponse
- else:
- FALLBACK_VALUE = ConversationBasedEvaluationResponse(q1=True)
- response_class = ConversationBasedEvaluationResponse
-
- try:
- return parse_model_json_response(response_text, response_class)
- except ResponseParsingError:
- return FALLBACK_VALUE
-
-
-def evaluate_code_issue_through_llm(
- issue: GeneratedIssueSchema,
- inputs: IdentifierInputs,
- project_context: ProjectContext,
- config: ImbueVerifyConfig,
- is_code_based_issue: bool,
-) -> tuple[bool, tuple[LLMResponse, ...]]:
- """
- Args:
- issue: The issue to evaluate.
- inputs: The inputs which determine the content provided to the evaluator.
- project_context: Loaded data corresponding to the inputs, e.g. diffs or files.
- config: Settings for the language model used to evaluate the issue.
- is_code_based_issue: Whether this is a code-based issue (vs conversation-based).
-
- Returns:
- A tuple containing a boolean indicating whether the issue passes the evaluation and the LLM responses.
- If evaluation fails because the data to judge the issue is missing, the issue is taken to have passed the evaluation.
- """
- if not config.filter_issues_through_llm_evaluator:
- return True, ()
-
- # Check that we have the required data for evaluation
- if is_code_based_issue:
- if inputs.maybe_goal is None or inputs.maybe_diff is None:
- return True, ()
- else:
- if inputs.maybe_conversation_history is None:
- return True, ()
-
- language_model = build_language_model_from_config(config.language_model_generation_config)
-
- prompt = _format_prompt(issue, project_context, config, inputs, is_code_based_issue)
- costed_response = language_model.complete_with_usage_sync(
- prompt,
- params=LanguageModelGenerationParams(temperature=0.0, max_tokens=config.max_output_tokens),
- is_caching_enabled=language_model.cache_path is not None,
- )
-
- response = only(costed_response.responses)
- invocation_info = extract_invocation_info_from_costed_response(costed_response)
- results = _parse_response(response.text, is_code_based_issue)
-
- llm_responses = (
- LLMResponse(
- metadata=IssueIdentificationLLMResponseMetadata(
- agentic_phase=AgenticPhase.FILTRATION,
- issue_type=None,
- ),
- raw_response=(response.text,),
- invocation_info=invocation_info,
- ),
- )
-
- return results.is_passing_result(), llm_responses
-
-
-MODEL_CONFIDENCE_THRESHOLD_DEFAULTS: dict[str, float] = {
- "gpt-5.1-2025-11-13": 0.0,
-}
-
-
-def get_imbue_verify_confidence_threshold(config: ImbueVerifyConfig) -> float:
- model_name = config.language_model_generation_config.model_name
-
- if model_name in MODEL_CONFIDENCE_THRESHOLD_DEFAULTS:
- return MODEL_CONFIDENCE_THRESHOLD_DEFAULTS[model_name]
-
- if config.filter_issues_below_confidence is not None:
- return config.filter_issues_below_confidence
-
- return DEFAULT_CONFIDENCE_THRESHOLD
-
-
-def evaluate_issue_through_confidence(issue: GeneratedIssueSchema, config: ImbueVerifyConfig) -> bool:
- threshold = get_imbue_verify_confidence_threshold(config)
- return issue.confidence >= threshold
-
-
-def filter_issues(
- issue_generator: Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo],
- inputs: IdentifierInputs,
- project_context: ProjectContext,
- config: ImbueVerifyConfig,
- # Currently, the LLM-based filter only works reliably for code-related issue types.
- is_code_based_issue_generator: bool,
-) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
- """
- Filter issues based on the evaluation.
-
- Args:
- results: The issues to filter.
- inputs: The inputs which determine the content provided to the evaluator.
- project_context: Loaded data corresponding to the inputs, e.g. diffs or files.
- config: Settings for imbue verify.
-
- Returns:
- A generator of issues with the passes_filtration flag set.
- If evaluation fails because the data to judge the issue is missing, the issue is taken to have passed the evaluation.
- At the end of the generation, returns IssueIdentificationDebugInfo containing the LLM responses.
- """
-
- filter_llm_responses = []
-
- issue_generator_with_capture = ReturnCapturingGenerator(issue_generator)
- for issue in issue_generator_with_capture:
- passes_filtration = evaluate_issue_through_confidence(issue, config)
- if passes_filtration:
- passes_filtration, llm_responses = evaluate_code_issue_through_llm(
- issue, inputs, project_context, config, is_code_based_issue_generator
- )
- filter_llm_responses.extend(llm_responses)
- issue.set_passes_filtration(passes_filtration)
- yield issue
- issue_generator_debug_info = issue_generator_with_capture.return_value
-
- augmented_debug_info = IssueIdentificationDebugInfo(
- llm_responses=issue_generator_debug_info.llm_responses + tuple(filter_llm_responses)
- )
-
- return augmented_debug_info
diff --git a/imbue_verify/issue_identifiers/registry.py b/imbue_verify/issue_identifiers/registry.py
@@ -1,262 +0,0 @@
-"""
-Registry of all the available issue identifiers with a `run` function for running them in an identification pipeline.
-"""
-
-from collections import defaultdict
-from enum import StrEnum
-from typing import Final
-from typing import Generator
-from typing import Iterable
-from typing import TypeVar
-
-from loguru import logger
-
-from imbue_core.agents.primitives.resource_limits import ensure_global_resource_limits
-from imbue_core.data_types import IssueCode
-from imbue_core.data_types import IssueIdentificationDebugInfo
-from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
-from imbue_core.data_types import IssueIdentifierResult
-from imbue_core.data_types import IssueIdentifierType
-from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
-from imbue_tools.get_conversation_history.input_data_types import (
- IdentifierInputsMissingError,
-)
-from imbue_tools.repo_utils.project_context import ProjectContext
-from imbue_tools.types.imbue_verify_config import ImbueVerifyConfig
-from imbue_tools.types.imbue_verify_config import get_enabled_issue_codes
-from imbue_verify.issue_identifiers.agentic_issue_collation import (
- collate_issues_with_agent,
-)
-from imbue_verify.issue_identifiers.base import IssueIdentifier
-from imbue_verify.issue_identifiers.common import GeneratedIssueSchema
-from imbue_verify.issue_identifiers.common import convert_to_issue_identifier_result
-from imbue_verify.issue_identifiers.harnesses.agentic import AgenticHarness
-from imbue_verify.issue_identifiers.harnesses.base import IssueIdentifierHarness
-from imbue_verify.issue_identifiers.harnesses.conversation_single_prompt import (
- ConversationSinglePromptHarness,
-)
-from imbue_verify.issue_identifiers.harnesses.single_prompt import SinglePromptHarness
-from imbue_verify.issue_identifiers.identification_guides import (
- ISSUE_CODES_FOR_BATCHED_COMMIT_CHECK,
-)
-from imbue_verify.issue_identifiers.identification_guides import (
- ISSUE_CODES_FOR_CONVERSATION_HISTORY_CHECK,
-)
-from imbue_verify.issue_identifiers.identification_guides import (
- ISSUE_CODES_FOR_CORRECTNESS_CHECK,
-)
-from imbue_verify.issue_identifiers.identification_guides import (
- ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
-)
-from imbue_verify.issue_identifiers.issue_deduplication import deduplicate_issues
-from imbue_verify.issue_identifiers.issue_evaluation import filter_issues
-from imbue_verify.issue_identifiers.utils import ReturnCapturingGenerator
-from imbue_verify.issue_identifiers.utils import multiplex_generators
-
-# Issue identifier harnesses together with certain default lists of issue codes.
-# This is intended as a transitionary structure to emulate the previous identifiers setup.
-# Eventually, we'll update ImbueVerifyConfig to no longer enable/disable specific identifiers, but instead
-# enable/disable harnesses and issue codes, and we'll pair up the enabled issue codes with the appropriate enabled
-# harnesses automatically.
-SINGLE_PROMPT_HARNESS = SinglePromptHarness()
-CONVERSATION_SINGLE_PROMPT_HARNESS = ConversationSinglePromptHarness()
-AGENTIC_HARNESS = AgenticHarness()
-HARNESS_PRESETS: Final[list[tuple[IssueIdentifierType, IssueIdentifierHarness, tuple[IssueCode, ...]]]] = [
- (
- IssueIdentifierType.AGENTIC_ISSUE_IDENTIFIER,
- AGENTIC_HARNESS,
- ISSUE_CODES_FOR_BATCHED_COMMIT_CHECK + ISSUE_CODES_FOR_CORRECTNESS_CHECK,
- ),
- (
- IssueIdentifierType.BATCHED_COMMIT_CHECK,
- SINGLE_PROMPT_HARNESS,
- ISSUE_CODES_FOR_BATCHED_COMMIT_CHECK,
- ),
- (
- IssueIdentifierType.CONVERSATION_HISTORY_IDENTIFIER,
- CONVERSATION_SINGLE_PROMPT_HARNESS,
- ISSUE_CODES_FOR_CONVERSATION_HISTORY_CHECK,
- ),
- (
- IssueIdentifierType.CORRECTNESS_COMMIT_CLASSIFIER,
- SINGLE_PROMPT_HARNESS,
- ISSUE_CODES_FOR_CORRECTNESS_CHECK,
- ),
-]
-
-
-def get_all_valid_identifier_names() -> set[IssueIdentifierType]:
- return {name for name, _, _ in HARNESS_PRESETS}
-
-
-EnumT = TypeVar("EnumT", bound=StrEnum)
-
-
-def _convert_all_to_enum(
- enum_strs: Iterable[str], all_enum_strs: Iterable[str], enum_type: type[EnumT]
-) -> tuple[EnumT]:
- results = []
- for enum_str in enum_strs:
- if enum_str not in all_enum_strs:
- raise ValueError(f"Bad config: unknown {enum_type.__name__} name: {enum_str}")
- results.append(enum_type(enum_str))
- return tuple(results)
-
-
-def _get_enabled_identifier_names(
- config: ImbueVerifyConfig,
-) -> set[IssueIdentifierType]:
- all_names = get_all_valid_identifier_names()
- explicitly_enabled = _convert_all_to_enum(config.enabled_identifiers or tuple(), all_names, IssueIdentifierType)
- explicitly_disabled = _convert_all_to_enum(config.disabled_identifiers or tuple(), all_names, IssueIdentifierType)
- enabled = set(explicitly_enabled) if len(explicitly_enabled) > 0 else all_names
- if len(explicitly_disabled) > 0:
- enabled = set(enabled) - set(explicitly_disabled)
- return enabled
-
-
-def _build_identifiers(
- identifiers_to_build: set[IssueIdentifierType], enabled_issue_codes: set[IssueCode]
-) -> list[tuple[str, IssueIdentifier]]:
- # Merge the enabled issue codes for each harness
- enabled_issue_codes_per_harness: defaultdict[IssueIdentifierHarness, set[IssueCode]] = defaultdict(set)
- combined_name_per_harness: defaultdict[IssueIdentifierHarness, list[str]] = defaultdict(list)
-
- for name, harness, default_issue_codes in HARNESS_PRESETS:
- if name in identifiers_to_build:
- enabled_issue_codes_for_harness = enabled_issue_codes & set(default_issue_codes)
- if enabled_issue_codes_for_harness:
- enabled_issue_codes_per_harness[harness].update(enabled_issue_codes_for_harness)
- combined_name_per_harness[harness].append(name.value)
-
- identifiers: list[tuple[str, IssueIdentifier]] = []
- for harness, issue_codes in enabled_issue_codes_per_harness.items():
- combined_name = "+".join(combined_name_per_harness[harness])
- identifiers.append(
- (
- combined_name,
- harness.make_issue_identifier(
- identification_guides=tuple(ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[code] for code in issue_codes)
- ),
- )
- )
-
- return identifiers
-
-
-def _generate_with_name_in_debug_info(
- name: str,
- generator: Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo],
-) -> Generator[GeneratedIssueSchema, None, tuple[str, IssueIdentificationDebugInfo]]:
- generator_with_capture = ReturnCapturingGenerator(generator)
- for result in generator_with_capture:
- yield result
- return name, generator_with_capture.return_value
-
-
-def _combine_issue_generator_debug_info(
- generator: Generator[GeneratedIssueSchema, None, tuple[tuple[str, IssueIdentificationDebugInfo], ...]],
-) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
- collected_debug_info: tuple[tuple[str, IssueIdentificationDebugInfo], ...] = (yield from generator)
-
- updated_llm_responses = []
- for identifier_name, debug_info in collected_debug_info:
- for response in debug_info.llm_responses:
- assert isinstance(response.metadata, IssueIdentificationLLMResponseMetadata)
- updated_response = response.evolve(response.ref().metadata.identifier_name, identifier_name)
- updated_llm_responses.append(updated_response)
-
- return IssueIdentificationDebugInfo(llm_responses=tuple(updated_llm_responses))
-
-
-def run(
- identifier_inputs: IdentifierInputs,
- project_context: ProjectContext,
- config: ImbueVerifyConfig,
-) -> Generator[IssueIdentifierResult, None, IssueIdentificationDebugInfo]:
- """
- Run all the registered and configured issue identifiers on the given inputs.
- """
- enabled_issue_codes = get_enabled_issue_codes(config)
- identifiers = _build_identifiers(_get_enabled_identifier_names(config), enabled_issue_codes)
- ensure_global_resource_limits(max_dollars=config.max_identifier_spend_dollars)
-
- issue_generators: list[Generator[GeneratedIssueSchema, None, tuple[str, IssueIdentificationDebugInfo]]] = []
- compatible_enabled_identifier_names: list[str] = []
- # The set of issue codes that can be detected by the compatible identifiers. A subset of enabled_issue_codes.
- detectable_issue_codes: set[IssueCode] = set()
- for identifier_name, identifier in identifiers:
- # 1. Identification
- try:
- inputs = identifier.to_required_inputs(identifier_inputs)
- identified_issues_generator = identifier.identify_issues(inputs, project_context, config)
- compatible_enabled_identifier_names.append(identifier_name)
- detectable_issue_codes.update(identifier.enabled_issue_codes)
- except IdentifierInputsMissingError as e:
- logger.debug(
- "skipping identifier {} because of missing inputs: {}",
- identifier_name,
- e,
- )
- continue
-
- # 2. Collation for agentic identifiers
- if identifier.requires_agentic_collation and config.enable_collation:
- try:
- collated_issues_generator = collate_issues_with_agent(
- identified_issues_generator,
- identifier_inputs,
- project_context,
- config,
- identifier.enabled_issue_codes,
- )
- except IdentifierInputsMissingError as e:
- logger.warning(
- "collate_issues_with_agent requires commit message and diff, skipping: {}",
- e,
- )
- continue
- else:
- collated_issues_generator = identified_issues_generator
-
- # 3. Filtration
- if config.filter_issues:
- filtered_results_generator = filter_issues(
- collated_issues_generator,
- identifier_inputs,
- project_context,
- config,
- is_code_based_issue_generator=identifier.identifies_code_issues,
- )
- else:
- filtered_results_generator = collated_issues_generator
-
- issue_generators.append(_generate_with_name_in_debug_info(identifier_name, filtered_results_generator))
-
- logger.info(
- "Using the following issue identifiers compatible with the input: {}",
- ", ".join([n for n in compatible_enabled_identifier_names]),
- )
-
- multiplexed_generators = multiplex_generators(issue_generators, max_workers=config.max_identify_workers)
- multiplexed_generators_with_combined_debug_info = _combine_issue_generator_debug_info(multiplexed_generators)
-
- # 4. Deduplicate issues across all identifiers
- if config.enable_deduplication:
- deduplicated_generator = deduplicate_issues(
- multiplexed_generators_with_combined_debug_info,
- config,
- tuple(detectable_issue_codes),
- )
- else:
- deduplicated_generator = multiplexed_generators_with_combined_debug_info
-
- # Conversion from GeneratedIssueSchema to IssueIdentifierResult
- converted_issues_generator = convert_to_issue_identifier_result(
- deduplicated_generator, project_context, tuple(enabled_issue_codes)
- )
-
- # Yield out results
- debug_info = yield from converted_issues_generator
-
- return debug_info
diff --git a/imbue_verify/issue_identifiers/test_prompt_lengths.py b/imbue_verify/issue_identifiers/test_prompt_lengths.py
@@ -1,60 +0,0 @@
-from imbue_core.data_types import IssueIdentifierType
-from imbue_core.frozen_utils import FrozenDict
-from imbue_core.itertools import first
-from imbue_tools.get_conversation_history.input_data_types import CommitInputs
-from imbue_tools.repo_utils.project_context import BaseProjectContext
-from imbue_tools.types.imbue_verify_config import ImbueVerifyConfig
-from imbue_verify.issue_identifiers import registry
-from imbue_verify.issue_identifiers.identification_guides import (
- ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
-)
-from imbue_verify.repo_utils import IMBUE_VERIFY_MAX_PROMPT_TOKENS
-
-EMPTY_PROJECT_CONTEXT = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="")
-DEFAULT_IMBUE_VERIFY_CONFIG = ImbueVerifyConfig()
-
-
-# Helper functions to extract a base prompt for different identifier types.
-PROMPT_EXTRACTOR_FUNCTIONS = {
- IssueIdentifierType.BATCHED_COMMIT_CHECK: lambda identifier: identifier._get_prompt(
- EMPTY_PROJECT_CONTEXT,
- DEFAULT_IMBUE_VERIFY_CONFIG,
- CommitInputs(maybe_goal="", maybe_diff=""),
- ),
- IssueIdentifierType.CORRECTNESS_COMMIT_CLASSIFIER: lambda identifier: identifier._get_prompt(
- EMPTY_PROJECT_CONTEXT,
- DEFAULT_IMBUE_VERIFY_CONFIG,
- CommitInputs(maybe_goal="", maybe_diff=""),
- ),
-}
-
-
-def _estimate_tokens(prompt: str) -> int:
- """
- Estimate the number of tokens in a prompt.
- This is a rough estimate and may not be accurate for all models.
- """
- # A factor of 1/4.5 appears to be a reasonable empirical estimate for current models.
- # We use a slighly larger factor (1/4) to have a more conservative estimate.
- return round(len(prompt) / 4)
-
-
-def test_prompt_lengths() -> None:
- """
- Test that the prompt lengths for various issue identifiers do not exceed the maximum allowed length.
- This is important to ensure that the LLM can process the prompts without raising errors.
- """
-
- for identifier_name, extract_prompt in PROMPT_EXTRACTOR_FUNCTIONS.items():
- identifier = first(
- [
- harness.make_issue_identifier(tuple(ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[c] for c in codes))
- for name, harness, codes in registry.HARNESS_PRESETS
- if name == identifier_name
- ]
- )
- prompt = extract_prompt(identifier)
- num_tokens = _estimate_tokens(prompt)
- assert (
- num_tokens <= IMBUE_VERIFY_MAX_PROMPT_TOKENS
- ), f"Prompt for {identifier_name} exceeds IMBUE_VERIFY_MAX_PROMPT_TOKENS. Consider increasing IMBUE_VERIFY_MAX_PROMPT_TOKENS or shortening the prompt. "
diff --git a/imbue_verify/issue_identifiers/utils_test.py b/imbue_verify/issue_identifiers/utils_test.py
@@ -1,92 +0,0 @@
-import contextvars
-import threading
-from typing import Generator
-
-from imbue_verify.issue_identifiers.utils import multiplex_generators
-from imbue_verify.issue_identifiers.utils import xml_post_escape
-
-
-def test_xml_post_escape_does_not_escape_if_not_necessary() -> None:
- input_string = "<root><code_part>hello</code_part></root>"
- assert xml_post_escape(input_string, "code_part") == input_string
-
-
-def test_xml_post_escape_properly_escapes_single_line() -> None:
- input_string = "<root><desc>Hey</desc><code_part>1 < 2</code_part></root>"
- assert xml_post_escape(input_string, "code_part") == "<root><desc>Hey</desc><code_part>1 < 2</code_part></root>"
-
-
-def test_xml_post_escape_properly_escapes_multi_line() -> None:
- input_string = """
- <root>
- <code_part>
- 1 < 2
- </code_part>
- </root>
- """
- assert (
- xml_post_escape(input_string, "code_part")
- == """
- <root>
- <code_part>
- 1 < 2
- </code_part>
- </root>
- """
- )
-
-
-def test_xml_post_escape_does_not_escape_if_not_asked_to() -> None:
- input_string = "<root><desc>Hey</desc><code_part>1 < 2</code_part></root>"
- assert xml_post_escape(input_string, "desc") == "<root><desc>Hey</desc><code_part>1 < 2</code_part></root>"
-
-
-def test_xml_post_escape_does_not_change_case() -> None:
- input_string = "<root><desc>Hey</desc><code_part>1 < 2</CODE_PART></root>"
- assert xml_post_escape(input_string, "code_part") == "<root><desc>Hey</desc><code_part>1 < 2</CODE_PART></root>"
-
-
-def test_xml_post_escape_does_nothing_if_element_not_present() -> None:
- input_string = "<root><greeting>hello</greeting></root>"
- assert xml_post_escape(input_string, "code_part") == input_string
-
-
-def _generator_with_barrier(value: int, count: int, barrier: threading.Barrier) -> Generator[int, None, int]:
- for i in range(count):
- barrier.wait(timeout=1.0)
- yield value + i
- return value * 100
-
-
-def test_multiplex_generators_runs_in_parallel() -> None:
- barrier = threading.Barrier(2)
-
- gen1 = _generator_with_barrier(0, 3, barrier)
- gen2 = _generator_with_barrier(10, 3, barrier)
-
- multiplexed = multiplex_generators([gen1, gen2], max_workers=2)
-
- results = []
- for item in multiplexed:
- results.append(item)
-
- assert len(results) == 6
- assert set(results) == {0, 1, 2, 10, 11, 12}
-
-
-def test_multiple_generators_transfers_contextvars() -> None:
- """Test that existing context variables are transferred to the generator threads."""
- var = contextvars.ContextVar("test_var", default=123)
-
- def _gen_with_contextvar() -> Generator[int, None, None]:
- yield var.get()
-
- gen = _gen_with_contextvar()
-
- multiplexed = multiplex_generators([gen])
-
- results = []
- for item in multiplexed:
- results.append(item)
-
- assert results == [123]
diff --git a/imbue_verify/repo_utils.py b/imbue_verify/repo_utils.py
@@ -1,73 +0,0 @@
-from pathlib import Path
-
-from imbue_core.async_monkey_patches import log_exception
-from imbue_core.computing_environment.data_types import RunCommandError
-from imbue_core.simple_git import SyncLocalGitRepo
-from imbue_tools.repo_utils.find_relative_to import find_relative_to_commit_hash
-from imbue_verify.errors import GitException
-
-# Maximum length of LLM prompts used within imbue_verify in tokens, without the repository-specific context.
-# Currently, the prompt is well under 10k tokens, but this value might need to be bumped up if we add a lot of additional
-# identification guides, few-shot examples, or other context.
-IMBUE_VERIFY_MAX_PROMPT_TOKENS = 10000
-
-
-def get_code_to_check(relative_to: str, repo_path: Path) -> tuple[str, str, str]:
- """
- Returns:
- - The commit hash to use as the base commit for the diff.
- - The combined diff including staged, unstaged, and untracked changes. (compatible with `git apply`)
- - The combined diff but with binary diffs shortened. (cannot be applied if binary changes are present)
- """
- try:
- base_commit = find_relative_to_commit_hash(relative_to, repo_path=repo_path)
- except RunCommandError as e:
- raise GitException(f"Unable to determine base commit for code verification: {e}") from e
-
- repo = SyncLocalGitRepo(repo_path)
-
- # Get the combined diff which includes all changes; staged, unstaged, and untracked.
- try:
- combined_diff = repo.get_git_diff(commit_hash=base_commit)
- combined_diff_no_binary = repo.get_git_diff(commit_hash=base_commit, include_binary=False)
- except RunCommandError as e:
- raise GitException(f"Unable to get diff to {base_commit}: {e}") from e
-
- # Get untracked files since we want to include these as part of the unstaged and full changes
- try:
- untracked_files = repo.get_untracked_files()
- except RunCommandError as e:
- raise GitException(f"Unable to get untracked files: {e}") from e
-
- # Create diffs for untracked files (treat them as new files)
- untracked_diffs = []
- untracked_diffs_no_binary = []
- for file_path in untracked_files:
- if file_path: # Skip empty lines
- try:
- untracked_diff = repo.get_untracked_file_diff(file_path, include_binary=True)
- untracked_diffs.append(untracked_diff)
- except RunCommandError as e:
- log_exception(
- e,
- "Skipping untracked file we couldn't diff: {file_path}",
- file_path=file_path,
- )
-
- try:
- untracked_diff_no_binary = repo.get_untracked_file_diff(file_path, include_binary=False)
- untracked_diffs_no_binary.append(untracked_diff_no_binary)
- except RunCommandError as e:
- log_exception(
- e,
- "Skipping untracked file we couldn't diff (no binary): {file_path}",
- file_path=file_path,
- )
-
- # Add untracked files to unstaged changes and the combined diff
- if untracked_diffs:
- combined_diff += "\n" + "\n".join(untracked_diffs)
- if untracked_diffs_no_binary:
- combined_diff_no_binary += "\n" + "\n".join(untracked_diffs_no_binary)
-
- return base_commit, combined_diff, combined_diff_no_binary
diff --git a/imbue_verify/repo_utils_test.py b/imbue_verify/repo_utils_test.py
@@ -1,97 +0,0 @@
-import subprocess
-from pathlib import Path
-
-from syrupy.assertion import SnapshotAssertion
-
-from imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
-from imbue_core.nested_evolver import assign
-from imbue_core.nested_evolver import chill
-from imbue_core.nested_evolver import evolver
-from imbue_tools.repo_utils.project_context import LazyProjectContext
-from imbue_verify.repo_utils import get_code_to_check
-
-
-def test_get_code_to_check(simple_test_git_repo: Path) -> None:
- """Test that get_code_to_check correctly handles staged, unstaged, and untracked files"""
- repo_path = simple_test_git_repo
- first_commit = subprocess.run(
- ["git", "rev-parse", "HEAD"],
- cwd=repo_path,
- capture_output=True,
- text=True,
- check=True,
- ).stdout.strip()
-
- # Create an untracked file
- new_file_content = "This is a new untracked file\nwith multiple lines\nof content"
- (repo_path / "new_file.txt").write_text(new_file_content)
- (repo_path / "new_file.bin").write_bytes(b"\x00\x01\x02")
-
- # Create a committed change
- (repo_path / "file1.txt").write_text("committed modified content\n")
- (repo_path / "file1.bin").write_bytes(b"\x00\x01\x02")
- subprocess.run(["git", "add", "file1.txt"], cwd=repo_path, check=True)
- subprocess.run(["git", "commit", "-m", "Modify file1"], cwd=repo_path, check=True)
-
- # Create a staged change
- with open((repo_path / "file1.txt"), "a+") as f:
- # make sure to have multiple newlines to sepearate changes so they don't get
- # picked up in same diff block
- f.write("\nstaged written modified content\n")
- subprocess.run(["git", "add", "file1.txt"], cwd=repo_path, check=True)
-
- # Create an unstaged change
- with open((repo_path / "file1.txt"), "a+") as f:
- f.write("\nunstaged written modified content")
-
- git_hash, diff, diff_no_binary = get_code_to_check(first_commit, repo_path=repo_path)
-
- assert git_hash == first_commit
-
- # Verify the untracked file is included in the diffs
- assert "new_file.txt" in diff
- assert "new_file.bin" in diff
- assert "new_file.txt" in diff_no_binary
- assert "new_file.bin" in diff_no_binary
- assert "Binary files /dev/null and b/new_file.bin differ" in diff_no_binary
-
- # Verify tracked changes are also included
- assert "file1.txt" in diff
- assert "+staged written modified content" in diff
- assert "+unstaged written modified content" in diff
- assert "+committed modified content" in diff
- assert "file1.bin" in diff
-
- assert "file1.txt" in diff_no_binary
- assert "+staged written modified content" in diff_no_binary
- assert "+unstaged written modified content" in diff_no_binary
- assert "+committed modified content" in diff_no_binary
- assert "Binary files /dev/null and b/file1.bin differ" in diff_no_binary
-
-
-def test_build_context(simple_test_git_repo: Path, snapshot: SnapshotAssertion) -> None:
- first_commit = subprocess.run(
- ["git", "rev-parse", "HEAD"],
- cwd=simple_test_git_repo,
- capture_output=True,
- text=True,
- check=True,
- ).stdout.strip()
- git_hash, diff, _diff_no_binary = get_code_to_check(first_commit, repo_path=simple_test_git_repo)
- project_context = LazyProjectContext.build(
- git_hash,
- diff,
- language_model_name=AnthropicModelName.CLAUDE_4_5_HAIKU_2025_10_01,
- repo_path=simple_test_git_repo,
- tokens_to_reserve=20000,
- ).to_base_project_context()
- assert project_context.repo_path == simple_test_git_repo
-
- # the temp dir isn't the same every time so we need to remove it
- project_context_evolver = evolver(project_context)
- assign(
- project_context_evolver.repo_path,
- lambda: None,
- )
- project_context_without_repo_path = chill(project_context_evolver)
- assert project_context_without_repo_path == snapshot
diff --git a/pyproject.toml b/pyproject.toml
@@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[project]
-name = "imbue_verify"
+name = "vet"
version = "0.1.0"
readme = "README.md"
dependencies = [
@@ -23,13 +23,13 @@ dependencies = [
requires-python = ">=3.11"
[project.scripts]
-imbue-verify = "imbue_verify.cli.main:main"
+vet = "vet.cli.main:main"
[tool.setuptools]
-package-data.imbue_verify = ["py.typed"]
+package-data.vet = ["py.typed"]
[tool.setuptools.packages.find]
-include = ["imbue_verify*"]
+include = ["vet*"]
[tool.uv.sources]
imbue_core = { path = "./imbue_core", editable = true }
diff --git a/uv.lock b/uv.lock
@@ -956,7 +956,6 @@ requires-dist = [
{ name = "async-lru" },
{ name = "attrs" },
{ name = "imbue-core", editable = "imbue_core" },
- { name = "imbue-verify", marker = "extra == 'test'", editable = "." },
{ name = "jinja2" },
{ name = "libcst" },
{ name = "loguru" },
@@ -969,53 +968,11 @@ requires-dist = [
{ name = "python-gitlab" },
{ name = "requests" },
{ name = "syrupy" },
+ { name = "vet", marker = "extra == 'test'", editable = "." },
]
provides-extras = ["test"]
[[package]]
-name = "imbue-verify"
-version = "0.1.0"
-source = { editable = "." }
-dependencies = [
- { name = "aiohttp" },
- { name = "click" },
- { name = "imbue-core" },
- { name = "imbue-tools" },
- { name = "jinja2" },
- { name = "loguru" },
- { name = "pydantic" },
- { name = "pygments" },
- { name = "pytest" },
- { name = "syrupy" },
- { name = "together" },
- { name = "vet-types" },
-]
-
-[package.dev-dependencies]
-dev = [
- { name = "black" },
-]
-
-[package.metadata]
-requires-dist = [
- { name = "aiohttp", specifier = ">=3.8.0" },
- { name = "click" },
- { name = "imbue-core", editable = "imbue_core" },
- { name = "imbue-tools", editable = "imbue_tools" },
- { name = "jinja2" },
- { name = "loguru" },
- { name = "pydantic" },
- { name = "pygments", specifier = ">=2.0.0" },
- { name = "pytest" },
- { name = "syrupy" },
- { name = "together", specifier = ">=1.5.35" },
- { name = "vet-types", editable = "vet_types" },
-]
-
-[package.metadata.requires-dev]
-dev = [{ name = "black" }]
-
-[[package]]
name = "iniconfig"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
@@ -2769,6 +2726,49 @@ wheels = [
]
[[package]]
+name = "vet"
+version = "0.1.0"
+source = { editable = "." }
+dependencies = [
+ { name = "aiohttp" },
+ { name = "click" },
+ { name = "imbue-core" },
+ { name = "imbue-tools" },
+ { name = "jinja2" },
+ { name = "loguru" },
+ { name = "pydantic" },
+ { name = "pygments" },
+ { name = "pytest" },
+ { name = "syrupy" },
+ { name = "together" },
+ { name = "vet-types" },
+]
+
+[package.dev-dependencies]
+dev = [
+ { name = "black" },
+]
+
+[package.metadata]
+requires-dist = [
+ { name = "aiohttp", specifier = ">=3.8.0" },
+ { name = "click" },
+ { name = "imbue-core", editable = "imbue_core" },
+ { name = "imbue-tools", editable = "imbue_tools" },
+ { name = "jinja2" },
+ { name = "loguru" },
+ { name = "pydantic" },
+ { name = "pygments", specifier = ">=2.0.0" },
+ { name = "pytest" },
+ { name = "syrupy" },
+ { name = "together", specifier = ">=1.5.35" },
+ { name = "vet-types", editable = "vet_types" },
+]
+
+[package.metadata.requires-dev]
+dev = [{ name = "black" }]
+
+[[package]]
name = "vet-types"
version = "0.1.0"
source = { editable = "vet_types" }
diff --git a/imbue-verify.toml b/vet.toml
diff --git a/imbue_verify/__init__.py b/vet/__init__.py
diff --git a/imbue_verify/__snapshots__/repo_utils_test.ambr b/vet/__snapshots__/repo_utils_test.ambr
diff --git a/vet/api.py b/vet/api.py
@@ -0,0 +1,125 @@
+"""Public API for vet.
+
+This module provides functions to identify issues in code changes. Issue identifiers are pieces of logic capable of finding issues in code.
+By default, vet runs all registered issue identifiers and returns all found issues.
+"""
+
+from pathlib import Path
+
+from loguru import logger
+
+from imbue_core.data_types import IdentifiedVerifyIssue
+from imbue_core.data_types import IssueIdentificationDebugInfo
+from vet_types.messages import ConversationMessageUnion
+from imbue_tools.get_conversation_history.get_conversation_history import (
+ ConversationLoadingError,
+)
+from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
+from imbue_tools.repo_utils.project_context import LazyProjectContext
+from imbue_tools.repo_utils.project_context import ProjectContext
+from imbue_tools.types.vet_config import VetConfig
+from imbue_tools.util_prompts.goal_from_conversation import get_goal_from_conversation
+from vet.issue_identifiers import registry
+from vet.issue_identifiers.utils import ReturnCapturingGenerator
+from vet.repo_utils import VET_MAX_PROMPT_TOKENS
+from vet.repo_utils import get_code_to_check
+
+
+def get_issues_with_raw_responses(
+ base_commit: str,
+ diff: str,
+ diff_no_binary: str,
+ goal: str,
+ config: VetConfig,
+ repo_path: Path,
+ conversation_history: tuple[ConversationMessageUnion, ...] | None = None,
+) -> tuple[tuple[IdentifiedVerifyIssue, ...], IssueIdentificationDebugInfo, ProjectContext]:
+ if not goal or not goal.strip():
+ logger.info("No goal was provided, generating one from conversation history")
+ # should be not None and not empty
+ if conversation_history:
+ try:
+ # TODO: we use the imbue verify config here, but we may want to configure this separately
+ goal = get_goal_from_conversation(conversation_history, config.language_model_generation_config)
+ logger.info("Generated goal from conversation history: {}", goal)
+ except Exception as e:
+ raise ConversationLoadingError(
+ f"No goal was provided and generating one from conversation history failed: {e}"
+ )
+ else:
+ # TODO: Consider which CLI options we should show this for (quiet, normal, verbose).
+ logger.info("No goal or conversation history provided, only goal-independent identifiers will run")
+ goal = ""
+
+ lm_config = config.language_model_generation_config
+ if diff_no_binary:
+ diff_no_binary_tokens = lm_config.count_tokens(diff_no_binary)
+ else:
+ diff_no_binary_tokens = 0
+
+ project_context = LazyProjectContext.build(
+ base_commit,
+ diff,
+ language_model_name=lm_config.model_name,
+ repo_path=repo_path,
+ # This needs to account for the vet prompt, as well as the max_tokens output tokens.
+ tokens_to_reserve=VET_MAX_PROMPT_TOKENS + diff_no_binary_tokens + config.max_output_tokens,
+ context_window=lm_config.get_max_context_length(),
+ is_custom_model=lm_config.is_custom_model(),
+ )
+
+ identifier_inputs = IdentifierInputs(
+ maybe_diff=diff_no_binary or None,
+ maybe_goal=goal,
+ maybe_conversation_history=conversation_history,
+ )
+
+ results_generator = registry.run(
+ identifier_inputs=identifier_inputs,
+ project_context=project_context,
+ config=config,
+ )
+
+ issues = []
+ results_generator_with_capture = ReturnCapturingGenerator(results_generator)
+ for result in results_generator_with_capture:
+ if result.passes_filtration:
+ issues.append(result.issue)
+ issue_identification_debug_info = results_generator_with_capture.return_value
+
+ return tuple(issues), issue_identification_debug_info, project_context
+
+
+def find_issues(
+ repo_path: Path,
+ relative_to: str,
+ goal: str,
+ config: VetConfig,
+ conversation_history: tuple[ConversationMessageUnion, ...] | None = None,
+) -> tuple[IdentifiedVerifyIssue, ...]:
+ logger.info(
+ "Finding issues in {repo_path} relative to commit hash {relative_to}",
+ repo_path=repo_path,
+ relative_to=relative_to,
+ )
+
+ base_commit, diff, diff_no_binary = get_code_to_check(relative_to, repo_path)
+ if not diff.strip():
+ logger.info(
+ "No code changes detected in repo {repo_path} since the specified relative_to commit {relative_to}, skipping issue identification",
+ repo_path=repo_path,
+ relative_to=relative_to,
+ )
+ # No code changes detected since the specified relative_to commit, so no issues to find.
+ return tuple()
+
+ issues, _, _ = get_issues_with_raw_responses(
+ base_commit=base_commit,
+ diff=diff,
+ diff_no_binary=diff_no_binary,
+ goal=goal,
+ config=config,
+ repo_path=repo_path,
+ conversation_history=conversation_history,
+ )
+ return issues
diff --git a/imbue_verify/cli/__init__.py b/vet/cli/__init__.py
diff --git a/imbue_verify/cli/config/__init__.py b/vet/cli/config/__init__.py
diff --git a/imbue_verify/cli/config/cli_config_schema.py b/vet/cli/config/cli_config_schema.py
diff --git a/vet/cli/config/cli_config_test.py b/vet/cli/config/cli_config_test.py
@@ -0,0 +1,446 @@
+from __future__ import annotations
+
+import argparse
+import os
+from pathlib import Path
+from unittest.mock import patch
+
+import pytest
+
+from vet.cli.config.cli_config_schema import CLI_DEFAULTS
+from vet.cli.config.cli_config_schema import CliConfigPreset
+from vet.cli.config.cli_config_schema import merge_presets
+from vet.cli.config.cli_config_schema import parse_cli_config_from_dict
+from vet.cli.config.loader import ConfigLoadError
+from vet.cli.config.loader import _load_cli_config_file
+from vet.cli.config.loader import get_cli_config_file_paths
+from vet.cli.config.loader import get_config_preset
+from vet.cli.config.loader import load_cli_config
+from vet.cli.main import apply_config_preset
+
+
+def test_parse_cli_config_from_dict_parses_single_config() -> None:
+ data = {
+ "ci": {
+ "confidence_threshold": 0.9,
+ "max_workers": 4,
+ "quiet": True,
+ }
+ }
+
+ result = parse_cli_config_from_dict(data)
+
+ assert "ci" in result
+ assert result["ci"].confidence_threshold == 0.9
+ assert result["ci"].max_workers == 4
+ assert result["ci"].quiet is True
+
+
+def test_parse_cli_config_from_dict_parses_multiple_configs() -> None:
+ data = {
+ "ci": {"confidence_threshold": 0.9},
+ "strict": {"confidence_threshold": 0.6, "model": "claude-4-sonnet"},
+ "default": {},
+ }
+
+ result = parse_cli_config_from_dict(data)
+
+ assert len(result) == 3
+ assert result["ci"].confidence_threshold == 0.9
+ assert result["strict"].confidence_threshold == 0.6
+ assert result["strict"].model == "claude-4-sonnet"
+ assert result["default"].confidence_threshold is None
+
+
+def test_parse_cli_config_from_dict_handles_all_fields() -> None:
+ data = {
+ "full": {
+ "goal": "Check for security issues",
+ "repo": "/path/to/repo",
+ "base_commit": "main",
+ "history_loader": "cat history.jsonl",
+ "extra_context": ["context1.txt", "context2.txt"],
+ "enabled_issue_codes": ["correctness", "style"],
+ "disabled_issue_codes": ["minor"],
+ "model": "test-model",
+ "temperature": 0.7,
+ "confidence_threshold": 0.85,
+ "max_workers": 8,
+ "output": "results.json",
+ "output_format": "json",
+ "output_fields": ["file", "line", "message"],
+ "verbose": True,
+ "quiet": False,
+ }
+ }
+
+ result = parse_cli_config_from_dict(data)
+
+ preset = result["full"]
+ assert preset.goal == "Check for security issues"
+ assert preset.repo == "/path/to/repo"
+ assert preset.base_commit == "main"
+ assert preset.history_loader == "cat history.jsonl"
+ assert preset.extra_context == ["context1.txt", "context2.txt"]
+ assert preset.enabled_issue_codes == ["correctness", "style"]
+ assert preset.disabled_issue_codes == ["minor"]
+ assert preset.model == "test-model"
+ assert preset.temperature == 0.7
+ assert preset.confidence_threshold == 0.85
+ assert preset.max_workers == 8
+ assert preset.output == "results.json"
+ assert preset.output_format == "json"
+ assert preset.output_fields == ["file", "line", "message"]
+ assert preset.verbose is True
+ assert preset.quiet is False
+
+
+def test_merge_presets_override_takes_precedence() -> None:
+ base = CliConfigPreset(confidence_threshold=0.8, max_workers=2, model="base-model")
+ override = CliConfigPreset(confidence_threshold=0.9, max_workers=None, model="override-model")
+
+ result = merge_presets(base, override)
+
+ assert result.confidence_threshold == 0.9
+ assert result.max_workers == 2
+ assert result.model == "override-model"
+
+
+def test_merge_presets_preserves_base_when_override_is_none() -> None:
+ base = CliConfigPreset(
+ confidence_threshold=0.8,
+ max_workers=4,
+ model="base-model",
+ verbose=True,
+ )
+ override = CliConfigPreset()
+
+ result = merge_presets(base, override)
+
+ assert result.confidence_threshold == 0.8
+ assert result.max_workers == 4
+ assert result.model == "base-model"
+ assert result.verbose is True
+
+
+def test_cli_defaults_and_cli_config_preset_have_same_fields() -> None:
+ """Verify CliDefaults and CliConfigPreset define the same fields.
+
+ These two models exist for different purposes:
+ - CliDefaults: Holds actual default values for CLI arguments (e.g., temperature=0.0)
+ - CliConfigPreset: Used for config file presets where None means "not specified"
+
+ They must have identical field names to ensure presets can override any default.
+ This test catches drift if a field is added to one model but not the other.
+ """
+ from vet.cli.config.cli_config_schema import CliDefaults
+
+ defaults_fields = set(CliDefaults.model_fields.keys())
+ preset_fields = set(CliConfigPreset.model_fields.keys())
+
+ assert defaults_fields == preset_fields, (
+ f"Field mismatch between CliDefaults and CliConfigPreset.\n"
+ f"Only in CliDefaults: {defaults_fields - preset_fields}\n"
+ f"Only in CliConfigPreset: {preset_fields - defaults_fields}"
+ )
+
+
+def test_get_cli_config_file_paths_returns_global_path(tmp_path: Path) -> None:
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path)}):
+ paths = get_cli_config_file_paths(repo_path=None)
+
+ assert len(paths) == 1
+ assert paths[0] == tmp_path / "vet" / "config.toml"
+
+
+def test_get_cli_config_file_paths_includes_project_path(tmp_path: Path) -> None:
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "xdg")}):
+ paths = get_cli_config_file_paths(repo_path=repo_path)
+
+ assert len(paths) == 2
+ assert paths[0] == tmp_path / "xdg" / "vet" / "config.toml"
+ assert paths[1] == repo_path / "vet.toml"
+
+
+def test_get_cli_config_file_paths_finds_git_root(tmp_path: Path) -> None:
+ git_root = tmp_path / "repo"
+ git_root.mkdir()
+ (git_root / ".git").mkdir()
+ subdir = git_root / "src" / "deep"
+ subdir.mkdir(parents=True)
+
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "xdg")}):
+ paths = get_cli_config_file_paths(repo_path=subdir)
+
+ assert paths[1] == git_root / "vet.toml"
+
+
+def test_load_cli_config_file_loads_valid_toml(tmp_path: Path) -> None:
+ config_file = tmp_path / "config.toml"
+ config_file.write_text(
+ """
+[ci]
+confidence_threshold = 0.9
+max_workers = 4
+quiet = true
+
+[strict]
+confidence_threshold = 0.6
+model = "claude-4-sonnet"
+"""
+ )
+
+ result = _load_cli_config_file(config_file)
+
+ assert "ci" in result
+ assert result["ci"].confidence_threshold == 0.9
+ assert result["ci"].max_workers == 4
+ assert result["ci"].quiet is True
+ assert "strict" in result
+ assert result["strict"].model == "claude-4-sonnet"
+
+
+def test_load_cli_config_file_raises_on_invalid_toml(tmp_path: Path) -> None:
+ config_file = tmp_path / "config.toml"
+ config_file.write_text("not = valid = toml")
+
+ with pytest.raises(ConfigLoadError) as exc_info:
+ _load_cli_config_file(config_file)
+
+ assert "Invalid TOML" in str(exc_info.value)
+
+
+def test_load_cli_config_file_raises_on_invalid_schema(tmp_path: Path) -> None:
+ config_file = tmp_path / "config.toml"
+ config_file.write_text(
+ """
+[ci]
+confidence_threshold = "not-a-float"
+"""
+ )
+
+ with pytest.raises(ConfigLoadError) as exc_info:
+ _load_cli_config_file(config_file)
+
+ assert "Invalid configuration" in str(exc_info.value)
+
+
+def test_load_cli_config_file_raises_on_unknown_field(tmp_path: Path) -> None:
+ config_file = tmp_path / "config.toml"
+ config_file.write_text(
+ """
+[ci]
+unknown_field = "value"
+"""
+ )
+
+ with pytest.raises(ConfigLoadError) as exc_info:
+ _load_cli_config_file(config_file)
+
+ assert "Invalid configuration" in str(exc_info.value)
+
+
+def test_load_cli_config_returns_empty_when_no_files_exist(tmp_path: Path) -> None:
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "nonexistent")}):
+ result = load_cli_config(repo_path=tmp_path)
+
+ assert result == {}
+
+
+def test_load_cli_config_loads_single_file(tmp_path: Path) -> None:
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+ config_file = repo_path / "vet.toml"
+ config_file.write_text(
+ """
+[ci]
+confidence_threshold = 0.9
+"""
+ )
+
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "nonexistent")}):
+ result = load_cli_config(repo_path=repo_path)
+
+ assert "ci" in result
+ assert result["ci"].confidence_threshold == 0.9
+
+
+def test_load_cli_config_merges_global_and_project(tmp_path: Path) -> None:
+ xdg_config = tmp_path / "xdg"
+ (xdg_config / "vet").mkdir(parents=True)
+ global_config = xdg_config / "vet" / "config.toml"
+ global_config.write_text(
+ """
+[ci]
+confidence_threshold = 0.8
+max_workers = 2
+
+[global-only]
+model = "global-model"
+"""
+ )
+
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+ project_config = repo_path / "vet.toml"
+ project_config.write_text(
+ """
+[ci]
+confidence_threshold = 0.9
+
+[project-only]
+model = "project-model"
+"""
+ )
+
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(xdg_config)}):
+ result = load_cli_config(repo_path=repo_path)
+
+ assert result["ci"].confidence_threshold == 0.9
+ assert result["ci"].max_workers == 2
+
+ assert "global-only" in result
+ assert result["global-only"].model == "global-model"
+ assert "project-only" in result
+ assert result["project-only"].model == "project-model"
+
+
+def test_get_config_preset_returns_preset() -> None:
+ configs = {
+ "ci": CliConfigPreset(confidence_threshold=0.9),
+ "strict": CliConfigPreset(confidence_threshold=0.6),
+ }
+
+ result = get_config_preset("ci", configs)
+
+ assert result.confidence_threshold == 0.9
+
+
+def test_get_config_preset_raises_on_unknown_with_available() -> None:
+ configs = {
+ "ci": CliConfigPreset(),
+ "strict": CliConfigPreset(),
+ }
+
+ with pytest.raises(ConfigLoadError) as exc_info:
+ get_config_preset("unknown", configs)
+
+ error_msg = str(exc_info.value)
+ assert "unknown" in error_msg
+ assert "ci" in error_msg
+ assert "strict" in error_msg
+
+
+def test_get_config_preset_raises_on_unknown_with_no_configs(tmp_path: Path) -> None:
+ configs: dict[str, CliConfigPreset] = {}
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "xdg")}):
+ with pytest.raises(ConfigLoadError) as exc_info:
+ get_config_preset("unknown", configs, repo_path)
+
+ error_msg = str(exc_info.value)
+ assert "unknown" in error_msg
+ assert "No configuration files found" in error_msg
+ # Verify the error message contains dynamically generated paths with labels
+ assert f"{tmp_path / 'xdg' / 'vet' / 'config.toml'} (global)" in error_msg
+ assert f"{repo_path / 'vet.toml'} (project)" in error_msg
+
+
+def _create_default_args() -> argparse.Namespace:
+ return argparse.Namespace(
+ model=CLI_DEFAULTS.model,
+ temperature=CLI_DEFAULTS.temperature,
+ confidence_threshold=CLI_DEFAULTS.confidence_threshold,
+ max_workers=CLI_DEFAULTS.max_workers,
+ output_format=CLI_DEFAULTS.output_format,
+ output_fields=CLI_DEFAULTS.output_fields,
+ verbose=CLI_DEFAULTS.verbose,
+ quiet=CLI_DEFAULTS.quiet,
+ enabled_issue_codes=CLI_DEFAULTS.enabled_issue_codes,
+ disabled_issue_codes=CLI_DEFAULTS.disabled_issue_codes,
+ )
+
+
+def test_apply_config_preset_applies_all_values() -> None:
+ args = _create_default_args()
+ preset = CliConfigPreset(
+ model="preset-model",
+ temperature=0.7,
+ confidence_threshold=0.9,
+ max_workers=4,
+ output_format="json",
+ output_fields=["file", "line"],
+ verbose=True,
+ quiet=False,
+ )
+
+ result = apply_config_preset(args, preset)
+
+ assert result.model == "preset-model"
+ assert result.temperature == 0.7
+ assert result.confidence_threshold == 0.9
+ assert result.max_workers == 4
+ assert result.output_format == "json"
+ assert result.output_fields == ["file", "line"]
+ assert result.verbose is True
+
+
+def test_apply_config_preset_cli_args_take_precedence() -> None:
+ args = argparse.Namespace(
+ model="cli-model",
+ temperature=0.0,
+ confidence_threshold=0.95,
+ max_workers=2,
+ output_format="text",
+ output_fields=None,
+ verbose=False,
+ quiet=False,
+ enabled_issue_codes=None,
+ disabled_issue_codes=None,
+ )
+ preset = CliConfigPreset(
+ model="preset-model",
+ temperature=0.3,
+ confidence_threshold=0.6,
+ max_workers=8,
+ )
+
+ result = apply_config_preset(args, preset)
+
+ assert result.model == "cli-model"
+ assert result.confidence_threshold == 0.95
+
+ assert result.temperature == 0.3
+ assert result.max_workers == 8
+
+
+def test_apply_config_preset_leaves_defaults_when_preset_is_none() -> None:
+ args = _create_default_args()
+ preset = CliConfigPreset()
+
+ result = apply_config_preset(args, preset)
+
+ assert result.model is None
+ assert result.temperature == 0.0
+ assert result.confidence_threshold == 0.8
+ assert result.max_workers == 2
+
+
+def test_apply_config_preset_handles_issue_codes() -> None:
+ args = _create_default_args()
+ preset = CliConfigPreset(
+ enabled_issue_codes=["incorrect_function_implementation"],
+ disabled_issue_codes=["bad_naming"],
+ )
+
+ result = apply_config_preset(args, preset)
+
+ assert len(result.enabled_issue_codes) == 1
+ assert result.enabled_issue_codes[0].value == "incorrect_function_implementation"
+ assert len(result.disabled_issue_codes) == 1
+ assert result.disabled_issue_codes[0].value == "bad_naming"
diff --git a/vet/cli/config/loader.py b/vet/cli/config/loader.py
@@ -0,0 +1,210 @@
+from __future__ import annotations
+
+import os
+import tomllib
+from pathlib import Path
+
+from pydantic import ValidationError
+
+from imbue_core.agents.configs import LanguageModelGenerationConfig
+from imbue_core.agents.configs import OpenAICompatibleModelConfig
+from imbue_core.agents.llm_apis.common import get_model_max_output_tokens
+from vet.cli.config.cli_config_schema import CliConfigPreset
+from vet.cli.config.cli_config_schema import merge_presets
+from vet.cli.config.cli_config_schema import parse_cli_config_from_dict
+from vet.cli.config.schema import ModelsConfig
+from vet.cli.config.schema import ProviderConfig
+
+
+class ConfigLoadError(Exception):
+ pass
+
+
+class MissingAPIKeyError(Exception):
+ def __init__(self, env_var: str, provider_name: str, model_id: str) -> None:
+ self.env_var = env_var
+ self.provider_name = provider_name
+ self.model_id = model_id
+ super().__init__(
+ f"API key not found: environment variable '{env_var}' is not set. "
+ + f"This is required for model '{model_id}' from provider '{provider_name}'."
+ )
+
+
+def get_xdg_config_home() -> Path:
+ xdg_config = os.environ.get("XDG_CONFIG_HOME")
+ if xdg_config:
+ return Path(xdg_config)
+ return Path.home() / ".config"
+
+
+def find_git_repo_root(start_path: Path) -> Path | None:
+ current = start_path.resolve()
+ while current != current.parent:
+ if (current / ".git").exists():
+ return current
+ current = current.parent
+ if (current / ".git").exists():
+ return current
+ return None
+
+
+def _get_config_file_paths(
+ global_subpath: str,
+ global_filename: str,
+ project_filename: str,
+ repo_path: Path | None = None,
+) -> list[Path]:
+ paths = [get_xdg_config_home() / global_subpath / global_filename]
+
+ if repo_path:
+ git_root = find_git_repo_root(repo_path)
+ root = git_root if git_root else repo_path
+ paths.append(root / project_filename)
+
+ return paths
+
+
+def get_config_file_paths(repo_path: Path | None = None) -> list[Path]:
+ return _get_config_file_paths("imbue", "models.json", "models.json", repo_path)
+
+
+def _load_single_config_file(config_path: Path) -> ModelsConfig:
+ try:
+ with open(config_path) as f:
+ return ModelsConfig.model_validate_json(f.read())
+ except ValidationError as e:
+ raise ConfigLoadError(f"Invalid configuration in {config_path}: {e}") from e
+ except OSError as e:
+ raise ConfigLoadError(f"Cannot read {config_path}: {e}") from e
+
+
+def load_models_config(repo_path: Path | None = None) -> ModelsConfig:
+ merged_providers: dict[str, ProviderConfig] = {}
+
+ for config_path in get_config_file_paths(repo_path):
+ if config_path.exists():
+ config = _load_single_config_file(config_path)
+ merged_providers.update(config.providers)
+
+ return ModelsConfig(providers=merged_providers)
+
+
+def get_user_defined_model_ids(config: ModelsConfig) -> set[str]:
+ model_ids: set[str] = set()
+ for provider in config.providers.values():
+ model_ids.update(provider.models.keys())
+ return model_ids
+
+
+def get_provider_for_model(model_id: str, config: ModelsConfig) -> ProviderConfig | None:
+ for provider in config.providers.values():
+ if model_id in provider.models:
+ return provider
+ return None
+
+
+def validate_api_key_for_model(model_id: str, config: ModelsConfig) -> None:
+ provider = get_provider_for_model(model_id, config)
+ if provider is None:
+ return
+
+ api_key_env = provider.api_key_env
+ if api_key_env is None:
+ return
+
+ api_key = os.environ.get(api_key_env, "")
+ if not api_key:
+ provider_name = provider.name or "unknown provider"
+ raise MissingAPIKeyError(
+ env_var=api_key_env,
+ provider_name=provider_name,
+ model_id=model_id,
+ )
+
+
+def get_models_by_provider_from_config(config: ModelsConfig) -> dict[str, list[str]]:
+ result: dict[str, list[str]] = {}
+ for provider_id, provider in config.providers.items():
+ display_name = provider.name or provider_id
+ result[display_name] = list(provider.models.keys())
+ return result
+
+
+def get_max_output_tokens_for_model(model_id: str, config: ModelsConfig) -> int | None:
+ provider = get_provider_for_model(model_id, config)
+ if provider is not None:
+ return provider.models[model_id].max_output_tokens
+
+ try:
+ return get_model_max_output_tokens(model_id)
+ except Exception:
+ return None
+
+
+def build_language_model_config(model_id: str, user_config: ModelsConfig) -> LanguageModelGenerationConfig:
+ provider = get_provider_for_model(model_id, user_config)
+ if provider is None:
+ return LanguageModelGenerationConfig(model_name=model_id)
+
+ model_config = provider.models[model_id]
+ actual_model_name = model_config.model_id or model_id
+
+ return OpenAICompatibleModelConfig(
+ model_name=actual_model_name,
+ custom_base_url=provider.base_url,
+ custom_api_key_env=provider.api_key_env or "",
+ custom_context_window=model_config.context_window,
+ custom_max_output_tokens=model_config.max_output_tokens,
+ )
+
+
+def get_cli_config_file_paths(repo_path: Path | None = None) -> list[Path]:
+ return _get_config_file_paths("vet", "config.toml", "vet.toml", repo_path)
+
+
+def _load_cli_config_file(config_path: Path) -> dict[str, CliConfigPreset]:
+ try:
+ with open(config_path, "rb") as f:
+ data = tomllib.load(f)
+ return parse_cli_config_from_dict(data)
+ except tomllib.TOMLDecodeError as e:
+ raise ConfigLoadError(f"Invalid TOML in {config_path}: {e}") from e
+ except ValidationError as e:
+ raise ConfigLoadError(f"Invalid configuration in {config_path}: {e}") from e
+ except OSError as e:
+ raise ConfigLoadError(f"Cannot read {config_path}: {e}") from e
+
+
+def load_cli_config(repo_path: Path | None = None) -> dict[str, CliConfigPreset]:
+ merged_configs: dict[str, CliConfigPreset] = {}
+
+ for config_path in get_cli_config_file_paths(repo_path):
+ if config_path.exists():
+ file_configs = _load_cli_config_file(config_path)
+ for name, preset in file_configs.items():
+ if name in merged_configs:
+ merged_configs[name] = merge_presets(merged_configs[name], preset)
+ else:
+ merged_configs[name] = preset
+
+ return merged_configs
+
+
+def get_config_preset(
+ config_name: str,
+ cli_configs: dict[str, CliConfigPreset],
+ repo_path: Path | None = None,
+) -> CliConfigPreset:
+ if config_name not in cli_configs:
+ available = sorted(cli_configs.keys())
+ if available:
+ raise ConfigLoadError(f"Configuration '{config_name}' not found. Available configs: {', '.join(available)}")
+ else:
+ paths = get_cli_config_file_paths(repo_path)
+ paths_list = "\n".join(f" - {p} ({'global' if i == 0 else 'project'})" for i, p in enumerate(paths))
+ raise ConfigLoadError(
+ f"Configuration '{config_name}' not found.\n\n"
+ f"No configuration files found. Create a config at one of these locations:\n{paths_list}"
+ )
+ return cli_configs[config_name]
diff --git a/vet/cli/config/loader_test.py b/vet/cli/config/loader_test.py
@@ -0,0 +1,377 @@
+from __future__ import annotations
+
+import json
+import os
+from pathlib import Path
+from unittest.mock import patch
+
+import pytest
+
+from vet.cli.config.loader import ConfigLoadError
+from vet.cli.config.loader import MissingAPIKeyError
+from vet.cli.config.loader import _load_single_config_file
+from vet.cli.config.loader import find_git_repo_root
+from vet.cli.config.loader import get_config_file_paths
+from vet.cli.config.loader import get_models_by_provider_from_config
+from vet.cli.config.loader import get_provider_for_model
+from vet.cli.config.loader import get_user_defined_model_ids
+from vet.cli.config.loader import get_xdg_config_home
+from vet.cli.config.loader import load_models_config
+from vet.cli.config.loader import validate_api_key_for_model
+from vet.cli.config.schema import ModelConfig
+from vet.cli.config.schema import ModelsConfig
+from vet.cli.config.schema import ProviderConfig
+
+
+def test_get_xdg_config_home_uses_env_var(tmp_path: Path) -> None:
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path)}):
+ assert get_xdg_config_home() == tmp_path
+
+
+def test_get_xdg_config_home_defaults_to_home_config() -> None:
+ with patch.dict(os.environ, {}, clear=True):
+ os.environ.pop("XDG_CONFIG_HOME", None)
+ result = get_xdg_config_home()
+ assert result == Path.home() / ".config"
+
+
+def test_find_git_repo_root_finds_root(tmp_path: Path) -> None:
+ git_root = tmp_path / "repo"
+ git_root.mkdir()
+ (git_root / ".git").mkdir()
+ subdir = git_root / "src" / "deep" / "nested"
+ subdir.mkdir(parents=True)
+
+ result = find_git_repo_root(subdir)
+ assert result == git_root
+
+
+def test_find_git_repo_root_returns_none_when_not_in_repo(tmp_path: Path) -> None:
+ non_repo = tmp_path / "not_a_repo"
+ non_repo.mkdir()
+
+ result = find_git_repo_root(non_repo)
+ assert result is None
+
+
+def test_get_config_file_paths_returns_global_path(tmp_path: Path) -> None:
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path)}):
+ paths = get_config_file_paths(repo_path=None)
+ assert len(paths) == 1
+ assert paths[0] == tmp_path / "imbue" / "models.json"
+
+
+def test_get_config_file_paths_finds_git_root(tmp_path: Path) -> None:
+ xdg_config = tmp_path / "xdg"
+ git_root = tmp_path / "repo"
+ git_root.mkdir()
+ (git_root / ".git").mkdir()
+ subdir = git_root / "src" / "submodule"
+ subdir.mkdir(parents=True)
+
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(xdg_config)}):
+ paths = get_config_file_paths(repo_path=subdir)
+ assert len(paths) == 2
+ assert paths[0] == xdg_config / "imbue" / "models.json"
+ assert paths[1] == git_root / "models.json"
+
+
+def test_load_single_config_file_loads_valid_config(tmp_path: Path) -> None:
+ config_file = tmp_path / "models.json"
+ config_data = {
+ "providers": {
+ "test-provider": {
+ "name": "Test Provider",
+ "api_type": "openai_compatible",
+ "base_url": "http://localhost:8080/v1",
+ "api_key_env": "TEST_API_KEY",
+ "models": {
+ "test-model": {
+ "model_id": "test-model-v1",
+ "context_window": 128000,
+ "max_output_tokens": 16384,
+ }
+ },
+ }
+ }
+ }
+ config_file.write_text(json.dumps(config_data))
+
+ result = _load_single_config_file(config_file)
+
+ assert "test-provider" in result.providers
+ provider = result.providers["test-provider"]
+ assert provider.name == "Test Provider"
+ assert provider.base_url == "http://localhost:8080/v1"
+ assert provider.api_key_env == "TEST_API_KEY"
+ assert "test-model" in provider.models
+ assert provider.models["test-model"].model_id == "test-model-v1"
+
+
+def test_load_single_config_file_raises_on_invalid_json(tmp_path: Path) -> None:
+ config_file = tmp_path / "models.json"
+ config_file.write_text("not valid json")
+
+ with pytest.raises(ConfigLoadError) as exc_info:
+ _load_single_config_file(config_file)
+ assert "Invalid JSON" in str(exc_info.value)
+
+
+def test_load_single_config_file_raises_on_invalid_schema(tmp_path: Path) -> None:
+ config_file = tmp_path / "models.json"
+ config_data = {
+ "providers": {
+ "test-provider": {
+ "name": "Test Provider",
+ }
+ }
+ }
+ config_file.write_text(json.dumps(config_data))
+
+ with pytest.raises(ConfigLoadError) as exc_info:
+ _load_single_config_file(config_file)
+ assert "Invalid configuration" in str(exc_info.value)
+
+
+def test_load_single_config_file_raises_on_invalid_api_type(tmp_path: Path) -> None:
+ config_file = tmp_path / "models.json"
+ config_data = {
+ "providers": {
+ "test-provider": {
+ "name": "Test Provider",
+ "api_type": "anthropic",
+ "base_url": "http://localhost:8080/v1",
+ "api_key_env": "TEST_API_KEY",
+ "models": {},
+ }
+ }
+ }
+ config_file.write_text(json.dumps(config_data))
+
+ with pytest.raises(ConfigLoadError) as exc_info:
+ _load_single_config_file(config_file)
+ assert "Invalid configuration" in str(exc_info.value)
+
+
+def test_load_models_config_returns_empty_when_no_files_exist(tmp_path: Path) -> None:
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "nonexistent")}):
+ result = load_models_config(repo_path=tmp_path)
+ assert result.providers == {}
+
+
+def test_load_models_config_loads_project_config(tmp_path: Path) -> None:
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+ config_file = repo_path / "models.json"
+ config_data = {
+ "providers": {
+ "project-provider": {
+ "base_url": "http://project:8080/v1",
+ "api_key_env": "PROJECT_KEY",
+ "models": {
+ "project-model": {
+ "context_window": 128000,
+ "max_output_tokens": 16384,
+ }
+ },
+ }
+ }
+ }
+ config_file.write_text(json.dumps(config_data))
+
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(tmp_path / "nonexistent")}):
+ result = load_models_config(repo_path=repo_path)
+
+ assert "project-provider" in result.providers
+
+
+def test_load_models_config_project_overrides_global(tmp_path: Path) -> None:
+ xdg_config = tmp_path / "xdg"
+ (xdg_config / "imbue").mkdir(parents=True)
+ global_config = xdg_config / "imbue" / "models.json"
+ global_config.write_text(
+ json.dumps(
+ {
+ "providers": {
+ "shared-provider": {
+ "name": "Global Name",
+ "base_url": "http://global:8080/v1",
+ "api_key_env": "GLOBAL_KEY",
+ "models": {
+ "global-model": {
+ "context_window": 128000,
+ "max_output_tokens": 16384,
+ }
+ },
+ }
+ }
+ }
+ )
+ )
+
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+ project_config = repo_path / "models.json"
+ project_config.write_text(
+ json.dumps(
+ {
+ "providers": {
+ "shared-provider": {
+ "name": "Project Name",
+ "base_url": "http://project:8080/v1",
+ "api_key_env": "PROJECT_KEY",
+ "models": {
+ "project-model": {
+ "context_window": 128000,
+ "max_output_tokens": 16384,
+ }
+ },
+ }
+ }
+ }
+ )
+ )
+
+ with patch.dict(os.environ, {"XDG_CONFIG_HOME": str(xdg_config)}):
+ result = load_models_config(repo_path=repo_path)
+
+ assert result.providers["shared-provider"].name == "Project Name"
+ assert result.providers["shared-provider"].base_url == "http://project:8080/v1"
+
+
+def test_get_user_defined_model_ids_extracts_all_ids() -> None:
+ config = ModelsConfig(
+ providers={
+ "provider1": ProviderConfig(
+ base_url="http://localhost:8080/v1",
+ api_key_env="KEY1",
+ models={
+ "model-a": ModelConfig(context_window=128000, max_output_tokens=16384),
+ "model-b": ModelConfig(context_window=128000, max_output_tokens=16384),
+ },
+ ),
+ "provider2": ProviderConfig(
+ base_url="http://localhost:8081/v1",
+ api_key_env="KEY2",
+ models={
+ "model-c": ModelConfig(context_window=128000, max_output_tokens=16384),
+ },
+ ),
+ }
+ )
+
+ result = get_user_defined_model_ids(config)
+
+ assert result == {"model-a", "model-b", "model-c"}
+
+
+def test_get_provider_for_model_finds_provider() -> None:
+ config = ModelsConfig(
+ providers={
+ "provider1": ProviderConfig(
+ base_url="http://localhost:8080/v1",
+ api_key_env="KEY1",
+ models={"model-a": ModelConfig(context_window=128000, max_output_tokens=16384)},
+ ),
+ "provider2": ProviderConfig(
+ base_url="http://localhost:8081/v1",
+ api_key_env="KEY2",
+ models={"model-b": ModelConfig(context_window=128000, max_output_tokens=16384)},
+ ),
+ }
+ )
+
+ result = get_provider_for_model("model-b", config)
+
+ assert result is not None
+ assert result.api_key_env == "KEY2"
+
+
+def test_get_provider_for_model_returns_none_for_unknown() -> None:
+ config = ModelsConfig(
+ providers={
+ "provider1": ProviderConfig(
+ base_url="http://localhost:8080/v1",
+ api_key_env="KEY1",
+ models={"model-a": ModelConfig(context_window=128000, max_output_tokens=16384)},
+ ),
+ }
+ )
+
+ result = get_provider_for_model("unknown-model", config)
+
+ assert result is None
+
+
+def test_validate_api_key_passes_when_key_is_set() -> None:
+ config = ModelsConfig(
+ providers={
+ "provider1": ProviderConfig(
+ name="Test Provider",
+ base_url="http://localhost:8080/v1",
+ api_key_env="TEST_API_KEY",
+ models={"model-a": ModelConfig(context_window=128000, max_output_tokens=16384)},
+ ),
+ }
+ )
+
+ with patch.dict(os.environ, {"TEST_API_KEY": "secret-key"}):
+ validate_api_key_for_model("model-a", config)
+
+
+def test_validate_api_key_raises_when_key_not_set() -> None:
+ config = ModelsConfig(
+ providers={
+ "provider1": ProviderConfig(
+ name="Test Provider",
+ base_url="http://localhost:8080/v1",
+ api_key_env="MISSING_KEY",
+ models={"model-a": ModelConfig(context_window=128000, max_output_tokens=16384)},
+ ),
+ }
+ )
+
+ with patch.dict(os.environ, {}, clear=True):
+ os.environ.pop("MISSING_KEY", None)
+ with pytest.raises(MissingAPIKeyError) as exc_info:
+ validate_api_key_for_model("model-a", config)
+
+ assert exc_info.value.env_var == "MISSING_KEY"
+ assert exc_info.value.model_id == "model-a"
+ assert "MISSING_KEY" in str(exc_info.value)
+
+
+def test_validate_api_key_passes_for_unknown_model() -> None:
+ config = ModelsConfig(providers={})
+ validate_api_key_for_model("unknown-model", config)
+
+
+def test_get_models_by_provider_groups_models() -> None:
+ config = ModelsConfig(
+ providers={
+ "ollama": ProviderConfig(
+ name="Ollama Local",
+ base_url="http://localhost:11434/v1",
+ api_key_env="OLLAMA_KEY",
+ models={
+ "llama3.2:latest": ModelConfig(context_window=128000, max_output_tokens=16384),
+ "qwen:7b": ModelConfig(context_window=32768, max_output_tokens=8192),
+ },
+ ),
+ "openrouter": ProviderConfig(
+ base_url="https://openrouter.ai/api/v1",
+ api_key_env="OPENROUTER_KEY",
+ models={
+ "anthropic/claude-3": ModelConfig(context_window=200000, max_output_tokens=16384),
+ },
+ ),
+ }
+ )
+
+ result = get_models_by_provider_from_config(config)
+
+ assert "Ollama Local" in result
+ assert set(result["Ollama Local"]) == {"llama3.2:latest", "qwen:7b"}
+
+ assert "openrouter" in result
+ assert result["openrouter"] == ["anthropic/claude-3"]
diff --git a/imbue_verify/cli/config/schema.py b/vet/cli/config/schema.py
diff --git a/vet/cli/main.py b/vet/cli/main.py
@@ -0,0 +1,502 @@
+from __future__ import annotations
+
+# The choice to use argparse was primarily driven by the idea that vet will be called by agents / llms.
+# Given this, we want to have the most standardized outputs possible.
+import argparse
+import json
+import subprocess
+import sys
+from importlib.metadata import version
+from pathlib import Path
+
+from loguru import logger
+
+from imbue_core.data_types import IssueCode
+from imbue_tools.get_conversation_history.get_conversation_history import (
+ parse_conversation_history,
+)
+from imbue_tools.types.vet_config import VetConfig
+from vet.api import find_issues
+from vet.cli.config.cli_config_schema import CLI_DEFAULTS
+from vet.cli.config.cli_config_schema import CliConfigPreset
+from vet.cli.config.loader import ConfigLoadError
+from vet.cli.config.loader import build_language_model_config
+from vet.cli.config.loader import get_cli_config_file_paths
+from vet.cli.config.loader import get_config_preset
+from vet.cli.config.loader import get_max_output_tokens_for_model
+from vet.cli.config.loader import load_cli_config
+from vet.cli.config.loader import load_models_config
+from vet.cli.config.loader import validate_api_key_for_model
+from vet.cli.config.schema import ModelsConfig
+from vet.cli.models import DEFAULT_MODEL_ID
+from vet.cli.models import get_models_by_provider
+from vet.cli.models import validate_model_id
+from vet.formatters import OUTPUT_FIELDS
+from vet.formatters import OUTPUT_FORMATS
+from vet.formatters import format_issue_text
+from vet.formatters import issue_to_dict
+from vet.formatters import validate_output_fields
+
+VERSION = version("vet")
+
+_ISSUE_CODE_FIELDS = frozenset({"enabled_issue_codes", "disabled_issue_codes"})
+_PATH_FIELDS = frozenset({"repo", "output"})
+_PATH_LIST_FIELDS = frozenset({"extra_context"})
+
+
+def create_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(
+ prog="vet",
+ description="Identify issues in code changes using LLM-based analysis.",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+
+ parser.add_argument(
+ "goal",
+ type=str,
+ nargs="?",
+ default=CLI_DEFAULTS.goal,
+ metavar="GOAL",
+ help=(
+ "Description of what the code change is trying to accomplish. "
+ + "If not provided, only goal-independent issue identifiers will run."
+ ),
+ )
+
+ parser.add_argument(
+ "--repo",
+ "-r",
+ type=Path,
+ default=Path.cwd(),
+ metavar="PATH",
+ help="Path to the repository for analysis (default: current directory)",
+ )
+
+ parser.add_argument(
+ "--version",
+ "-V",
+ action="version",
+ version=f"%(prog)s {VERSION}",
+ )
+
+ parser.add_argument(
+ "--config",
+ "-c",
+ type=str,
+ default=None,
+ metavar="NAME",
+ help="Name of the configuration to use. Configurations are defined in vet.toml in your target project's root or ~/.config/vet/config.toml.",
+ )
+ parser.add_argument(
+ "--list-configs",
+ action="store_true",
+ help="List all available named configurations",
+ )
+
+ diff_group = parser.add_argument_group("diff options")
+ diff_group.add_argument(
+ "--base-commit",
+ type=str,
+ default=CLI_DEFAULTS.base_commit,
+ metavar="REF",
+ help=f"Git commit, branch, or ref to use as the base for computing the diff (default: {CLI_DEFAULTS.base_commit})",
+ )
+
+ context_group = parser.add_argument_group("context options")
+ context_group.add_argument(
+ "--history-loader",
+ type=str,
+ default=CLI_DEFAULTS.history_loader,
+ metavar="COMMAND",
+ help=(
+ "Shell command that outputs conversation history as JSON to stdout. "
+ + "Used to derive a goal if one is not provided."
+ ),
+ )
+ context_group.add_argument(
+ "--extra-context",
+ type=Path,
+ nargs="*",
+ default=CLI_DEFAULTS.extra_context,
+ metavar="FILE",
+ help="Path(s) to file(s) containing additional context (e.g., library documentation). Content is included in the prompt after the codebase snapshot.",
+ )
+
+ analysis_group = parser.add_argument_group("analysis options")
+ # Valid issue codes are defined in imbue_core.data_types.IssueCode
+ analysis_group.add_argument(
+ "--enabled-issue-codes",
+ type=IssueCode,
+ nargs="+",
+ default=CLI_DEFAULTS.enabled_issue_codes,
+ metavar="CODE",
+ help="Only report issues of the given type(s). Use --list-issue-codes to see valid codes.",
+ )
+ analysis_group.add_argument(
+ "--disabled-issue-codes",
+ type=IssueCode,
+ nargs="+",
+ default=CLI_DEFAULTS.disabled_issue_codes,
+ metavar="CODE",
+ help="Do not report issues of the given type(s). Use --list-issue-codes to see valid codes.",
+ )
+ analysis_group.add_argument(
+ "--list-issue-codes",
+ action="store_true",
+ help="List all available issue codes",
+ )
+
+ model_group = parser.add_argument_group("model configuration")
+ model_group.add_argument(
+ "--model",
+ "-m",
+ type=str,
+ default=CLI_DEFAULTS.model,
+ metavar="MODEL",
+ help=f"LLM to use for analysis (default: {DEFAULT_MODEL_ID}). ",
+ )
+ model_group.add_argument(
+ "--list-models",
+ action="store_true",
+ help="List all available models",
+ )
+ model_group.add_argument(
+ "--temperature",
+ type=float,
+ default=CLI_DEFAULTS.temperature,
+ metavar="TEMP",
+ help=f"Override the default temperature for the model (default: {CLI_DEFAULTS.temperature}).",
+ )
+
+ filter_group = parser.add_argument_group("filtering options")
+ filter_group.add_argument(
+ "--confidence-threshold",
+ type=float,
+ default=CLI_DEFAULTS.confidence_threshold,
+ metavar="THRESHOLD",
+ help=f"Minimum confidence score (0.0-1.0) for issues to be reported (default: {CLI_DEFAULTS.confidence_threshold})",
+ )
+
+ parallel_group = parser.add_argument_group("parallelization options")
+ parallel_group.add_argument(
+ "--max-workers",
+ type=int,
+ default=CLI_DEFAULTS.max_workers,
+ metavar="N",
+ help=f"Maximum number of parallel workers for identification (default: {CLI_DEFAULTS.max_workers})",
+ )
+
+ output_group = parser.add_argument_group("output options")
+ output_group.add_argument(
+ "--output",
+ "-o",
+ type=Path,
+ default=CLI_DEFAULTS.output,
+ metavar="FILE",
+ help="Output file path (default: stdout). Use - to write to stdout.",
+ )
+ output_group.add_argument(
+ "--output-format",
+ type=str,
+ choices=OUTPUT_FORMATS,
+ default=CLI_DEFAULTS.output_format,
+ metavar="FORMAT",
+ help=f"Output format. Choices: {', '.join(OUTPUT_FORMATS)} (default: {CLI_DEFAULTS.output_format})",
+ )
+ output_group.add_argument(
+ "--output-fields",
+ type=str,
+ nargs="+",
+ default=CLI_DEFAULTS.output_fields,
+ metavar="FIELD",
+ help="Output fields to include (default: all)",
+ )
+ output_group.add_argument(
+ "--list-fields",
+ action="store_true",
+ help="List all available output data fields",
+ )
+ output_group.add_argument(
+ "--verbose",
+ "-v",
+ action="store_true",
+ default=CLI_DEFAULTS.verbose,
+ help="Show verbose logger messages",
+ )
+ output_group.add_argument(
+ "--quiet",
+ "-q",
+ action="store_true",
+ default=CLI_DEFAULTS.quiet,
+ help="Suppress progress indicator and non-essential output",
+ )
+
+ return parser
+
+
+def _get_available_issue_codes() -> list[IssueCode]:
+ return [code for code in IssueCode if not code.name.startswith("_DEPRECATED")]
+
+
+# TODO: There are logical groupings of codes we should consider because some issue_codes are associated with the same prompts / categories of issues.
+# This should likely be used to dictate the ordering instead of sorting.
+def list_issue_codes() -> None:
+ print("Available issue codes:")
+ print()
+ for code in sorted(_get_available_issue_codes(), key=lambda c: c.value):
+ print(f" {code.value}")
+
+
+def list_models(user_config: ModelsConfig | None = None) -> None:
+ print("Available models:")
+ print()
+ models_by_provider = get_models_by_provider(user_config)
+ for provider, model_ids in sorted(models_by_provider.items()):
+ print(f" {provider}:")
+ for model_id in sorted(model_ids):
+ default_marker = " (default)" if model_id == DEFAULT_MODEL_ID else ""
+ print(f" {model_id}{default_marker}")
+
+
+def list_fields() -> None:
+ print("Available output fields:")
+ print()
+ for field in OUTPUT_FIELDS:
+ print(f" {field}")
+
+
+def list_configs(cli_configs: dict[str, CliConfigPreset], repo_path: Path) -> None:
+ print("Available configurations:")
+ print()
+
+ if not cli_configs:
+ print(" No configurations found.")
+ print()
+ print("Configuration files are loaded from:")
+ for path in get_cli_config_file_paths(repo_path):
+ exists_marker = " (exists)" if path.exists() else ""
+ print(f" {path}{exists_marker}")
+ return
+
+ for name, preset in sorted(cli_configs.items()):
+ print(f" {name}:")
+ preset_dict = preset.model_dump(exclude_none=True)
+ if preset_dict:
+ for key, value in preset_dict.items():
+ print(f" {key}: {value}")
+ else:
+ print(" (uses all defaults)")
+ print()
+
+
+def configure_logging(verbose: bool, quiet: bool) -> None:
+ logger.remove()
+ if quiet:
+ level = "WARNING"
+ elif verbose:
+ level = "DEBUG"
+ else:
+ level = "INFO"
+ logger.add(sys.stderr, level=level)
+
+
+def load_conversation_from_command(command: str, cwd: Path) -> tuple:
+ result = subprocess.run(command, shell=True, capture_output=True, text=True, cwd=cwd)
+ if result.returncode != 0:
+ logger.warning(f"History loader command failed with exit code {result.returncode}: {result.stderr}")
+ return ()
+ if not result.stdout.strip():
+ return ()
+ return parse_conversation_history(result.stdout)
+
+
+def apply_config_preset(args: argparse.Namespace, preset: CliConfigPreset) -> argparse.Namespace:
+ preset_dict = preset.model_dump(exclude_none=True)
+
+ for field, preset_value in preset_dict.items():
+ default_value = getattr(CLI_DEFAULTS, field, None)
+ if getattr(args, field) == default_value:
+ if field in _ISSUE_CODE_FIELDS:
+ preset_value = [IssueCode(code) for code in preset_value]
+ elif field in _PATH_LIST_FIELDS:
+ preset_value = [Path(p) for p in preset_value]
+ elif field in _PATH_FIELDS:
+ preset_value = Path(preset_value)
+ setattr(args, field, preset_value)
+
+ return args
+
+
+def main(argv: list[str] | None = None) -> int:
+ parser = create_parser()
+ args = parser.parse_args(argv)
+
+ goal = args.goal or ""
+
+ repo_path = args.repo
+
+ try:
+ user_config = load_models_config(repo_path)
+ except ConfigLoadError as e:
+ print(f"Error loading model configuration: {e}", file=sys.stderr)
+ return 2
+
+ if args.list_issue_codes:
+ list_issue_codes()
+ return 0
+
+ if args.list_models:
+ list_models(user_config)
+ return 0
+
+ if args.list_fields:
+ list_fields()
+ return 0
+
+ try:
+ cli_configs = load_cli_config(repo_path)
+ except ConfigLoadError as e:
+ print(f"Error loading CLI configuration: {e}", file=sys.stderr)
+ return 2
+
+ if args.list_configs:
+ list_configs(cli_configs, repo_path)
+ return 0
+
+ if args.config is not None:
+ try:
+ preset = get_config_preset(args.config, cli_configs, repo_path)
+ args = apply_config_preset(args, preset)
+ except ConfigLoadError as e:
+ print(f"Error: {e}", file=sys.stderr)
+ return 2
+
+ if not repo_path.exists():
+ print(f"Error: Repository path does not exist: {repo_path}", file=sys.stderr)
+ return 2
+
+ if not repo_path.is_dir():
+ print(f"Error: Repository path is not a directory: {repo_path}", file=sys.stderr)
+ return 2
+
+ if args.extra_context:
+ for extra_context_file in args.extra_context:
+ if not extra_context_file.exists():
+ print(
+ f"Error: Extra context file does not exist: {extra_context_file}",
+ file=sys.stderr,
+ )
+ return 2
+
+ if args.verbose and args.quiet:
+ print(
+ "Error: --verbose and --quiet are mutually exclusive",
+ file=sys.stderr,
+ )
+ return 2
+
+ if not 0.0 <= args.confidence_threshold <= 1.0:
+ print(
+ f"Error: Confidence threshold must be between 0.0 and 1.0, got: {args.confidence_threshold}",
+ file=sys.stderr,
+ )
+ return 2
+
+ if not 0.0 <= args.temperature <= 2.0:
+ print(
+ f"Error: Temperature must be between 0.0 and 2.0, got: {args.temperature}",
+ file=sys.stderr,
+ )
+ return 2
+
+ configure_logging(args.verbose, args.quiet)
+
+ conversation_history = None
+ if args.history_loader is not None:
+ conversation_history = load_conversation_from_command(args.history_loader, repo_path)
+
+ extra_context = None
+ if args.extra_context:
+ extra_context_parts = []
+ for context_file in args.extra_context:
+ extra_context_parts.append(context_file.read_text())
+ extra_context = "\n\n".join(extra_context_parts)
+
+ if args.output_fields is not None:
+ try:
+ validate_output_fields(args.output_fields)
+ except ValueError as e:
+ print(f"Error: {e}", file=sys.stderr)
+ return 2
+
+ model_id = args.model or DEFAULT_MODEL_ID
+
+ try:
+ model_id = validate_model_id(model_id, user_config)
+ except ValueError as e:
+ print(f"Error: {e}", file=sys.stderr)
+ return 2
+
+ try:
+ validate_api_key_for_model(model_id, user_config)
+ except Exception as e:
+ print(f"Error: {e}", file=sys.stderr)
+ return 2
+
+ # TODO: Support OFFLINE, UPDATE_SNAPSHOT, and MOCKED modes.
+ language_model_config = build_language_model_config(model_id, user_config)
+ max_output_tokens = get_max_output_tokens_for_model(model_id, user_config)
+
+ config = VetConfig(
+ disabled_identifiers=("agentic_issue_identifier",),
+ language_model_generation_config=language_model_config,
+ enabled_issue_codes=(tuple(args.enabled_issue_codes) if args.enabled_issue_codes else None),
+ disabled_issue_codes=(tuple(args.disabled_issue_codes) if args.disabled_issue_codes else None),
+ temperature=args.temperature,
+ filter_issues_below_confidence=args.confidence_threshold,
+ max_identify_workers=args.max_workers,
+ max_output_tokens=max_output_tokens or 20000,
+ extra_context=extra_context,
+ )
+
+ issues = find_issues(
+ repo_path=repo_path,
+ relative_to=args.base_commit,
+ goal=goal,
+ config=config,
+ conversation_history=conversation_history,
+ )
+
+ output_fields = args.output_fields if args.output_fields else OUTPUT_FIELDS
+
+ output_file = None
+ if args.output is not None and str(args.output) != "-":
+ output_file = open(args.output, "w")
+ output_stream = output_file
+ else:
+ output_stream = sys.stdout
+
+ try:
+ if not issues:
+ if args.output_format == "json":
+ print(json.dumps({"issues": []}, indent=2), file=output_stream)
+ elif not args.quiet:
+ print("No issues found.", file=output_stream)
+ return 0
+
+ if args.output_format == "json":
+ issues_list = [issue_to_dict(issue, output_fields) for issue in issues]
+ print(json.dumps({"issues": issues_list}, indent=2), file=output_stream)
+ else:
+ for issue in issues:
+ print(format_issue_text(issue, output_fields), file=output_stream)
+ print(file=output_stream)
+
+ return 1
+ finally:
+ if output_file is not None:
+ output_file.close()
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/vet/cli/models.py b/vet/cli/models.py
@@ -0,0 +1,65 @@
+from __future__ import annotations
+
+from imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
+from imbue_core.agents.llm_apis.common import get_all_model_names
+from imbue_core.agents.llm_apis.gemini_api import GeminiModelName
+from imbue_core.agents.llm_apis.groq_api import GroqSupportedModelName
+from imbue_core.agents.llm_apis.openai_api import OpenAIModelName
+from imbue_core.agents.llm_apis.together_api import TogetherAIModelName
+from vet.cli.config.loader import get_models_by_provider_from_config
+from vet.cli.config.loader import get_user_defined_model_ids
+from vet.cli.config.schema import ModelsConfig
+
+DEFAULT_MODEL_ID = AnthropicModelName.CLAUDE_4_5_HAIKU_2025_10_01.value
+
+
+def get_builtin_model_ids() -> set[str]:
+ return {str(name) for name in get_all_model_names()}
+
+
+def get_all_model_ids(user_config: ModelsConfig | None = None) -> set[str]:
+ model_ids = get_builtin_model_ids()
+
+ if user_config:
+ model_ids.update(get_user_defined_model_ids(user_config))
+
+ return model_ids
+
+
+def is_valid_model_id(model_id: str, user_config: ModelsConfig | None = None) -> bool:
+ return model_id in get_all_model_ids(user_config)
+
+
+def is_user_defined_model(model_id: str, user_config: ModelsConfig | None = None) -> bool:
+ if user_config is None:
+ return False
+ return model_id in get_user_defined_model_ids(user_config)
+
+
+def validate_model_id(model_id: str, user_config: ModelsConfig | None = None) -> str:
+ if not is_valid_model_id(model_id, user_config):
+ raise ValueError(f"Unknown model: {model_id}. Use --list-models to see available models.")
+ return model_id
+
+
+def get_builtin_models_by_provider() -> dict[str, list[str]]:
+ return {
+ "anthropic": [m.value for m in AnthropicModelName],
+ "openai": [m.value for m in OpenAIModelName],
+ "gemini": [m.value for m in GeminiModelName],
+ "groq": [m.value for m in GroqSupportedModelName],
+ "together": [m.value for m in TogetherAIModelName],
+ }
+
+
+def get_models_by_provider(
+ user_config: ModelsConfig | None = None,
+) -> dict[str, list[str]]:
+ providers = get_builtin_models_by_provider()
+
+ if user_config:
+ user_providers = get_models_by_provider_from_config(user_config)
+ for provider_name, model_ids in user_providers.items():
+ providers[provider_name] = model_ids
+
+ return providers
diff --git a/vet/cli/models_test.py b/vet/cli/models_test.py
@@ -0,0 +1,168 @@
+from __future__ import annotations
+
+import pytest
+
+from vet.cli.config.schema import ModelConfig
+from vet.cli.config.schema import ModelsConfig
+from vet.cli.config.schema import ProviderConfig
+from vet.cli.models import DEFAULT_MODEL_ID
+from vet.cli.models import get_all_model_ids
+from vet.cli.models import get_builtin_model_ids
+from vet.cli.models import get_builtin_models_by_provider
+from vet.cli.models import get_models_by_provider
+from vet.cli.models import is_user_defined_model
+from vet.cli.models import is_valid_model_id
+from vet.cli.models import validate_model_id
+
+SAMPLE_USER_CONFIG = ModelsConfig(
+ providers={
+ "custom": ProviderConfig(
+ base_url="http://localhost:8080/v1",
+ api_key_env="CUSTOM_KEY",
+ models={
+ "my-custom-model": ModelConfig(context_window=128000, max_output_tokens=16384),
+ "another-model": ModelConfig(context_window=128000, max_output_tokens=16384),
+ },
+ )
+ }
+)
+
+
+def test_default_model_is_in_builtin_models() -> None:
+ assert DEFAULT_MODEL_ID in get_builtin_model_ids()
+
+
+def test_get_builtin_model_ids_returns_strings() -> None:
+ model_ids = get_builtin_model_ids()
+ assert all(isinstance(m, str) for m in model_ids)
+
+
+def test_get_all_model_ids_returns_builtin_models_when_no_config() -> None:
+ all_ids = get_all_model_ids(user_config=None)
+ builtin_ids = get_builtin_model_ids()
+ assert all_ids == builtin_ids
+
+
+def test_get_all_model_ids_includes_user_defined_models() -> None:
+ all_ids = get_all_model_ids(SAMPLE_USER_CONFIG)
+
+ assert "my-custom-model" in all_ids
+ assert "another-model" in all_ids
+ assert DEFAULT_MODEL_ID in all_ids
+
+
+@pytest.mark.parametrize(
+ ("model_id", "user_config", "expected"),
+ [
+ (DEFAULT_MODEL_ID, None, True),
+ ("nonexistent-model-xyz", None, False),
+ ("my-custom-model", SAMPLE_USER_CONFIG, True),
+ ],
+)
+def test_is_valid_model_id(model_id: str, user_config: ModelsConfig | None, expected: bool) -> None:
+ assert is_valid_model_id(model_id, user_config) is expected
+
+
+@pytest.mark.parametrize(
+ ("model_id", "user_config", "expected"),
+ [
+ ("any-model", None, False),
+ ("my-custom-model", SAMPLE_USER_CONFIG, True),
+ (DEFAULT_MODEL_ID, SAMPLE_USER_CONFIG, False),
+ ],
+)
+def test_is_user_defined_model(model_id: str, user_config: ModelsConfig | None, expected: bool) -> None:
+ assert is_user_defined_model(model_id, user_config) is expected
+
+
+def test_validate_model_id_returns_model_id_when_valid() -> None:
+ result = validate_model_id(DEFAULT_MODEL_ID)
+ assert result == DEFAULT_MODEL_ID
+
+
+def test_validate_model_id_raises_for_invalid_model() -> None:
+ with pytest.raises(ValueError) as exc_info:
+ validate_model_id("nonexistent-model-xyz")
+
+ assert "Unknown model: nonexistent-model-xyz" in str(exc_info.value)
+ assert "--list-models" in str(exc_info.value)
+
+
+def test_validate_model_id_validates_user_defined_model() -> None:
+ user_config = ModelsConfig(
+ providers={
+ "custom": ProviderConfig(
+ base_url="http://localhost:8080/v1",
+ api_key_env="CUSTOM_KEY",
+ models={"my-custom-model": ModelConfig(context_window=128000, max_output_tokens=16384)},
+ )
+ }
+ )
+
+ result = validate_model_id("my-custom-model", user_config)
+ assert result == "my-custom-model"
+
+
+def test_get_builtin_models_by_provider_returns_dict_with_expected_providers() -> None:
+ providers = get_builtin_models_by_provider()
+
+ assert "anthropic" in providers
+ assert "openai" in providers
+ assert "gemini" in providers
+ assert "groq" in providers
+ assert "together" in providers
+
+
+def test_get_builtin_models_by_provider_all_values_are_lists_of_strings() -> None:
+ providers = get_builtin_models_by_provider()
+
+ for provider_name, models in providers.items():
+ assert isinstance(models, list), f"{provider_name} should have a list of models"
+ assert all(isinstance(m, str) for m in models), f"{provider_name} models should all be strings"
+
+
+def test_get_models_by_provider_returns_builtin_providers_when_no_config() -> None:
+ providers = get_models_by_provider(user_config=None)
+ builtin_providers = get_builtin_models_by_provider()
+
+ assert providers == builtin_providers
+
+
+def test_get_models_by_provider_includes_user_defined_providers() -> None:
+ user_config = ModelsConfig(
+ providers={
+ "ollama": ProviderConfig(
+ name="Ollama Local",
+ base_url="http://localhost:11434/v1",
+ api_key_env="OLLAMA_KEY",
+ models={
+ "llama3.2:latest": ModelConfig(context_window=128000, max_output_tokens=16384),
+ "qwen:7b": ModelConfig(context_window=32768, max_output_tokens=8192),
+ },
+ )
+ }
+ )
+
+ providers = get_models_by_provider(user_config)
+
+ assert "Ollama Local" in providers
+ assert set(providers["Ollama Local"]) == {"llama3.2:latest", "qwen:7b"}
+ assert "anthropic" in providers
+ assert "openai" in providers
+
+
+def test_get_models_by_provider_user_provider_overrides_builtin_with_same_name() -> None:
+ user_config = ModelsConfig(
+ providers={
+ "custom": ProviderConfig(
+ name="anthropic",
+ base_url="http://localhost:8080/v1",
+ api_key_env="CUSTOM_KEY",
+ models={"custom-model": ModelConfig(context_window=128000, max_output_tokens=16384)},
+ )
+ }
+ )
+
+ providers = get_models_by_provider(user_config)
+
+ assert providers["anthropic"] == ["custom-model"]
diff --git a/imbue_verify/conftest.py b/vet/conftest.py
diff --git a/imbue_verify/errors.py b/vet/errors.py
diff --git a/imbue_verify/formatters.py b/vet/formatters.py
diff --git a/imbue_verify/issue_identifiers/__init__.py b/vet/issue_identifiers/__init__.py
diff --git a/vet/issue_identifiers/agentic_issue_collation.py b/vet/issue_identifiers/agentic_issue_collation.py
@@ -0,0 +1,183 @@
+import json
+from typing import Generator
+from typing import Iterable
+
+import jinja2
+
+from imbue_core.data_types import AgenticPhase
+from imbue_core.data_types import IssueCode
+from imbue_core.data_types import IssueIdentificationDebugInfo
+from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
+from imbue_core.data_types import LLMResponse
+from imbue_tools.get_conversation_history.input_data_types import CommitInputs
+from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
+from imbue_tools.get_conversation_history.input_data_types import (
+ to_specific_inputs_type,
+)
+from imbue_tools.repo_utils.context_utils import escape_prompt_markers
+from imbue_tools.repo_utils.project_context import ProjectContext
+from imbue_tools.types.vet_config import VetConfig
+from vet.issue_identifiers.common import GeneratedIssueSchema
+from vet.issue_identifiers.common import GeneratedResponseSchema
+from vet.issue_identifiers.common import extract_invocation_info_from_messages
+from vet.issue_identifiers.common import (
+ format_issue_identification_guide_for_llm,
+)
+from vet.issue_identifiers.common import generate_issues_from_response_texts
+from vet.issue_identifiers.common import generate_response_from_claude_code
+from vet.issue_identifiers.common import get_claude_code_options
+from vet.issue_identifiers.identification_guides import (
+ ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
+)
+from vet.issue_identifiers.utils import ReturnCapturingGenerator
+
+COLLATION_PROMPT_TEMPLATE = """You are reviewing the results from parallel code analysis for potential issues.
+Multiple specialized agents analyzed the following code diff, each focusing on a specific type of issue.
+The repository files are available in {{ repo_path }}.
+
+### User request ###
+{% filter indent(width=2) %}
+{{ commit_message }}
+{% endfilter %}
+
+### Diff (lines starting with `-` indicate removed code, and lines starting with `+` indicate added code) ###
+{% filter indent(width=2) %}
+{{ unified_diff }}
+{% endfilter %}
+###
+
+The rubric below outlines the categories of issues we care about:
+{% for issue_code, guide in guides.items() %}
+---
+**{{ issue_code }}**:
+{{ guide }}
+{% endfor %}
+---
+
+### Parallel Analysis Results ###
+{{ generated_issues }}
+
+Your task is to:
+1. Review all the findings for accuracy and relevance using the category definitions above
+2. Consolidate any duplicate or overlapping issues
+3. Ensure each issue is correctly categorized according to the category definitions and re-categorize any issues if necessary
+4. Return a consolidated set of issues
+
+Guidelines:
+- Merge similar issues that refer to the same underlying problem
+- Do not remove any issues, you may only re-categorize or merge issues
+
+After your analysis, provide your response in JSON format matching this schema:
+
+{{ response_schema | tojson(indent=2) }}
+"""
+
+
+def _get_collation_prompt(
+ project_context: ProjectContext,
+ identifier_inputs: CommitInputs,
+ enabled_issue_codes: tuple[IssueCode, ...],
+ generated_issues: str,
+) -> str:
+ # Sort issue codes to make the resulting prompts deterministic (for snapshot tests and LLM caching)
+ sorted_issue_codes = sorted(enabled_issue_codes)
+ formatted_guides = {
+ code: format_issue_identification_guide_for_llm(ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[code])
+ for code in sorted_issue_codes
+ }
+
+ env = jinja2.Environment(undefined=jinja2.StrictUndefined)
+ jinja_template = env.from_string(COLLATION_PROMPT_TEMPLATE)
+
+ prompt = jinja_template.render(
+ {
+ "repo_path": project_context.repo_path,
+ "commit_message": escape_prompt_markers(identifier_inputs.goal),
+ "unified_diff": escape_prompt_markers(identifier_inputs.diff),
+ "guides": formatted_guides,
+ "response_schema": GeneratedResponseSchema.model_json_schema(),
+ "generated_issues": escape_prompt_markers(generated_issues),
+ }
+ )
+ return prompt
+
+
+def _convert_parsed_issues_to_combined_string(
+ all_parsed_issues: Iterable[GeneratedIssueSchema],
+) -> str:
+ """Convert all parsed issues from all issue types to a combined string for collation prompt."""
+ combined_issues = []
+
+ for issue in all_parsed_issues:
+ issue_dict = issue.model_dump()
+ for key in ("location", "code_part"):
+ if key in issue_dict and issue_dict[key] is None:
+ del issue_dict[key]
+ combined_issues.append(issue_dict)
+
+ return json.dumps({"issues": combined_issues}, indent=2)
+
+
+def collate_issues_with_agent(
+ issue_generator: Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo],
+ identifier_inputs: IdentifierInputs,
+ project_context: ProjectContext,
+ config: VetConfig,
+ enabled_issue_codes: tuple[IssueCode, ...],
+) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
+ """
+ Collate issues from multiple issue identifiers.
+
+ Args:
+ issues: The issues to collate.
+ identifier_inputs: The inputs which determine the content provided to the identifiers.
+ project_context: Loaded data corresponding to the inputs, e.g. diffs or files.
+ config: Settings for imbue verify.
+ enabled_issue_codes: The issue types used by the issue identifiers.
+
+ Returns:
+ A generator of collated issues. Returns IssueIdentificationDebugInfo after the generator is exhausted.
+
+ Raises:
+ IdentifierInputsMissingError: If the identifier inputs are missing the commit message or diff, which are required for collation.
+ """
+ collation_inputs = to_specific_inputs_type(identifier_inputs, CommitInputs)
+
+ all_issues = []
+ issue_generator_with_capture = ReturnCapturingGenerator(issue_generator)
+ for issue in issue_generator_with_capture:
+ all_issues.append(issue)
+ issue_generator_debug_info = issue_generator_with_capture.return_value
+
+ options = get_claude_code_options(
+ cwd=project_context.repo_path,
+ model_name=config.language_model_generation_config.model_name,
+ )
+ combined_issues_string = _convert_parsed_issues_to_combined_string(all_issues)
+ collation_prompt = _get_collation_prompt(
+ project_context, collation_inputs, enabled_issue_codes, combined_issues_string
+ )
+ claude_response = generate_response_from_claude_code(collation_prompt, options)
+ assert claude_response is not None
+ response_text, collation_messages = claude_response
+ collation_raw_messages = tuple(json.dumps(message.model_dump()) for message in collation_messages)
+ collation_invocation_info = extract_invocation_info_from_messages(collation_messages)
+
+ collation_llm_responses = (
+ LLMResponse(
+ metadata=IssueIdentificationLLMResponseMetadata(
+ agentic_phase=AgenticPhase.COLLATION,
+ issue_type=None,
+ ),
+ raw_response=collation_raw_messages,
+ invocation_info=collation_invocation_info,
+ ),
+ )
+
+ yield from generate_issues_from_response_texts(response_texts=(response_text,))
+
+ augmented_debug_info = IssueIdentificationDebugInfo(
+ llm_responses=issue_generator_debug_info.llm_responses + collation_llm_responses
+ )
+
+ return augmented_debug_info
diff --git a/vet/issue_identifiers/base.py b/vet/issue_identifiers/base.py
@@ -0,0 +1,92 @@
+import abc
+from typing import Generator
+from typing import Generic
+from typing import TypeVar
+
+from imbue_core.data_types import IssueCode
+from imbue_core.data_types import IssueIdentificationDebugInfo
+from imbue_core.pydantic_serialization import SerializableModel
+from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
+from imbue_tools.get_conversation_history.input_data_types import (
+ to_specific_inputs_type,
+)
+from imbue_tools.repo_utils.project_context import ProjectContext
+from imbue_tools.types.vet_config import VetConfig
+from vet.issue_identifiers.common import GeneratedIssueSchema
+
+T = TypeVar("T", bound=IdentifierInputs)
+
+
+class IssueIdentifier(SerializableModel, abc.ABC, Generic[T]):
+ """
+ A protocol for identifying issues given certain inputs.
+
+ By implementing this protocol and registering the new instance in `vet/issue_identifiers/registry.py`,
+ one can create a new issue identifier and automatically expand the default abilities of vet.
+
+ """
+
+ @abc.abstractmethod
+ def identify_issues(
+ self,
+ identifier_inputs: T,
+ project_context: ProjectContext,
+ config: VetConfig,
+ ) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
+ """
+ Identify issues given the identifier inputs.
+
+ Args:
+ identifier_inputs: The inputs which determine the content provided to the identifier.
+ project_context: Loaded data corresponding to the inputs, e.g. diffs or files.
+ config: Settings for imbue verify.
+
+ Returns:
+ A generator of identified issues. When done iterating, returns the debug info.
+
+ Raises:
+ IdentifierInputsMissingError: If the identifier inputs are missing required information for this identifier.
+ """
+
+ @abc.abstractmethod
+ def input_type(self) -> type[T]:
+ """
+ The type of inputs that this identifier expects.
+ """
+
+ def to_required_inputs(self, identifier_inputs: IdentifierInputs) -> T:
+ return to_specific_inputs_type(identifier_inputs, self.input_type())
+
+ @property
+ @abc.abstractmethod
+ def enabled_issue_codes(self) -> tuple[IssueCode, ...]:
+ """
+ The issue codes that this identifier is capable of identifying.
+ """
+
+ @property
+ def requires_agentic_collation(self) -> bool:
+ """
+ Whether this identifier requires agentic collation of issues.
+ """
+ return False
+
+ @property
+ @abc.abstractmethod
+ def identifies_code_issues(self) -> bool:
+ """
+ Whether this identifier identifies code-related issues (as opposed to e.g. conversation-related issues).
+ """
+ pass
+
+ @abc.abstractmethod
+ def _get_prompt(
+ self,
+ project_context: ProjectContext,
+ config: VetConfig,
+ identifier_inputs: T,
+ ) -> str:
+ """
+ Get the prompt for this identifier.
+ """
+ pass
diff --git a/vet/issue_identifiers/common.py b/vet/issue_identifiers/common.py
@@ -0,0 +1,328 @@
+"""
+Common components shared between issue identifiers.
+"""
+
+from pathlib import Path
+from typing import Generator
+from typing import Iterable
+
+import jinja2
+from loguru import logger
+from pydantic import Field
+from pydantic import PrivateAttr
+
+from imbue_core.agents.agent_api.api import get_agent_client
+from imbue_core.agents.agent_api.claude.data_types import ClaudeCodeOptions
+from imbue_core.agents.agent_api.data_types import AgentAssistantMessage
+from imbue_core.agents.agent_api.data_types import AgentMessage
+from imbue_core.agents.agent_api.data_types import AgentResultMessage
+from imbue_core.agents.agent_api.data_types import AgentTextBlock
+from imbue_core.agents.agent_api.data_types import AgentToolName
+from imbue_core.agents.agent_api.data_types import READ_ONLY_TOOLS
+from imbue_core.agents.llm_apis.anthropic_data_types import AnthropicCachingInfo
+from imbue_core.agents.llm_apis.data_types import CostedLanguageModelResponse
+from imbue_core.async_monkey_patches import log_exception
+from imbue_core.data_types import ConfidenceScore
+from imbue_core.data_types import IdentifiedVerifyIssue
+from imbue_core.data_types import InvocationInfo
+from imbue_core.data_types import IssueCode
+from imbue_core.data_types import IssueIdentificationDebugInfo
+from imbue_core.data_types import IssueIdentifierResult
+from imbue_core.data_types import IssueLocation
+from imbue_core.data_types import LineRange
+from imbue_core.data_types import SeverityScore
+from imbue_core.pydantic_serialization import SerializableModel
+from imbue_tools.llm_output_parsing.parse_model_json_response import (
+ ResponseParsingError,
+)
+from imbue_tools.llm_output_parsing.parse_model_json_response import (
+ parse_model_json_response,
+)
+from imbue_tools.repo_utils.project_context import ProjectContext
+from vet.issue_identifiers.identification_guides import (
+ IssueIdentificationGuide,
+)
+from vet.issue_identifiers.utils import ReturnCapturingGenerator
+
+
+class GeneratedIssueSchema(SerializableModel):
+ """Individual issue from LLM response."""
+
+ issue_code: str = Field(description="Category of the issue")
+ description: str = Field(description="Specific explanation of what's wrong and why it's incorrect")
+ location: str | None = Field(default=None, description="File path where the issue occurs")
+ code_part: str | None = Field(default=None, description="Specific code snippet that has the issue")
+ # pyre doesn't like the way ints/floats implement ge/le
+ severity: int = Field(description="Integer 1-5 (1=minor issue, 5=critical bug)", ge=1, le=5) # pyre-ignore[6]
+ confidence: float = Field(description="Confidence in this assessment", ge=0.0, le=1.0) # pyre-ignore[6]
+
+ # ----------------------------------------------------------------
+ # Internal mutable fields used by the post-identification pipeline for tagging.
+ # These fields are mutable, but "monotic", in the sense that they can only be populated once and never
+ # be changed again after that.
+ # These won't be populated by issue identifiers and are not shown to LLMs.
+ # ----------------------------------------------------------------
+ _passes_filtration: bool | None = PrivateAttr(default=None)
+
+ @property
+ def passes_filtration(self) -> bool:
+ if self._passes_filtration is None:
+ # Default to True if not set
+ return True
+ else:
+ return self._passes_filtration
+
+ def set_passes_filtration(self, passes: bool) -> None:
+ assert self._passes_filtration is None, "passes_filtration can only be set once"
+ self._passes_filtration = passes
+
+
+class GeneratedResponseSchema(SerializableModel):
+ """Complete response structure for issue identification."""
+
+ issues: list[GeneratedIssueSchema] = Field(default_factory=list, description="List of identified issues")
+
+
+def generate_issues_from_response_texts(
+ response_texts: Iterable[str],
+) -> Generator[GeneratedIssueSchema, None, None]:
+ """Generate IssueIdentifierResult objects from LLM response text."""
+ for response_text in response_texts:
+ try:
+ parsed_data = parse_model_json_response(response_text, GeneratedResponseSchema)
+ except ResponseParsingError:
+ logger.warning(f"Failed to parse response text: {response_text}")
+ continue
+
+ for raw_issue in parsed_data.issues:
+ yield raw_issue
+
+
+def line_ranges_to_issue_locations(line_ranges: Iterable[LineRange], file_path: str) -> tuple[IssueLocation, ...]:
+ """Convert LineRange objects to IssueLocation objects."""
+ return tuple(
+ IssueLocation(
+ line_start=line_range.start,
+ line_end=line_range.end,
+ filename=file_path,
+ )
+ for line_range in line_ranges
+ )
+
+
+def convert_generated_issue_to_identified_issue(
+ issue_data: GeneratedIssueSchema,
+ project_context: ProjectContext,
+ enabled_issue_codes: tuple[IssueCode, ...],
+) -> IdentifiedVerifyIssue | None:
+ try:
+ # Validate issue code
+ issue_code = issue_data.issue_code
+ if issue_code not in enabled_issue_codes:
+ logger.error(
+ "Got issue code '{issue_code}', skipping. Expected one of: {enabled_issue_codes}",
+ issue_code=issue_code,
+ enabled_issue_codes=enabled_issue_codes,
+ )
+ return None
+
+ # Extract location and code part for line ranges
+ line_ranges: tuple[LineRange, ...] = ()
+ issue_location = issue_data.location
+ try:
+ issue_location_path = Path(issue_location) if issue_location else None
+ if project_context.repo_path and issue_location_path and issue_location_path.is_absolute():
+ # Make absolute path relative.
+ # This will raise ValueError if issue_location_path is not under repo_path.
+ repo_path = project_context.repo_path
+ assert repo_path is not None
+ issue_location_path = issue_location_path.relative_to(repo_path)
+ except ValueError:
+ issue_location_path = None
+ logger.warning(f"Invalid location '{issue_location}', skipping line range detection.")
+ issue_code_part = issue_data.code_part
+ if issue_location_path and issue_code_part:
+ contents = project_context.file_contents_by_path.get(issue_location_path.as_posix())
+ if contents is not None:
+ line_ranges = LineRange.build_from_substring(contents, issue_code_part)
+ if not line_ranges:
+ logger.debug(
+ "Could not find code_part in file {location}: {code_part_repr}",
+ location=issue_location,
+ code_part_repr=repr(issue_code_part),
+ )
+ else:
+ logger.warning(f"Unknown location '{issue_location}', skipping line range detection.")
+
+ # Convert severity (1-5) to normalized score (0-1)
+ severity_normalized = (issue_data.severity - 1) / 4.0 # Map 1-5 to 0-1
+ locations = line_ranges_to_issue_locations(
+ line_ranges, issue_location_path.as_posix() if issue_location_path else ""
+ )
+ return IdentifiedVerifyIssue(
+ code=IssueCode(issue_data.issue_code),
+ description=issue_data.description,
+ severity_score=SeverityScore(raw=issue_data.severity, normalized=severity_normalized),
+ confidence_score=ConfidenceScore(raw=issue_data.confidence, normalized=issue_data.confidence),
+ location=locations,
+ )
+
+ except (ValueError, KeyError, TypeError) as e:
+ log_exception(
+ e,
+ "Error processing issue data: {issue_data}, skipping",
+ issue_data=issue_data,
+ )
+ return None
+
+
+def convert_to_issue_identifier_result(
+ generator: Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo],
+ project_context: ProjectContext,
+ enabled_issue_codes: tuple[IssueCode, ...],
+) -> Generator[IssueIdentifierResult, None, IssueIdentificationDebugInfo]:
+ """Convert a generator of GeneratedIssueSchema to IssueIdentifierResult."""
+ generator_with_capture = ReturnCapturingGenerator(generator)
+ for issue_data in generator_with_capture:
+ issue = convert_generated_issue_to_identified_issue(
+ issue_data=issue_data,
+ project_context=project_context,
+ enabled_issue_codes=enabled_issue_codes,
+ )
+ if issue:
+ yield IssueIdentifierResult(issue=issue, passes_filtration=issue_data.passes_filtration)
+
+ return generator_with_capture.return_value
+
+
+def get_claude_code_options(cwd: Path | None, model_name: str) -> ClaudeCodeOptions:
+ options = ClaudeCodeOptions(
+ cwd=cwd,
+ permission_mode="bypassPermissions", # Equivalent to --dangerously-skip-permissions
+ allowed_tools=list(READ_ONLY_TOOLS) + [AgentToolName.BASH],
+ model=model_name,
+ )
+ return options
+
+
+def generate_response_from_claude_code(
+ prompt: str, options: ClaudeCodeOptions
+) -> tuple[str, list[AgentMessage]] | None:
+ messages = []
+ assistant_messages = []
+ result_message = None
+ try:
+ with get_agent_client(options=options) as client:
+ for message in client.process_query(prompt):
+ messages.append(message)
+ if isinstance(message, AgentAssistantMessage):
+ assistant_messages.append(message)
+ elif isinstance(message, AgentResultMessage):
+ result_message = message
+ except Exception as e:
+ log_exception(e, "Claude Code API call failed")
+ return None
+
+ # Try to get response from result message first
+ response_text = ""
+ if result_message and result_message.result:
+ response_text = result_message.result
+
+ # If no result message, concatenate assistant messages
+ if not response_text and assistant_messages:
+ for message in assistant_messages:
+ for content_block in message.content:
+ if isinstance(content_block, AgentTextBlock):
+ response_text += content_block.text.strip() + "\n"
+
+ return response_text, messages
+
+
+def extract_invocation_info_from_costed_response(
+ response: CostedLanguageModelResponse,
+) -> InvocationInfo:
+ usage = response.usage
+
+ cache_creation_tokens = None
+ cache_read_tokens = None
+
+ if usage.caching_info is not None:
+ caching_info = usage.caching_info
+ cache_read_tokens = caching_info.read_from_cache
+
+ if caching_info.provider_specific_data is not None:
+ if isinstance(caching_info.provider_specific_data, AnthropicCachingInfo):
+ cache_creation_tokens = (
+ caching_info.provider_specific_data.written_5m + caching_info.provider_specific_data.written_1h
+ )
+ else:
+ logger.info(
+ "Not recording caching info for provider specific data type {}",
+ type(caching_info.provider_specific_data),
+ )
+
+ return InvocationInfo(
+ input_tokens=usage.prompt_tokens_used,
+ cache_creation_input_tokens=cache_creation_tokens,
+ cache_read_input_tokens=cache_read_tokens,
+ total_input_tokens=usage.prompt_tokens_used,
+ output_tokens=usage.completion_tokens_used,
+ cost=usage.dollars_used,
+ )
+
+
+def extract_invocation_info_from_messages(
+ messages: list[AgentMessage],
+) -> InvocationInfo:
+ """Extract invocation information from Agent messages."""
+ for message in messages:
+ if isinstance(message, AgentResultMessage):
+ total_input_tokens = None
+ usage = message.usage
+ if usage:
+ input_tokens = usage.input_tokens
+ cached_tokens = usage.cached_tokens
+ output_tokens = usage.output_tokens
+ else:
+ input_tokens = None
+ cached_tokens = None
+ output_tokens = None
+ if usage and input_tokens is not None and cached_tokens is not None:
+ total_input_tokens = input_tokens + cached_tokens
+ return InvocationInfo(
+ input_tokens=input_tokens,
+ cache_creation_input_tokens=None,
+ cache_read_input_tokens=cached_tokens,
+ total_input_tokens=total_input_tokens,
+ output_tokens=output_tokens,
+ duration_ms=message.duration_ms,
+ cost=usage.total_cost_usd if usage else None,
+ num_turns=message.num_turns,
+ )
+ return InvocationInfo()
+
+
+_ISSUE_IDENTIFICATION_LLM_FORMAT = """
+Guidelines:{% filter indent(width=4) %}
+{{ guide }}{% endfilter %}
+{%- if examples %}
+Examples:
+{%- for example in examples %}
+ - {{ example }}
+{%- endfor %}
+{%- endif -%}
+{%- if exceptions %}
+Exceptions:
+{%- for exception in exceptions %}
+ - {{ exception }}
+{%- endfor %}
+{%- endif -%}
+"""
+
+
+def format_issue_identification_guide_for_llm(guide: IssueIdentificationGuide) -> str:
+ formatted_guide = jinja2.Template(_ISSUE_IDENTIFICATION_LLM_FORMAT).render(
+ guide=guide.guide, examples=guide.examples, exceptions=guide.exceptions
+ )
+
+ return formatted_guide.strip()
diff --git a/vet/issue_identifiers/common_test.py b/vet/issue_identifiers/common_test.py
@@ -0,0 +1,375 @@
+import json
+from pathlib import Path
+from typing import Iterable
+
+from imbue_core.async_monkey_patches_test import expect_exact_logged_errors
+from imbue_core.data_types import IdentifiedVerifyIssue
+from imbue_core.data_types import IssueCode
+from imbue_core.frozen_utils import FrozenDict
+from imbue_core.itertools import only
+from imbue_tools.llm_output_parsing.parse_model_json_response import (
+ ResponseParsingError,
+)
+from imbue_tools.llm_output_parsing.parse_model_json_response import (
+ parse_model_json_response,
+)
+from imbue_tools.repo_utils.project_context import BaseProjectContext
+from imbue_tools.repo_utils.project_context import ProjectContext
+from vet.issue_identifiers.common import GeneratedResponseSchema
+from vet.issue_identifiers.common import (
+ convert_generated_issue_to_identified_issue,
+)
+from vet.issue_identifiers.common import (
+ format_issue_identification_guide_for_llm,
+)
+from vet.issue_identifiers.identification_guides import (
+ ISSUE_CODES_FOR_CORRECTNESS_CHECK,
+)
+from vet.issue_identifiers.identification_guides import (
+ IssueIdentificationGuide,
+)
+
+
+def _parse_issues(
+ valid_response: str,
+ project_context: ProjectContext,
+ enabled_issue_codes: Iterable[IssueCode],
+) -> list[IdentifiedVerifyIssue]:
+ issues = []
+ try:
+ issue_data = parse_model_json_response(valid_response, GeneratedResponseSchema)
+ except ResponseParsingError:
+ return []
+ for parsed_issue in issue_data.issues:
+ issue = convert_generated_issue_to_identified_issue(
+ issue_data=parsed_issue,
+ project_context=project_context,
+ enabled_issue_codes=tuple(enabled_issue_codes),
+ )
+ if issue:
+ issues.append(issue)
+ return issues
+
+
+def test_parse_issues_valid_json() -> None:
+ project_context = BaseProjectContext(
+ file_contents_by_path=FrozenDict({"test.py": "def test():\n while True:\n pass"}),
+ cached_prompt_prefix="test",
+ )
+
+ valid_response = json.dumps(
+ {
+ "issues": [
+ {
+ "issue_code": "logic_error",
+ "description": "Infinite loop detected",
+ "location": "test.py",
+ "code_part": "while True:\n pass",
+ "severity": 5,
+ "confidence": 0.95,
+ }
+ ]
+ }
+ )
+
+ issues = _parse_issues(valid_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+
+ issue = only(issues)
+ assert issue.code == IssueCode.LOGIC_ERROR
+ assert issue.description == "Infinite loop detected"
+ assert issue.confidence_score is not None
+ assert issue.confidence_score.normalized == 0.95
+ assert issue.severity_score is not None
+ assert issue.severity_score.normalized == 1.0 # severity 5 maps to 1.0
+ assert len(issue.location) == 1
+ assert issue.location[0].filename == "test.py"
+
+
+def test_parse_response_with_leading_and_trailing_text() -> None:
+ project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="test")
+ valid_response = json.dumps(
+ {
+ "issues": [
+ {
+ "issue_code": "logic_error",
+ "description": "Infinite loop detected",
+ "location": "test.py",
+ "code_part": "while True:\n pass",
+ "severity": 5,
+ "confidence": 0.95,
+ }
+ ]
+ }
+ )
+
+ response_text = "Some leading text\n```json\n" + valid_response + "\n```\nSome trailing text"
+ # Note: This logs a warning about "Unknown location" since test.py isn't in the project context
+ issues = _parse_issues(response_text, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+ issue = only(issues)
+ assert issue.code == IssueCode.LOGIC_ERROR
+ assert issue.description == "Infinite loop detected"
+ assert issue.confidence_score is not None
+ assert issue.confidence_score.normalized == 0.95
+ assert issue.severity_score is not None
+ assert issue.severity_score.normalized == 1.0 # severity 5 maps to 1.0
+
+
+def test_parse_issues_empty_response() -> None:
+ project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="test")
+
+ empty_response = json.dumps({"issues": []})
+
+ issues = _parse_issues(empty_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+ assert len(issues) == 0
+
+
+def test_parse_issues_invalid_json() -> None:
+ project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="test")
+
+ invalid_response = "not json"
+
+ with expect_exact_logged_errors(["Response does not match the expected schema"]):
+ issues = _parse_issues(invalid_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+ assert len(issues) == 0
+
+
+def test_parse_issues_with_markdown_formatting() -> None:
+ project_context = BaseProjectContext(
+ file_contents_by_path=FrozenDict({"test.py": "x = 1"}),
+ cached_prompt_prefix="test",
+ )
+
+ markdown_response = (
+ "```json\n"
+ + json.dumps(
+ {
+ "issues": [
+ {
+ "issue_code": "runtime_error_risk",
+ "description": "Test issue",
+ "severity": 3,
+ "confidence": 0.8,
+ }
+ ]
+ }
+ )
+ + "\n```"
+ )
+
+ issues = _parse_issues(markdown_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+ assert len(issues) == 1
+ assert issues[0].code == IssueCode.RUNTIME_ERROR_RISK
+
+
+def test_parse_issues_invalid_severity() -> None:
+ project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="test")
+
+ invalid_severity_response = json.dumps(
+ {
+ "issues": [
+ {
+ "issue_code": "logic_error",
+ "description": "Test issue",
+ "severity": 10, # Invalid - should be 1-5
+ "confidence": 0.8,
+ }
+ ]
+ }
+ )
+
+ with expect_exact_logged_errors(["Response does not match the expected schema"]):
+ issues = _parse_issues(
+ invalid_severity_response,
+ project_context,
+ ISSUE_CODES_FOR_CORRECTNESS_CHECK,
+ )
+ assert len(issues) == 0 # Should be skipped due to invalid severity
+
+
+def test_parse_issues_unknown_issue_code() -> None:
+ project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="test")
+
+ unknown_code_response = json.dumps(
+ {
+ "issues": [
+ {
+ "issue_code": "unknown_issue", # Not in our defined codes
+ "description": "Test issue",
+ "severity": 3,
+ "confidence": 0.8,
+ }
+ ]
+ }
+ )
+
+ with expect_exact_logged_errors(["Got issue code"]):
+ issues = _parse_issues(unknown_code_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+ assert len(issues) == 0 # Should be skipped due to unknown code
+
+
+def test_parse_issues_missing_required_fields() -> None:
+ project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="[ROLE=SYSTEM]\ntest")
+
+ # Missing required field 'confidence'
+ missing_field_response = json.dumps(
+ {
+ "issues": [
+ {
+ "issue_code": "logic_error",
+ "description": "Test issue",
+ "severity": 3,
+ # Missing 'confidence' field
+ }
+ ]
+ }
+ )
+
+ with expect_exact_logged_errors(["Response does not match the expected schema"]):
+ issues = _parse_issues(missing_field_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+ assert len(issues) == 0 # Should be skipped due to missing field
+
+
+def test_parse_issues_invalid_confidence() -> None:
+ project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="[ROLE=SYSTEM]\ntest")
+
+ invalid_confidence_response = json.dumps(
+ {
+ "issues": [
+ {
+ "issue_code": "logic_error",
+ "description": "Test issue",
+ "severity": 3,
+ "confidence": 1.5, # Invalid - should be 0.0-1.0
+ }
+ ]
+ }
+ )
+
+ with expect_exact_logged_errors(["Response does not match the expected schema"]):
+ issues = _parse_issues(
+ invalid_confidence_response,
+ project_context,
+ ISSUE_CODES_FOR_CORRECTNESS_CHECK,
+ )
+ assert len(issues) == 0 # Should be skipped due to invalid confidence
+
+
+def test_parse_issues_with_line_ranges() -> None:
+ code_content = "def hello():\n print('world')\n return True"
+ project_context = BaseProjectContext(
+ file_contents_by_path=FrozenDict({"test.py": code_content}),
+ cached_prompt_prefix="[ROLE=SYSTEM]\ntest",
+ )
+
+ response_with_location = json.dumps(
+ {
+ "issues": [
+ {
+ "issue_code": "logic_error",
+ "description": "Test issue with location",
+ "location": "test.py",
+ "code_part": "print('world')",
+ "severity": 3,
+ "confidence": 0.8,
+ }
+ ]
+ }
+ )
+
+ issues = _parse_issues(response_with_location, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+ issue = only(issues)
+ assert issue.location[0].filename == "test.py"
+ assert len(issue.location) > 0 # Should have found line ranges
+
+
+def test_parse_issues_malformed_response_structure() -> None:
+ project_context = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="[ROLE=SYSTEM]\ntest")
+
+ # Test with non-dict response
+ non_dict_response = json.dumps(["not", "a", "dict"])
+ with expect_exact_logged_errors(["Response does not match the expected schema"]):
+ issues = _parse_issues(non_dict_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+ assert len(issues) == 0
+
+ # Test with missing `issues` key
+ missing_key_response = json.dumps({"other_key": ["some value", "another value"]})
+ # note that this doesn't log an error; the model validation allows "issues" to be missing, and fills in an empty list
+ issues = _parse_issues(missing_key_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+ assert len(issues) == 0
+
+ # Test with missing everything
+ missing_everything_response = json.dumps({})
+ # note that this doesn't log an error; the model validation allows "issues" to be missing, and fills in an empty list
+ issues = _parse_issues(missing_everything_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+ assert len(issues) == 0
+
+ # Test with non-list `issues` value
+ non_list_response = json.dumps({"issues": "not a list"})
+ with expect_exact_logged_errors(["Response does not match the expected schema"]):
+ issues = _parse_issues(non_list_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+ assert len(issues) == 0
+
+
+def test_format_issue_identification_guide_for_llm() -> None:
+ complete_guide = IssueIdentificationGuide(
+ issue_code=IssueCode.LOGIC_ERROR,
+ guide="- Guideline 1\n- Guideline 2",
+ examples=(
+ "Example 1",
+ "Example 2",
+ ),
+ exceptions=(
+ "Exception 1",
+ "Exception 2",
+ ),
+ )
+
+ expected_formatted_complete_guide = """Guidelines:
+ - Guideline 1
+ - Guideline 2
+Examples:
+ - Example 1
+ - Example 2
+Exceptions:
+ - Exception 1
+ - Exception 2"""
+
+ minimal_guide = IssueIdentificationGuide(issue_code=IssueCode.LOGIC_ERROR, guide="Only has a guide.")
+ expected_formatted_minimal_guide = """Guidelines:
+ Only has a guide."""
+
+ formatted_complete_guide = format_issue_identification_guide_for_llm(complete_guide)
+ assert formatted_complete_guide == expected_formatted_complete_guide
+
+ formatted_minimal_guide = format_issue_identification_guide_for_llm(minimal_guide)
+ assert formatted_minimal_guide == expected_formatted_minimal_guide
+
+
+def test_strips_absolute_filenames() -> None:
+ project_context = BaseProjectContext(
+ file_contents_by_path=FrozenDict({"test.py": "def test():\n while True:\n pass"}),
+ cached_prompt_prefix="test",
+ repo_path=Path("/code"),
+ )
+
+ valid_response = json.dumps(
+ {
+ "issues": [
+ {
+ "issue_code": "logic_error",
+ "description": "Infinite loop detected",
+ "location": "/code/test.py",
+ "code_part": "while True:\n pass",
+ "severity": 5,
+ "confidence": 0.95,
+ }
+ ]
+ }
+ )
+
+ issues = _parse_issues(valid_response, project_context, ISSUE_CODES_FOR_CORRECTNESS_CHECK)
+
+ issue = only(issues)
+ assert issue.description == "Infinite loop detected"
+ assert len(issue.location) == 1
+ assert issue.location[0].filename == "test.py"
diff --git a/imbue_verify/issue_identifiers/context_providers/__init__.py b/vet/issue_identifiers/context_providers/__init__.py
diff --git a/imbue_verify/issue_identifiers/harnesses/__init__.py b/vet/issue_identifiers/harnesses/__init__.py
diff --git a/vet/issue_identifiers/harnesses/agentic.py b/vet/issue_identifiers/harnesses/agentic.py
@@ -0,0 +1,361 @@
+"""
+Agentic harness that checks a given diff for issues using Claude Code agents with tools.
+"""
+
+import concurrent.futures
+import json
+from concurrent.futures import ThreadPoolExecutor
+from functools import cached_property
+from typing import Any
+from typing import Generator
+
+import jinja2
+from loguru import logger
+
+from imbue_core.agents.agent_api.claude.data_types import ClaudeCodeOptions
+from imbue_core.agents.agent_api.data_types import AgentMessage
+from imbue_core.agents.agent_api.data_types import AgentToolName
+from imbue_core.agents.agent_api.data_types import READ_ONLY_TOOLS
+from imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
+from imbue_core.async_monkey_patches import log_exception
+from imbue_core.data_types import AgenticPhase
+from imbue_core.data_types import IssueCode
+from imbue_core.data_types import IssueIdentificationDebugInfo
+from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
+from imbue_core.data_types import LLMResponse
+from imbue_tools.get_conversation_history.input_data_types import CommitInputs
+from imbue_tools.repo_utils.context_utils import escape_prompt_markers
+from imbue_tools.repo_utils.project_context import ProjectContext
+from imbue_tools.types.vet_config import VetConfig
+from vet.issue_identifiers.base import IssueIdentifier
+from vet.issue_identifiers.common import GeneratedIssueSchema
+from vet.issue_identifiers.common import GeneratedResponseSchema
+from vet.issue_identifiers.common import extract_invocation_info_from_messages
+from vet.issue_identifiers.common import (
+ format_issue_identification_guide_for_llm,
+)
+from vet.issue_identifiers.common import generate_issues_from_response_texts
+from vet.issue_identifiers.common import generate_response_from_claude_code
+from vet.issue_identifiers.harnesses.base import IssueIdentifierHarness
+from vet.issue_identifiers.identification_guides import (
+ IssueIdentificationGuide,
+)
+
+PROMPT_TEMPLATE = """You are analyzing a code repository for potential issues. The repository files are available in {{ repo_path }}.
+
+Assume that a user requested work to be done and a programmer delivered the diff below.
+The changes from the diff are present in the codebase but are not yet committed.
+
+### User request ###
+{% filter indent(width=2) %}
+{{ commit_message }}
+{% endfilter %}
+
+### Diff (lines starting with `-` indicate removed code, and lines starting with `+` indicate added code) ###
+{% filter indent(width=2) %}
+{{ unified_diff }}
+{% endfilter %}
+###
+
+Your task is to help verify the quality of the diff.
+We care only about specific categories of important issues.
+The rubric below outlines these categories of important issues, and contains guidelines and examples to correctly identify them:
+{% for issue_code, guide in guides.items() %}
+---
+**{{ issue_code }}**:
+{{ guide }}
+{% endfor %}
+---
+
+Use your standard tools to explore the repository and analyze the code thoroughly.
+Look at the additional guidance section below for more details on how to find issues.
+
+After your analysis, provide your response in JSON format matching this schema:
+
+{{ response_schema | tojson(indent=2) }}
+
+For each issue found, provide:
+- issue_code: Category from the rubric above
+- description: Specific explanation of the issue
+- (if applicable) location: File path where the issue occurs (relative to {{ repo_path }})
+- (if applicable) code_part: Specific code snippet that has the issue. Your code snippet should be the exact same as the original code including whitespace.
+- severity: Integer 1-5 (1=minor, 5=critical)
+- confidence: Float 0.0-1.0 indicating your confidence
+
+Your response should look like:
+```json
+{
+ "issues": [
+ <list of issues>
+ ]
+}
+```
+
+If no issues are found, return: ```json{"issues": []}```
+
+Focus on real issues that impact code quality, correctness, or maintainability.
+You must not return issues that were already present in the code or issues that are fixed by the diff.
+You must only return issues that were introduced by the diff.
+Do not report duplicate issues with the same or equivalent descriptions.
+
+### Additional Guidance for Finding Issues ###
+You should use a Task tool to create a parallel task for each issue type in the rubric.
+You should pass along the exact issue type definition with all details to the task.
+Once all the Tasks have completed you can collate their results.
+You should pass along any relevant information from the guidance below to the task.
+Here is a non-exhaustive list of things that you can do using your tools within the task to find issues:
+{% for issue_code, guidance in additional_guidance.items() %}
+---
+**{{ issue_code }}**:
+{{ guidance }}
+{% endfor %}
+---
+Note that this is just guidance on how to find issues, please refer to the rubric for the types of issues to find.
+"""
+
+ISSUE_TYPE_PROMPT_TEMPLATE = """You are analyzing a code repository for potential issues of type {{ issue_type }}. The repository files are available in {{ repo_path }}.
+
+Assume that a user requested work to be done and a programmer delivered the diff below.
+The changes from the diff are present in the codebase but are not yet committed.
+
+### User request ###
+{% filter indent(width=2) %}
+{{ commit_message }}
+{% endfilter %}
+
+### Diff (lines starting with `-` indicate removed code, and lines starting with `+` indicate added code) ###
+{% filter indent(width=2) %}
+{{ unified_diff }}
+{% endfilter %}
+###
+
+Your task is to help verify the quality of the diff.
+Here is the definition of the issue type you are looking for:
+**{{ issue_type }}**:
+{{ guide }}
+
+Use your standard tools to explore the repository and analyze the code thoroughly.
+ONLY look for issues related to {{ issue_type }}.
+Do NOT modify any files - this is read-only analysis.
+
+After your analysis, provide your response in JSON format matching this schema:
+
+{{ response_schema | tojson(indent=2) }}
+
+For each issue found, provide:
+- issue_code: Category from the rubric above
+- description: Specific explanation of the issue
+- (if applicable) location: File path where the issue occurs (relative to {{ repo_path }})
+- (if applicable) code_part: Specific code snippet that has the issue. Your code snippet should be the exact same as the original code including whitespace.
+- severity: Integer 1-5 (1=minor, 5=critical)
+- confidence: Float 0.0-1.0 indicating your confidence
+
+Your response should look like:
+```json
+{
+ "issues": [
+ <list of issues>
+ ]
+}
+```
+
+If no issues of this type are found, return: ```json{"issues": []}```
+You must not return issues that were already present in the code or issues that are fixed by the diff.
+You must only return issues that were introduced by the diff.
+Do not report duplicate issues with the same or equivalent descriptions.
+"""
+
+
+MAX_PARALLEL_CLAUDE_CODE_SESSIONS = 5 # TODO: this was arbitrarily chosen
+ResponseText = str
+
+
+def _generate_issues_worker(
+ issue_code: IssueCode,
+ prompt: str,
+ options: ClaudeCodeOptions,
+) -> tuple[IssueCode, ResponseText, list[AgentMessage]] | None:
+ issue_result = generate_response_from_claude_code(prompt, options)
+ if issue_result is None:
+ return None
+ return issue_code, issue_result[0], issue_result[1]
+
+
+class _AgenticIssueIdentifier(IssueIdentifier[CommitInputs]):
+ _identification_guides: tuple[IssueIdentificationGuide, ...]
+
+ def __init__(self, identification_guides: tuple[IssueIdentificationGuide, ...]) -> None:
+ assert len(identification_guides) > 0, "At least one identification guide must be provided"
+ self._identification_guides = identification_guides
+
+ @cached_property
+ def _response_schema(self) -> dict[str, Any]:
+ return GeneratedResponseSchema.model_json_schema()
+
+ def _get_prompt(
+ self,
+ project_context: ProjectContext,
+ config: VetConfig, # unused
+ identifier_inputs: CommitInputs,
+ ) -> str:
+ env = jinja2.Environment(undefined=jinja2.StrictUndefined)
+ jinja_template = env.from_string(PROMPT_TEMPLATE)
+ additional_guidance_by_issue_code = {
+ guide.issue_code: guide.additional_guide_for_agent for guide in self._identification_guides
+ }
+
+ formatted_guides = {
+ guide.issue_code: format_issue_identification_guide_for_llm(guide) for guide in self._identification_guides
+ }
+
+ prompt = jinja_template.render(
+ {
+ "repo_path": project_context.repo_path,
+ "commit_message": escape_prompt_markers(identifier_inputs.goal),
+ "unified_diff": escape_prompt_markers(identifier_inputs.diff),
+ "guides": formatted_guides,
+ "response_schema": self._response_schema,
+ "additional_guidance": additional_guidance_by_issue_code,
+ }
+ )
+ return prompt
+
+ def _get_prompt_for_issue_type(
+ self,
+ project_context: ProjectContext,
+ identifier_inputs: CommitInputs,
+ guide: IssueIdentificationGuide,
+ ) -> str:
+ env = jinja2.Environment(undefined=jinja2.StrictUndefined)
+ jinja_template = env.from_string(ISSUE_TYPE_PROMPT_TEMPLATE)
+
+ formatted_guide = format_issue_identification_guide_for_llm(guide)
+
+ prompt = jinja_template.render(
+ {
+ "repo_path": project_context.repo_path,
+ "commit_message": escape_prompt_markers(identifier_inputs.goal),
+ "unified_diff": escape_prompt_markers(identifier_inputs.diff),
+ "guide": formatted_guide,
+ "response_schema": self._response_schema,
+ "issue_type": guide.issue_code,
+ }
+ )
+ return prompt
+
+ def identify_issues(
+ self,
+ identifier_inputs: CommitInputs,
+ project_context: ProjectContext,
+ config: VetConfig,
+ ) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
+ assert project_context.repo_path is not None, "Project context must have a valid repo_path, got None"
+
+ config_model_name = config.language_model_generation_config.model_name
+ if config_model_name in [anthropic_model.value for anthropic_model in AnthropicModelName]:
+ model_name = config_model_name
+ else:
+ model_name = AnthropicModelName.CLAUDE_4_5_HAIKU_2025_10_01
+ logger.info(
+ "Config model_name {config_model_name} is not a valid Anthropic model, using default ({model_name}).",
+ config_model_name=config_model_name,
+ model_name=model_name,
+ )
+
+ options = ClaudeCodeOptions(
+ cwd=project_context.repo_path,
+ permission_mode="bypassPermissions", # Equivalent to --dangerously-skip-permissions
+ allowed_tools=list(READ_ONLY_TOOLS) + [AgentToolName.BASH], # Allow read-only tools for analysis
+ model=model_name,
+ )
+
+ if config.enable_parallel_agentic_issue_identification:
+ llm_responses = []
+
+ issue_prompts = [
+ (
+ guide.issue_code,
+ self._get_prompt_for_issue_type(project_context, identifier_inputs, guide),
+ )
+ for guide in self._identification_guides
+ ]
+ with ThreadPoolExecutor(max_workers=MAX_PARALLEL_CLAUDE_CODE_SESSIONS) as executor:
+ tasks = [
+ executor.submit(_generate_issues_worker, issue_code, prompt, options)
+ for issue_code, prompt in issue_prompts
+ ]
+
+ for task in concurrent.futures.as_completed(tasks):
+ try:
+ result = task.result()
+ except Exception as e:
+ log_exception(e, "Error processing issue type: {e}", e=e)
+ continue
+
+ if result is None:
+ continue
+
+ issue_code, issue_type_response_text, messages = result
+
+ yield from generate_issues_from_response_texts(response_texts=(issue_type_response_text,))
+
+ message_dumps = tuple(json.dumps(message.model_dump()) for message in messages)
+ invocation_info = extract_invocation_info_from_messages(messages)
+
+ llm_responses.append(
+ LLMResponse(
+ metadata=IssueIdentificationLLMResponseMetadata(
+ agentic_phase=AgenticPhase.ISSUE_IDENTIFICATION,
+ issue_type=issue_code,
+ ),
+ raw_response=message_dumps,
+ invocation_info=invocation_info,
+ )
+ )
+
+ return IssueIdentificationDebugInfo(llm_responses=tuple(llm_responses))
+ else:
+ prompt = self._get_prompt(project_context, config, identifier_inputs)
+ claude_response = generate_response_from_claude_code(prompt, options)
+ assert claude_response is not None
+ response_text, messages = claude_response
+
+ message_dumps = tuple(json.dumps(message.model_dump()) for message in messages)
+ invocation_info = extract_invocation_info_from_messages(messages)
+
+ llm_responses = [
+ LLMResponse(
+ metadata=IssueIdentificationLLMResponseMetadata(
+ agentic_phase=AgenticPhase.ISSUE_IDENTIFICATION,
+ issue_type=None,
+ ),
+ raw_response=message_dumps,
+ invocation_info=invocation_info,
+ )
+ ]
+
+ yield from generate_issues_from_response_texts(response_texts=(response_text,))
+
+ return IssueIdentificationDebugInfo(llm_responses=tuple(llm_responses))
+
+ def input_type(self) -> type[CommitInputs]:
+ return CommitInputs
+
+ @property
+ def enabled_issue_codes(self) -> tuple[IssueCode, ...]:
+ return tuple(guide.issue_code for guide in self._identification_guides)
+
+ @property
+ def requires_agentic_collation(self) -> bool:
+ return True
+
+ @property
+ def identifies_code_issues(self) -> bool:
+ return True
+
+
+class AgenticHarness(IssueIdentifierHarness[CommitInputs]):
+ def make_issue_identifier(
+ self, identification_guides: tuple[IssueIdentificationGuide, ...]
+ ) -> IssueIdentifier[CommitInputs]:
+ return _AgenticIssueIdentifier(identification_guides=identification_guides)
diff --git a/vet/issue_identifiers/harnesses/base.py b/vet/issue_identifiers/harnesses/base.py
@@ -0,0 +1,17 @@
+import abc
+from typing import Generic
+from typing import TypeVar
+
+from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
+from vet.issue_identifiers.base import IssueIdentifier
+from vet.issue_identifiers.identification_guides import (
+ IssueIdentificationGuide,
+)
+
+T = TypeVar("T", bound=IdentifierInputs)
+
+
+class IssueIdentifierHarness(abc.ABC, Generic[T]):
+ @abc.abstractmethod
+ def make_issue_identifier(self, identification_guides: tuple[IssueIdentificationGuide, ...]) -> IssueIdentifier[T]:
+ """Return an issue identifier based on this harness by binding it to the provided issue identification guides."""
diff --git a/vet/issue_identifiers/harnesses/conversation_single_prompt.py b/vet/issue_identifiers/harnesses/conversation_single_prompt.py
@@ -0,0 +1,153 @@
+"""
+Single-prompt issue identification harness that operates on the conversation history.
+
+Currently hard-coded to check for misleading behavior in a conversation.
+"""
+
+from functools import cached_property
+from typing import Any
+from typing import Generator
+
+import jinja2
+
+from imbue_core.agents.llm_apis.build_apis import build_language_model_from_config
+from imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams
+from imbue_core.data_types import AgenticPhase
+from imbue_core.data_types import IssueCode
+from imbue_core.data_types import IssueIdentificationDebugInfo
+from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
+from imbue_core.data_types import LLMResponse
+from imbue_core.itertools import only
+from imbue_tools.get_conversation_history.get_conversation_history import (
+ format_conversation_history_for_prompt,
+)
+from imbue_tools.get_conversation_history.input_data_types import ConversationInputs
+from imbue_tools.repo_utils.project_context import ProjectContext
+from imbue_tools.types.vet_config import VetConfig
+from imbue_tools.util_prompts.conversation_prefix import CONVERSATION_PREFIX_TEMPLATE
+from vet.issue_identifiers.base import IssueIdentifier
+from vet.issue_identifiers.common import GeneratedIssueSchema
+from vet.issue_identifiers.common import GeneratedResponseSchema
+from vet.issue_identifiers.common import (
+ extract_invocation_info_from_costed_response,
+)
+from vet.issue_identifiers.common import (
+ format_issue_identification_guide_for_llm,
+)
+from vet.issue_identifiers.common import generate_issues_from_response_texts
+from vet.issue_identifiers.harnesses.base import IssueIdentifierHarness
+from vet.issue_identifiers.identification_guides import (
+ IssueIdentificationGuide,
+)
+
+PROMPT_TEMPLATE = (
+ CONVERSATION_PREFIX_TEMPLATE
+ + """
+{% if cache_full_prompt %}[ROLE=USER_CACHED]{% else %}[ROLE=USER]{% endif %}{% if instruction_context %}
+Here are the instruction files that were provided to the agent:
+{{ instruction_context }}{% endif %}
+
+Your task is to examine the conversation history to find events of interest.
+These events will be used to generate suggestions for what the agent should do next to best achieve the user's goal.
+We care only about specific categories of events. The rubric below outlines these categories of events, and contains guidelines and examples to correctly identify them:
+{% for guide_name, guide in guides.items() %}
+---
+**{{ guide_name }}**:
+{{ guide }}
+
+{% endfor %}
+---
+
+Respond with valid JSON that matches this exact schema:
+
+{{ response_schema | tojson(indent=2) }}
+
+[ROLE=ASSISTANT]
+"""
+)
+
+
+class _ConversationSinglePromptIssueIdentifier(IssueIdentifier[ConversationInputs]):
+ _identification_guides: tuple[IssueIdentificationGuide, ...]
+
+ def __init__(self, identification_guides: tuple[IssueIdentificationGuide, ...]) -> None:
+ self._identification_guides = identification_guides
+
+ @cached_property
+ def _response_schema(self) -> dict[str, Any]:
+ return GeneratedResponseSchema.model_json_schema()
+
+ def _get_prompt(
+ self,
+ project_context: ProjectContext,
+ config: VetConfig,
+ identifier_inputs: ConversationInputs,
+ ) -> str:
+ # Sort the guides by issue code to ensure prompt caching (and snapshotting in tests) works.
+ sorted_guides = sorted(self._identification_guides, key=lambda guide: guide.issue_code)
+ formatted_guides = {
+ guide.issue_code: format_issue_identification_guide_for_llm(guide) for guide in sorted_guides
+ }
+
+ env = jinja2.Environment(undefined=jinja2.StrictUndefined)
+ jinja_template = env.from_string(PROMPT_TEMPLATE)
+ return jinja_template.render(
+ cached_prompt_prefix=project_context.cached_prompt_prefix,
+ cache_full_prompt=config.cache_full_prompt,
+ conversation_history=format_conversation_history_for_prompt(identifier_inputs.conversation_history),
+ # pyre-fixme[16]: SubrepoContext need not have a formatted_repo_context, and instruction_context can be None
+ instruction_context=project_context.instruction_context.formatted_repo_context,
+ response_schema=self._response_schema,
+ guides=formatted_guides,
+ )
+
+ def identify_issues(
+ self,
+ identifier_inputs: ConversationInputs,
+ project_context: ProjectContext,
+ config: VetConfig,
+ ) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
+ language_model = build_language_model_from_config(config.language_model_generation_config)
+ language_model_params = LanguageModelGenerationParams(
+ temperature=config.temperature,
+ max_tokens=config.max_output_tokens,
+ )
+ prompt = self._get_prompt(project_context, config, identifier_inputs)
+ costed_response = language_model.complete_with_usage_sync(
+ prompt,
+ params=language_model_params,
+ is_caching_enabled=language_model.cache_path is not None,
+ )
+
+ response = only(costed_response.responses)
+ invocation_info = extract_invocation_info_from_costed_response(costed_response)
+
+ llm_responses = (
+ LLMResponse(
+ metadata=IssueIdentificationLLMResponseMetadata(agentic_phase=AgenticPhase.ISSUE_IDENTIFICATION),
+ raw_response=(response.text,),
+ invocation_info=invocation_info,
+ ),
+ )
+
+ yield from generate_issues_from_response_texts(response_texts=(response.text,))
+
+ return IssueIdentificationDebugInfo(llm_responses=llm_responses)
+
+ def input_type(self) -> type[ConversationInputs]:
+ return ConversationInputs
+
+ @property
+ def enabled_issue_codes(self) -> tuple[IssueCode, ...]:
+ return tuple(guide.issue_code for guide in self._identification_guides)
+
+ @property
+ def identifies_code_issues(self) -> bool:
+ return False
+
+
+class ConversationSinglePromptHarness(IssueIdentifierHarness[ConversationInputs]):
+ def make_issue_identifier(
+ self, identification_guides: tuple[IssueIdentificationGuide, ...]
+ ) -> IssueIdentifier[ConversationInputs]:
+ return _ConversationSinglePromptIssueIdentifier(identification_guides=identification_guides)
diff --git a/vet/issue_identifiers/harnesses/conversation_single_prompt_test.py b/vet/issue_identifiers/harnesses/conversation_single_prompt_test.py
@@ -0,0 +1,68 @@
+import pytest
+
+from imbue_core.data_types import IssueCode
+from vet_types.chat_state import TextBlock
+from vet_types.ids import AssistantMessageID
+from vet_types.messages import AgentMessageSource
+from vet_types.messages import ChatInputUserMessage
+from vet_types.messages import LLMModel
+from vet_types.messages import ResponseBlockAgentMessage
+from imbue_tools.get_conversation_history.input_data_types import ConversationInputs
+from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
+from imbue_tools.get_conversation_history.input_data_types import (
+ IdentifierInputsMissingError,
+)
+from vet.issue_identifiers.harnesses.conversation_single_prompt import (
+ ConversationSinglePromptHarness,
+)
+from vet.issue_identifiers.identification_guides import (
+ ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
+)
+
+
+def test_to_required_inputs() -> None:
+ harness = ConversationSinglePromptHarness()
+ classifier = harness.make_issue_identifier(
+ identification_guides=(ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[IssueCode.MISLEADING_BEHAVIOR],)
+ )
+
+ # should support inputs where only the conversation history is present
+ conversation_history_inputs = IdentifierInputs(
+ maybe_conversation_history=(
+ ChatInputUserMessage(
+ text="fake content",
+ model_name=LLMModel.CLAUDE_4_SONNET,
+ ),
+ )
+ )
+ cvi = classifier.to_required_inputs(conversation_history_inputs)
+ assert isinstance(cvi, ConversationInputs)
+
+ # and inputs where the conversation history and commit message are present
+ conversation_history_and_commit_message_inputs = IdentifierInputs(
+ maybe_conversation_history=(
+ ResponseBlockAgentMessage(
+ source=AgentMessageSource.AGENT,
+ role="assistant",
+ assistant_message_id=AssistantMessageID("fake_message_id"),
+ content=(TextBlock(text="fake content"),),
+ ),
+ ),
+ maybe_goal="test",
+ maybe_diff="test",
+ )
+ cvi = classifier.to_required_inputs(conversation_history_and_commit_message_inputs)
+ assert isinstance(cvi, ConversationInputs)
+ assert cvi.maybe_goal == "test"
+ assert cvi.maybe_diff == "test"
+
+ # should not support inputs where the conversation history is absent
+ commit_inputs = IdentifierInputs(maybe_goal="test", maybe_diff="test")
+ with pytest.raises(IdentifierInputsMissingError):
+ classifier.to_required_inputs(commit_inputs)
+ file_inputs = IdentifierInputs(maybe_files=("test.py",))
+ with pytest.raises(IdentifierInputsMissingError):
+ classifier.to_required_inputs(file_inputs)
+ no_inputs = IdentifierInputs()
+ with pytest.raises(IdentifierInputsMissingError):
+ classifier.to_required_inputs(no_inputs)
diff --git a/vet/issue_identifiers/harnesses/single_prompt.py b/vet/issue_identifiers/harnesses/single_prompt.py
@@ -0,0 +1,200 @@
+"""
+Simple zero-shot issue identification harness that checks a diff for issues in a single prompt.
+"""
+
+from functools import cached_property
+from typing import Any
+from typing import Generator
+
+import jinja2
+
+from imbue_core.agents.llm_apis.build_apis import build_language_model_from_config
+from imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams
+from imbue_core.data_types import AgenticPhase
+from imbue_core.data_types import IssueCode
+from imbue_core.data_types import IssueIdentificationDebugInfo
+from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
+from imbue_core.data_types import LLMResponse
+from imbue_core.itertools import only
+from imbue_tools.get_conversation_history.input_data_types import CommitInputs
+from imbue_tools.repo_utils.context_utils import escape_prompt_markers
+from imbue_tools.repo_utils.project_context import ProjectContext
+from imbue_tools.types.vet_config import VetConfig
+from vet.issue_identifiers.base import IssueIdentifier
+from vet.issue_identifiers.common import GeneratedIssueSchema
+from vet.issue_identifiers.common import GeneratedResponseSchema
+from vet.issue_identifiers.common import (
+ extract_invocation_info_from_costed_response,
+)
+from vet.issue_identifiers.common import (
+ format_issue_identification_guide_for_llm,
+)
+from vet.issue_identifiers.common import generate_issues_from_response_texts
+from vet.issue_identifiers.harnesses.base import IssueIdentifierHarness
+from vet.issue_identifiers.identification_guides import (
+ IssueIdentificationGuide,
+)
+
+USER_REQUEST_PREFIX_TEMPLATE = """{{cached_prompt_prefix}}
+[ROLE=USER_CACHED]
+I'm working on a project, adding commits one after another. The current state of the project is captured by the codebase snapshot above.
+{% if extra_context %}
+
+=== ADDITIONAL CONTEXT BEGIN ===
+{{ extra_context }}
+=== ADDITIONAL CONTEXT END ===
+{% endif %}
+
+Assume that I asked for a piece of work to be done by specifying the user request and another programmer has delivered the diff.
+{% if include_request_and_diff %}
+Below, you can see the user request, as well as the delivered diff. IMPORTANT: The codebase snapshot already includes the changes made in this diff!
+
+=== USER REQUEST BEGIN ===
+{{ commit_message }}
+=== USER REQUEST END ===
+
+=== DIFF BEGIN (unified; lines starting with `-` are removed and `+` are added) ===
+{{ unified_diff }}
+=== DIFF END ===
+
+{% endif %}{% if not cache_full_prompt %}
+[ROLE=USER]{% endif %}
+"""
+
+
+PROMPT_TEMPLATE = (
+ USER_REQUEST_PREFIX_TEMPLATE
+ + """Your task is to help me verify the quality of the diff.
+
+We care only about specific categories of issues. The rubric below outlines these categories of issues, and contains guidelines and examples to correctly identify them:
+{% for issue_type_name, guide in guides.items() %}
+[Issue Category {{ loop.index }}: {{ issue_type_name }}]
+{{ guide }}
+[End of issue category: {{ issue_type_name }}]
+{% endfor %}
+
+## Instructions:
+
+1. Look at each category of issues outlined above, one at a time.
+2. For each given category, analyze the diff for issues that match the category.
+3. For each issue found, provide:
+ - issue_code: One of the category names above
+ - description: Specific explanation of what's wrong and what a better implementation could be. The description should not exceed a few sentences unless absolutely necessary.
+ - location: File path where the issue occurs (if applicable)
+ - code_part: Specific code snippet that has the issue (if applicable). Must match exactly, including whitespace. If the code part spans multiple lines, include the exact whitespace and newlines. If there are multiple locations that are relevant to the issue, select a single one to represent the issue.
+ - severity: Integer 1-5 (1=minor issue, 5=critical issue that will definitely cause problems)
+ - confidence: Float 0.0-1.0 indicating your confidence in this issue
+4. When you have identified all issues of the current category, move on to the next category and repeat the process.
+
+Respond with valid JSON that matches this exact schema:
+
+{{ response_schema | tojson(indent=2) }}
+
+Every issue you report must stand on its own, and should not reference other issues in its description.
+Do not report duplicate issues with the same or equivalent descriptions within one issue category.
+Do not output any issues that are merely based on the absence of information in the codebase snapshot.
+Do not speculate about the way a piece of code might get used if that use is not supported by the code included above.
+Only raise issues that were introduced by the diff.
+It is fine to output an empty list if no issues are found!
+
+IMPORTANT: Do not include any additional commentary outside the JSON response, your response should only contain the JSON object:
+
+```json
+{
+ "issues": [
+ <list of issues>
+ ]
+}
+```
+[ROLE=ASSISTANT]
+"""
+)
+
+
+class _SinglePromptIssueIdentifier(IssueIdentifier[CommitInputs]):
+ _identification_guides: tuple[IssueIdentificationGuide, ...]
+
+ def __init__(self, identification_guides: tuple[IssueIdentificationGuide, ...]) -> None:
+ self._identification_guides = identification_guides
+
+ @cached_property
+ def _response_schema(self) -> dict[str, Any]:
+ return GeneratedResponseSchema.model_json_schema()
+
+ def _get_prompt(
+ self,
+ project_context: ProjectContext,
+ config: VetConfig,
+ identifier_inputs: CommitInputs,
+ ) -> str:
+ # Sort the guides by issue code to ensure prompt caching (and snapshotting in tests) works.
+ sorted_guides = sorted(self._identification_guides, key=lambda guide: guide.issue_code)
+ formatted_guides = {
+ guide.issue_code: format_issue_identification_guide_for_llm(guide) for guide in sorted_guides
+ }
+
+ env = jinja2.Environment(undefined=jinja2.StrictUndefined)
+ jinja_template = env.from_string(PROMPT_TEMPLATE)
+ return jinja_template.render(
+ {
+ "include_request_and_diff": True,
+ "cached_prompt_prefix": project_context.cached_prompt_prefix,
+ "cache_full_prompt": config.cache_full_prompt,
+ "extra_context": (escape_prompt_markers(config.extra_context) if config.extra_context else None),
+ "commit_message": escape_prompt_markers(identifier_inputs.goal),
+ "unified_diff": escape_prompt_markers(identifier_inputs.diff),
+ "guides": formatted_guides,
+ "response_schema": self._response_schema,
+ }
+ )
+
+ def identify_issues(
+ self,
+ identifier_inputs: CommitInputs,
+ project_context: ProjectContext,
+ config: VetConfig,
+ ) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
+ prompt = self._get_prompt(project_context, config, identifier_inputs)
+ language_model = build_language_model_from_config(config.language_model_generation_config)
+ language_model_params = LanguageModelGenerationParams(
+ temperature=config.temperature,
+ max_tokens=config.max_output_tokens,
+ )
+ costed_response = language_model.complete_with_usage_sync(
+ prompt,
+ params=language_model_params,
+ is_caching_enabled=language_model.cache_path is not None,
+ )
+
+ response = only(costed_response.responses)
+ invocation_info = extract_invocation_info_from_costed_response(costed_response)
+
+ llm_responses = (
+ LLMResponse(
+ metadata=IssueIdentificationLLMResponseMetadata(agentic_phase=AgenticPhase.ISSUE_IDENTIFICATION),
+ raw_response=(response.text,),
+ invocation_info=invocation_info,
+ ),
+ )
+
+ yield from generate_issues_from_response_texts(response_texts=(response.text,))
+
+ return IssueIdentificationDebugInfo(llm_responses=llm_responses)
+
+ def input_type(self) -> type[CommitInputs]:
+ return CommitInputs
+
+ @property
+ def enabled_issue_codes(self) -> tuple[IssueCode, ...]:
+ return tuple(guide.issue_code for guide in self._identification_guides)
+
+ @property
+ def identifies_code_issues(self) -> bool:
+ return True
+
+
+class SinglePromptHarness(IssueIdentifierHarness[CommitInputs]):
+ def make_issue_identifier(
+ self, identification_guides: tuple[IssueIdentificationGuide, ...]
+ ) -> IssueIdentifier[CommitInputs]:
+ return _SinglePromptIssueIdentifier(identification_guides=identification_guides)
diff --git a/vet/issue_identifiers/harnesses/single_prompt_test.py b/vet/issue_identifiers/harnesses/single_prompt_test.py
@@ -0,0 +1,174 @@
+"""
+Tests for the SinglePromptHarness.
+"""
+
+import json
+from unittest import mock
+
+import pytest
+
+from imbue_core.agents.llm_apis.data_types import CostedLanguageModelResponse
+from imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams
+from imbue_core.agents.llm_apis.data_types import LanguageModelResponseUsage
+from imbue_core.agents.llm_apis.data_types import LanguageModelResponseWithLogits
+from imbue_core.agents.llm_apis.data_types import ResponseStopReason
+from imbue_core.agents.llm_apis.mock_api import LanguageModelMock
+from imbue_core.data_types import IssueCode
+from imbue_core.frozen_utils import FrozenDict
+from imbue_tools.get_conversation_history.input_data_types import CommitInputs
+from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
+from imbue_tools.get_conversation_history.input_data_types import (
+ IdentifierInputsMissingError,
+)
+from imbue_tools.repo_utils.project_context import BaseProjectContext
+from imbue_tools.types.vet_config import VetConfig
+from vet.issue_identifiers.base import IssueIdentifier
+from vet.issue_identifiers.harnesses.single_prompt import SinglePromptHarness
+from vet.issue_identifiers.identification_guides import (
+ ISSUE_CODES_FOR_CORRECTNESS_CHECK,
+)
+from vet.issue_identifiers.identification_guides import (
+ ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
+)
+from vet.issue_identifiers.utils import ReturnCapturingGenerator
+
+
+class SinglePromptHarnessMock(LanguageModelMock):
+ """Mock language model for testing SinglePromptHarness."""
+
+ response_text: str = ""
+
+ def complete_with_usage_sync(
+ self,
+ prompt: str,
+ params: LanguageModelGenerationParams,
+ is_caching_enabled: bool = True,
+ ) -> CostedLanguageModelResponse:
+ self.stats.complete_calls += 1
+ response = LanguageModelResponseWithLogits(
+ text=self.response_text,
+ token_count=len(self.response_text.split()),
+ stop_reason=ResponseStopReason.END_TURN,
+ network_failure_count=0,
+ token_probabilities=self._get_token_probabilities(self.response_text),
+ )
+ usage = LanguageModelResponseUsage(
+ prompt_tokens_used=100,
+ completion_tokens_used=50,
+ dollars_used=0.001,
+ caching_info=None,
+ )
+ return CostedLanguageModelResponse(usage=usage, responses=(response,))
+
+
+def make_identifier() -> IssueIdentifier:
+ harness = SinglePromptHarness()
+ identifier = harness.make_issue_identifier(
+ identification_guides=tuple(
+ ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[code] for code in ISSUE_CODES_FOR_CORRECTNESS_CHECK
+ )
+ )
+ return identifier
+
+
+def test_to_required_inputs() -> None:
+ identifier = make_identifier()
+
+ # Should support inputs where only the commit message and diff are present
+ commit_inputs = IdentifierInputs(maybe_goal="test", maybe_diff="test")
+ cmi = identifier.to_required_inputs(commit_inputs)
+ assert isinstance(cmi, CommitInputs)
+
+ # Should support inputs where the commit message and diff are present
+ combined_inputs = IdentifierInputs(
+ maybe_goal="test",
+ maybe_diff="test",
+ maybe_files=("test.py",),
+ maybe_conversation_history=(),
+ )
+ cmi = identifier.to_required_inputs(combined_inputs)
+ assert isinstance(cmi, CommitInputs)
+
+ # Should not support inputs where the commit message and diff are absent
+ file_inputs = IdentifierInputs(maybe_files=("test.py",))
+ with pytest.raises(IdentifierInputsMissingError):
+ identifier.to_required_inputs(file_inputs)
+ no_inputs = IdentifierInputs()
+ with pytest.raises(IdentifierInputsMissingError):
+ identifier.to_required_inputs(no_inputs)
+
+ # Should not support inputs where only one of the commit message and diff are present
+ commit_message_inputs = IdentifierInputs(maybe_goal="test", maybe_conversation_history=())
+ with pytest.raises(IdentifierInputsMissingError):
+ identifier.to_required_inputs(commit_message_inputs)
+ diff_inputs = IdentifierInputs(maybe_diff="test")
+ with pytest.raises(IdentifierInputsMissingError):
+ identifier.to_required_inputs(diff_inputs)
+
+
+def test_get_prompt_structure() -> None:
+ identifier = make_identifier()
+ project_context = BaseProjectContext(
+ file_contents_by_path=FrozenDict({"test.py": "print('hello')"}),
+ cached_prompt_prefix="[ROLE=SYSTEM]\nSystem context here",
+ )
+ commit_inputs = CommitInputs(
+ maybe_goal="Add hello world function",
+ maybe_diff="+def hello():\n+ print('hello')",
+ )
+ config = VetConfig()
+
+ prompt = identifier._get_prompt(project_context, config, commit_inputs)
+
+ # Check that prompt contains key elements
+ assert "System context here" in prompt
+ assert "Add hello world function" in prompt
+ assert "+def hello():" in prompt
+ assert "logic_error" in prompt
+ assert "runtime_error_risk" in prompt
+ assert "issues" in prompt
+ assert "schema" in prompt.lower() # Should contain schema from pydantic model
+
+
+def test_identify_issues_integration() -> None:
+ """Test the full identify_issues flow with mocked LLM."""
+ identifier = make_identifier()
+
+ # Create mock language model with specific response
+ response_text = json.dumps(
+ {
+ "issues": [
+ {
+ "issue_code": "logic_error",
+ "description": "Test logic error",
+ "severity": 4,
+ "confidence": 0.9,
+ }
+ ]
+ }
+ )
+
+ mock_language_model = SinglePromptHarnessMock(response_text=response_text)
+ with mock.patch(
+ "vet.issue_identifiers.harnesses.single_prompt.build_language_model_from_config",
+ return_value=mock_language_model,
+ ):
+ project_context = BaseProjectContext(
+ file_contents_by_path=FrozenDict({"test.py": "print('hello')"}),
+ cached_prompt_prefix="[ROLE=SYSTEM]\nSystem context",
+ )
+ commit_inputs = IdentifierInputs(maybe_goal="Add hello function", maybe_diff="+print('hello')")
+ config = VetConfig()
+
+ inputs = identifier.to_required_inputs(commit_inputs)
+ raw_issues_generator = identifier.identify_issues(inputs, project_context, config)
+ raw_issues = []
+ raw_issues_generator_with_capture = ReturnCapturingGenerator(raw_issues_generator)
+ for raw_issue in raw_issues_generator_with_capture:
+ raw_issues.append(raw_issue)
+ llm_responses = raw_issues_generator_with_capture.return_value.llm_responses
+
+ assert len(raw_issues) == 1
+ assert raw_issues[0].issue_code == IssueCode.LOGIC_ERROR
+ assert raw_issues[0].description == "Test logic error"
+ assert len(llm_responses) > 0 # Should have LLM responses
diff --git a/imbue_verify/issue_identifiers/identification_guides.py b/vet/issue_identifiers/identification_guides.py
diff --git a/vet/issue_identifiers/issue_deduplication.py b/vet/issue_identifiers/issue_deduplication.py
@@ -0,0 +1,190 @@
+import json
+from typing import Generator
+from typing import Iterable
+
+import jinja2
+
+from imbue_core.agents.llm_apis.build_apis import build_language_model_from_config
+from imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams
+from imbue_core.data_types import AgenticPhase
+from imbue_core.data_types import IssueCode
+from imbue_core.data_types import IssueIdentificationDebugInfo
+from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
+from imbue_core.data_types import LLMResponse
+from imbue_core.itertools import only
+from imbue_tools.repo_utils.context_utils import escape_prompt_markers
+from imbue_tools.types.vet_config import VetConfig
+from vet.issue_identifiers.common import GeneratedIssueSchema
+from vet.issue_identifiers.common import GeneratedResponseSchema
+from vet.issue_identifiers.common import (
+ extract_invocation_info_from_costed_response,
+)
+from vet.issue_identifiers.common import (
+ format_issue_identification_guide_for_llm,
+)
+from vet.issue_identifiers.common import generate_issues_from_response_texts
+from vet.issue_identifiers.identification_guides import (
+ ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
+)
+from vet.issue_identifiers.utils import ReturnCapturingGenerator
+
+DEDUPLICATION_PROMPT_TEMPLATE = """[ROLE=USER]
+You are reviewing the results from parallel code analysis for potential issues.
+Multiple specialized checks analyzed the work of an automated coding agent, each focusing on checking for a specific type of issue.
+
+The rubric below outlines the categories of issues we care about:
+{% for issue_code, guide in guides.items() %}
+---
+**{{ issue_code }}**:
+{{ guide }}
+{% endfor %}
+---
+
+### Individual Analysis Results ###
+{{ generated_issues }}
+
+Your task is to:
+1. Consolidate any duplicate issues
+2. If duplicates are categorized as different issue types, pick the most appropriate issue type for the merged issue according to the category definitions above.
+3. Return the consolidated set of issues
+
+Guidelines:
+- Merge issues that refer to the same underlying problem and would be solved by the same fix. Make sure that their locations (if available) are the same, and that their descriptions describe the same underlying problem. The issue_code and other properties can be different.
+- A merged issue should represent a single problem. Never merge multiple distinct problems, even if they are closely related or share the same location.
+- Never merge issues that refer to different locations, functions or files.
+- Do not remove any issues, you may only re-categorize or merge issues
+- When merging issues, pick A SINGLE most relevant location + code_part pair from the issues that you are merging together. NEVER try to combine multiple locations or code_part into one. Just pick one of them. Make sure that you repeat the code part string verbatim (including any whitespaces) in the resulting merged issue.
+- The confidence value of a merged issue should be the highest confidence value among the issues being merged.
+
+After your analysis, provide your response in JSON format matching this schema:
+
+{{ response_schema | tojson(indent=2) }}
+
+Do not output any other JSON, only the consolidated issues in the specified format:
+```json
+{
+ "issues": [
+ <list of consolidated issues>
+ ]
+}
+```
+[ROLE=ASSISTANT]
+"""
+
+
+def _get_deduplication_prompt(
+ enabled_issue_codes: Iterable[IssueCode],
+ generated_issues: str,
+) -> str:
+ # Sort issue codes to make the resulting prompts deterministic (for snapshot tests and LLM caching)
+ sorted_issue_codes = sorted(enabled_issue_codes)
+ formatted_guides = {
+ code: format_issue_identification_guide_for_llm(ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[code])
+ for code in sorted_issue_codes
+ }
+
+ env = jinja2.Environment(undefined=jinja2.StrictUndefined)
+ jinja_template = env.from_string(DEDUPLICATION_PROMPT_TEMPLATE)
+
+ prompt = jinja_template.render(
+ {
+ "guides": formatted_guides,
+ "response_schema": GeneratedResponseSchema.model_json_schema(),
+ "generated_issues": escape_prompt_markers(generated_issues),
+ }
+ )
+ return prompt
+
+
+def _convert_parsed_issues_to_combined_string(
+ all_parsed_issues: Iterable[GeneratedIssueSchema],
+) -> str:
+ """Convert all parsed issues from all issue types to a combined string for the deduplication prompt."""
+ combined_issues = []
+
+ for issue in all_parsed_issues:
+ issue_dict = issue.model_dump()
+ combined_issues.append(issue_dict)
+
+ return json.dumps({"issues": combined_issues}, indent=2)
+
+
+def deduplicate_issues(
+ issue_generator: Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo],
+ config: VetConfig,
+ enabled_issue_codes: Iterable[IssueCode],
+) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
+ """
+ Deduplicate issues from multiple issue identifiers.
+
+ Args:
+ issues: The issues to deduplicate.
+ config: Settings for imbue verify.
+ enabled_issue_codes: The issue types used by the issue identifiers.
+
+ Returns:
+ A generator of deduplicated issues. Returns IssueIdentificationDebugInfo after the generator is exhausted.
+ """
+
+ # This current implementation is not streaming. Rather, we collect all issues, then send them to the LLM for deduplication all at once.
+ # In the future, we can consider changing this into a streaming version that performs deduplication as issues come in.
+ all_issues = []
+ issue_generator_with_capture = ReturnCapturingGenerator(issue_generator)
+ for issue in issue_generator_with_capture:
+ all_issues.append(issue)
+ issue_generator_debug_info = issue_generator_with_capture.return_value
+
+ # TODO: This is a bit hacky, since it breaks abstraction boundaries:
+ # We need to apply some special handling here around issue filtration.
+ # This will go away when in the future, we move the filtration step to after the deduplication step.
+ # However, we can't do that yet, because the filtration currently only works for certain issue types.
+ # For now, we make the following compromise:
+ # - We deduplicate only over issues that pass filtration.
+ # (The resulting deduplicated issues will implicitly be set to have passed filtration as well, as per default value of _passes_filtration)
+ # - Issues that didn't pass filtration will be yielded out unchanged.
+ issues_passing_filtration = [issue for issue in all_issues if issue.passes_filtration]
+ issues_not_passing_filtration = [issue for issue in all_issues if not issue.passes_filtration]
+
+ if len(issues_passing_filtration) <= 1:
+ # None or one issues that pass filtration: nothing to deduplicate, return early
+ for issue in all_issues:
+ yield issue
+ return issue_generator_debug_info
+
+ language_model = build_language_model_from_config(config.language_model_generation_config)
+
+ # As per above TODO, only deduplicate over issues that passed filtration
+ combined_issues_string = _convert_parsed_issues_to_combined_string(issues_passing_filtration)
+ prompt = _get_deduplication_prompt(enabled_issue_codes, combined_issues_string)
+
+ costed_response = language_model.complete_with_usage_sync(
+ prompt,
+ params=LanguageModelGenerationParams(temperature=0.0, max_tokens=config.max_output_tokens),
+ is_caching_enabled=language_model.cache_path is not None,
+ )
+
+ response = only(costed_response.responses)
+ invocation_info = extract_invocation_info_from_costed_response(costed_response)
+
+ yield from generate_issues_from_response_texts(response_texts=(response.text,))
+
+ # As per above TODO, now also yield out all issues that didn't pass filtration unchanged (these will keep their passes_filtration=False)
+ for issue in issues_not_passing_filtration:
+ yield issue
+
+ deduplication_llm_responses = (
+ LLMResponse(
+ metadata=IssueIdentificationLLMResponseMetadata(
+ agentic_phase=AgenticPhase.DEDUPLICATION,
+ issue_type=None,
+ ),
+ raw_response=(response.text,),
+ invocation_info=invocation_info,
+ ),
+ )
+
+ augmented_debug_info = IssueIdentificationDebugInfo(
+ llm_responses=issue_generator_debug_info.llm_responses + deduplication_llm_responses
+ )
+
+ return augmented_debug_info
diff --git a/vet/issue_identifiers/issue_evaluation.py b/vet/issue_identifiers/issue_evaluation.py
@@ -0,0 +1,295 @@
+from typing import Generator
+
+import jinja2
+
+from imbue_core.agents.llm_apis.build_apis import build_language_model_from_config
+from imbue_core.agents.llm_apis.data_types import LanguageModelGenerationParams
+from imbue_core.data_types import AgenticPhase
+from imbue_core.data_types import IssueCode
+from imbue_core.data_types import IssueIdentificationDebugInfo
+from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
+from imbue_core.data_types import LLMResponse
+from imbue_core.itertools import only
+from imbue_core.pydantic_serialization import SerializableModel
+from imbue_tools.get_conversation_history.get_conversation_history import (
+ format_conversation_history_for_prompt,
+)
+from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
+from imbue_tools.llm_output_parsing.parse_model_json_response import (
+ ResponseParsingError,
+)
+from imbue_tools.llm_output_parsing.parse_model_json_response import (
+ parse_model_json_response,
+)
+from imbue_tools.repo_utils.context_utils import escape_prompt_markers
+from imbue_tools.repo_utils.project_context import ProjectContext
+from imbue_tools.types.vet_config import DEFAULT_CONFIDENCE_THRESHOLD
+from imbue_tools.types.vet_config import VetConfig
+from imbue_tools.util_prompts.conversation_prefix import CONVERSATION_PREFIX_TEMPLATE
+from vet.issue_identifiers.common import GeneratedIssueSchema
+from vet.issue_identifiers.common import (
+ extract_invocation_info_from_costed_response,
+)
+from vet.issue_identifiers.common import (
+ format_issue_identification_guide_for_llm,
+)
+from vet.issue_identifiers.harnesses.single_prompt import (
+ USER_REQUEST_PREFIX_TEMPLATE,
+)
+from vet.issue_identifiers.identification_guides import (
+ ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
+)
+from vet.issue_identifiers.utils import ReturnCapturingGenerator
+
+CODE_BASED_CRITERIA = (
+ "1. The issue is based on specific code, and not merely on the absence of information in the codebase snapshot. (true/false)",
+ "2. The issue does not speculate about the way a piece of code might get used without having specific knowledge of how it's being used. (true/false)",
+ "3. The issues seems important and not overly pedantic. (true/false)",
+ "4. The issue was introduced by the diff. (true/false)",
+ "5. The issue matches the issue type definition given below. (true/false)",
+ "6. The issue flags a piece of code that is already being removed by the diff (line in diff starts with a `-`). (true/false)",
+)
+
+CONVERSATION_BASED_CRITERIA = ("1. The issue matches the issue type definition given below. (true/false)",)
+
+PROMPT_TEMPLATE = """Somebody has reviewed the {% if is_code_based_issue %}diff{% else %}conversation history{% endif %} and flagged an issue with it, which you can see here:
+
+### Issue description ###
+{% filter indent(width=2) %}
+{{ issue_description }}
+{% endfilter %}
+
+Please evaluate the issue and determine whether it matches the following criteria:
+
+{% for criterion in criteria %}
+{{ criterion }}
+{% endfor %}
+
+### Issue type definition ###
+{% filter indent(width=2) %}
+**{{ issue_code }}**:
+{{ guide }}
+{% endfilter %}
+
+Please answer the questions above in the form of a JSON object with this exact JSON schema:
+
+{{ response_schema | tojson(indent=2) }}
+
+The keys correspond to the question numbers ("q1" for question 1, "q2" for question 2, and so on), and the values should be boolean values indicating whether the issue matches the criteria (true or false).
+
+IMPORTANT: Do not include any additional commentary outside the JSON response, your response should only contain the JSON object:
+
+```json
+{
+ "q1": <true|false>,
+ "q2": <true|false>,
+ ...
+}
+```
+[ROLE=ASSISTANT]
+"""
+
+
+def _get_full_prompt_template(is_code_based_issue: bool) -> str:
+ """Get the full prompt template with the appropriate prefix."""
+ prefix = USER_REQUEST_PREFIX_TEMPLATE if is_code_based_issue else CONVERSATION_PREFIX_TEMPLATE
+ return prefix + PROMPT_TEMPLATE
+
+
+class CodeBasedEvaluationResponse(SerializableModel):
+ q1: bool
+ q2: bool
+ q3: bool
+ q4: bool
+ q5: bool
+ q6: bool
+
+ def is_passing_result(self) -> bool:
+ return all([self.q1, self.q2, self.q3, self.q4, self.q5]) and not self.q6
+
+
+class ConversationBasedEvaluationResponse(SerializableModel):
+ q1: bool
+
+ def is_passing_result(self) -> bool:
+ return self.q1
+
+
+def _format_prompt(
+ issue: GeneratedIssueSchema,
+ project_context: ProjectContext,
+ config: VetConfig,
+ inputs: IdentifierInputs,
+ is_code_based_issue: bool,
+) -> str:
+ env = jinja2.Environment(undefined=jinja2.StrictUndefined)
+ prompt_template = _get_full_prompt_template(is_code_based_issue)
+ jinja_template = env.from_string(prompt_template)
+ issue_code = IssueCode(issue.issue_code)
+ guide = format_issue_identification_guide_for_llm(ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[issue_code])
+
+ criteria = CODE_BASED_CRITERIA if is_code_based_issue else CONVERSATION_BASED_CRITERIA
+ response_class = CodeBasedEvaluationResponse if is_code_based_issue else ConversationBasedEvaluationResponse
+
+ template_vars = {
+ "cached_prompt_prefix": project_context.cached_prompt_prefix,
+ "cache_full_prompt": config.cache_full_prompt,
+ "issue_description": issue.description,
+ "issue_code": issue_code,
+ "guide": guide,
+ "criteria": criteria,
+ "response_schema": response_class.model_json_schema(),
+ "is_code_based_issue": is_code_based_issue,
+ }
+
+ if is_code_based_issue:
+ template_vars["include_request_and_diff"] = True
+ template_vars["commit_message"] = escape_prompt_markers(inputs.maybe_goal or "")
+ template_vars["unified_diff"] = escape_prompt_markers(inputs.maybe_diff or "")
+ template_vars["extra_context"] = escape_prompt_markers(config.extra_context) if config.extra_context else None
+ else:
+ template_vars["conversation_history"] = format_conversation_history_for_prompt(
+ inputs.maybe_conversation_history or ()
+ )
+
+ return jinja_template.render(template_vars)
+
+
+def _parse_response(
+ response_text: str, is_code_based_issue: bool
+) -> CodeBasedEvaluationResponse | ConversationBasedEvaluationResponse:
+ # Fallback value of True for now, since we assume that most issues will pass the evaluation.
+ if is_code_based_issue:
+ FALLBACK_VALUE = CodeBasedEvaluationResponse(q1=True, q2=True, q3=True, q4=True, q5=True, q6=False)
+ response_class = CodeBasedEvaluationResponse
+ else:
+ FALLBACK_VALUE = ConversationBasedEvaluationResponse(q1=True)
+ response_class = ConversationBasedEvaluationResponse
+
+ try:
+ return parse_model_json_response(response_text, response_class)
+ except ResponseParsingError:
+ return FALLBACK_VALUE
+
+
+def evaluate_code_issue_through_llm(
+ issue: GeneratedIssueSchema,
+ inputs: IdentifierInputs,
+ project_context: ProjectContext,
+ config: VetConfig,
+ is_code_based_issue: bool,
+) -> tuple[bool, tuple[LLMResponse, ...]]:
+ """
+ Args:
+ issue: The issue to evaluate.
+ inputs: The inputs which determine the content provided to the evaluator.
+ project_context: Loaded data corresponding to the inputs, e.g. diffs or files.
+ config: Settings for the language model used to evaluate the issue.
+ is_code_based_issue: Whether this is a code-based issue (vs conversation-based).
+
+ Returns:
+ A tuple containing a boolean indicating whether the issue passes the evaluation and the LLM responses.
+ If evaluation fails because the data to judge the issue is missing, the issue is taken to have passed the evaluation.
+ """
+ if not config.filter_issues_through_llm_evaluator:
+ return True, ()
+
+ # Check that we have the required data for evaluation
+ if is_code_based_issue:
+ if inputs.maybe_goal is None or inputs.maybe_diff is None:
+ return True, ()
+ else:
+ if inputs.maybe_conversation_history is None:
+ return True, ()
+
+ language_model = build_language_model_from_config(config.language_model_generation_config)
+
+ prompt = _format_prompt(issue, project_context, config, inputs, is_code_based_issue)
+ costed_response = language_model.complete_with_usage_sync(
+ prompt,
+ params=LanguageModelGenerationParams(temperature=0.0, max_tokens=config.max_output_tokens),
+ is_caching_enabled=language_model.cache_path is not None,
+ )
+
+ response = only(costed_response.responses)
+ invocation_info = extract_invocation_info_from_costed_response(costed_response)
+ results = _parse_response(response.text, is_code_based_issue)
+
+ llm_responses = (
+ LLMResponse(
+ metadata=IssueIdentificationLLMResponseMetadata(
+ agentic_phase=AgenticPhase.FILTRATION,
+ issue_type=None,
+ ),
+ raw_response=(response.text,),
+ invocation_info=invocation_info,
+ ),
+ )
+
+ return results.is_passing_result(), llm_responses
+
+
+MODEL_CONFIDENCE_THRESHOLD_DEFAULTS: dict[str, float] = {
+ "gpt-5.1-2025-11-13": 0.0,
+}
+
+
+def get_vet_confidence_threshold(config: VetConfig) -> float:
+ model_name = config.language_model_generation_config.model_name
+
+ if model_name in MODEL_CONFIDENCE_THRESHOLD_DEFAULTS:
+ return MODEL_CONFIDENCE_THRESHOLD_DEFAULTS[model_name]
+
+ if config.filter_issues_below_confidence is not None:
+ return config.filter_issues_below_confidence
+
+ return DEFAULT_CONFIDENCE_THRESHOLD
+
+
+def evaluate_issue_through_confidence(issue: GeneratedIssueSchema, config: VetConfig) -> bool:
+ threshold = get_vet_confidence_threshold(config)
+ return issue.confidence >= threshold
+
+
+def filter_issues(
+ issue_generator: Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo],
+ inputs: IdentifierInputs,
+ project_context: ProjectContext,
+ config: VetConfig,
+ # Currently, the LLM-based filter only works reliably for code-related issue types.
+ is_code_based_issue_generator: bool,
+) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
+ """
+ Filter issues based on the evaluation.
+
+ Args:
+ results: The issues to filter.
+ inputs: The inputs which determine the content provided to the evaluator.
+ project_context: Loaded data corresponding to the inputs, e.g. diffs or files.
+ config: Settings for imbue verify.
+
+ Returns:
+ A generator of issues with the passes_filtration flag set.
+ If evaluation fails because the data to judge the issue is missing, the issue is taken to have passed the evaluation.
+ At the end of the generation, returns IssueIdentificationDebugInfo containing the LLM responses.
+ """
+
+ filter_llm_responses = []
+
+ issue_generator_with_capture = ReturnCapturingGenerator(issue_generator)
+ for issue in issue_generator_with_capture:
+ passes_filtration = evaluate_issue_through_confidence(issue, config)
+ if passes_filtration:
+ passes_filtration, llm_responses = evaluate_code_issue_through_llm(
+ issue, inputs, project_context, config, is_code_based_issue_generator
+ )
+ filter_llm_responses.extend(llm_responses)
+ issue.set_passes_filtration(passes_filtration)
+ yield issue
+ issue_generator_debug_info = issue_generator_with_capture.return_value
+
+ augmented_debug_info = IssueIdentificationDebugInfo(
+ llm_responses=issue_generator_debug_info.llm_responses + tuple(filter_llm_responses)
+ )
+
+ return augmented_debug_info
diff --git a/vet/issue_identifiers/registry.py b/vet/issue_identifiers/registry.py
@@ -0,0 +1,262 @@
+"""
+Registry of all the available issue identifiers with a `run` function for running them in an identification pipeline.
+"""
+
+from collections import defaultdict
+from enum import StrEnum
+from typing import Final
+from typing import Generator
+from typing import Iterable
+from typing import TypeVar
+
+from loguru import logger
+
+from imbue_core.agents.primitives.resource_limits import ensure_global_resource_limits
+from imbue_core.data_types import IssueCode
+from imbue_core.data_types import IssueIdentificationDebugInfo
+from imbue_core.data_types import IssueIdentificationLLMResponseMetadata
+from imbue_core.data_types import IssueIdentifierResult
+from imbue_core.data_types import IssueIdentifierType
+from imbue_tools.get_conversation_history.input_data_types import IdentifierInputs
+from imbue_tools.get_conversation_history.input_data_types import (
+ IdentifierInputsMissingError,
+)
+from imbue_tools.repo_utils.project_context import ProjectContext
+from imbue_tools.types.vet_config import VetConfig
+from imbue_tools.types.vet_config import get_enabled_issue_codes
+from vet.issue_identifiers.agentic_issue_collation import (
+ collate_issues_with_agent,
+)
+from vet.issue_identifiers.base import IssueIdentifier
+from vet.issue_identifiers.common import GeneratedIssueSchema
+from vet.issue_identifiers.common import convert_to_issue_identifier_result
+from vet.issue_identifiers.harnesses.agentic import AgenticHarness
+from vet.issue_identifiers.harnesses.base import IssueIdentifierHarness
+from vet.issue_identifiers.harnesses.conversation_single_prompt import (
+ ConversationSinglePromptHarness,
+)
+from vet.issue_identifiers.harnesses.single_prompt import SinglePromptHarness
+from vet.issue_identifiers.identification_guides import (
+ ISSUE_CODES_FOR_BATCHED_COMMIT_CHECK,
+)
+from vet.issue_identifiers.identification_guides import (
+ ISSUE_CODES_FOR_CONVERSATION_HISTORY_CHECK,
+)
+from vet.issue_identifiers.identification_guides import (
+ ISSUE_CODES_FOR_CORRECTNESS_CHECK,
+)
+from vet.issue_identifiers.identification_guides import (
+ ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
+)
+from vet.issue_identifiers.issue_deduplication import deduplicate_issues
+from vet.issue_identifiers.issue_evaluation import filter_issues
+from vet.issue_identifiers.utils import ReturnCapturingGenerator
+from vet.issue_identifiers.utils import multiplex_generators
+
+# Issue identifier harnesses together with certain default lists of issue codes.
+# This is intended as a transitionary structure to emulate the previous identifiers setup.
+# Eventually, we'll update VetConfig to no longer enable/disable specific identifiers, but instead
+# enable/disable harnesses and issue codes, and we'll pair up the enabled issue codes with the appropriate enabled
+# harnesses automatically.
+SINGLE_PROMPT_HARNESS = SinglePromptHarness()
+CONVERSATION_SINGLE_PROMPT_HARNESS = ConversationSinglePromptHarness()
+AGENTIC_HARNESS = AgenticHarness()
+HARNESS_PRESETS: Final[list[tuple[IssueIdentifierType, IssueIdentifierHarness, tuple[IssueCode, ...]]]] = [
+ (
+ IssueIdentifierType.AGENTIC_ISSUE_IDENTIFIER,
+ AGENTIC_HARNESS,
+ ISSUE_CODES_FOR_BATCHED_COMMIT_CHECK + ISSUE_CODES_FOR_CORRECTNESS_CHECK,
+ ),
+ (
+ IssueIdentifierType.BATCHED_COMMIT_CHECK,
+ SINGLE_PROMPT_HARNESS,
+ ISSUE_CODES_FOR_BATCHED_COMMIT_CHECK,
+ ),
+ (
+ IssueIdentifierType.CONVERSATION_HISTORY_IDENTIFIER,
+ CONVERSATION_SINGLE_PROMPT_HARNESS,
+ ISSUE_CODES_FOR_CONVERSATION_HISTORY_CHECK,
+ ),
+ (
+ IssueIdentifierType.CORRECTNESS_COMMIT_CLASSIFIER,
+ SINGLE_PROMPT_HARNESS,
+ ISSUE_CODES_FOR_CORRECTNESS_CHECK,
+ ),
+]
+
+
+def get_all_valid_identifier_names() -> set[IssueIdentifierType]:
+ return {name for name, _, _ in HARNESS_PRESETS}
+
+
+EnumT = TypeVar("EnumT", bound=StrEnum)
+
+
+def _convert_all_to_enum(
+ enum_strs: Iterable[str], all_enum_strs: Iterable[str], enum_type: type[EnumT]
+) -> tuple[EnumT]:
+ results = []
+ for enum_str in enum_strs:
+ if enum_str not in all_enum_strs:
+ raise ValueError(f"Bad config: unknown {enum_type.__name__} name: {enum_str}")
+ results.append(enum_type(enum_str))
+ return tuple(results)
+
+
+def _get_enabled_identifier_names(
+ config: VetConfig,
+) -> set[IssueIdentifierType]:
+ all_names = get_all_valid_identifier_names()
+ explicitly_enabled = _convert_all_to_enum(config.enabled_identifiers or tuple(), all_names, IssueIdentifierType)
+ explicitly_disabled = _convert_all_to_enum(config.disabled_identifiers or tuple(), all_names, IssueIdentifierType)
+ enabled = set(explicitly_enabled) if len(explicitly_enabled) > 0 else all_names
+ if len(explicitly_disabled) > 0:
+ enabled = set(enabled) - set(explicitly_disabled)
+ return enabled
+
+
+def _build_identifiers(
+ identifiers_to_build: set[IssueIdentifierType], enabled_issue_codes: set[IssueCode]
+) -> list[tuple[str, IssueIdentifier]]:
+ # Merge the enabled issue codes for each harness
+ enabled_issue_codes_per_harness: defaultdict[IssueIdentifierHarness, set[IssueCode]] = defaultdict(set)
+ combined_name_per_harness: defaultdict[IssueIdentifierHarness, list[str]] = defaultdict(list)
+
+ for name, harness, default_issue_codes in HARNESS_PRESETS:
+ if name in identifiers_to_build:
+ enabled_issue_codes_for_harness = enabled_issue_codes & set(default_issue_codes)
+ if enabled_issue_codes_for_harness:
+ enabled_issue_codes_per_harness[harness].update(enabled_issue_codes_for_harness)
+ combined_name_per_harness[harness].append(name.value)
+
+ identifiers: list[tuple[str, IssueIdentifier]] = []
+ for harness, issue_codes in enabled_issue_codes_per_harness.items():
+ combined_name = "+".join(combined_name_per_harness[harness])
+ identifiers.append(
+ (
+ combined_name,
+ harness.make_issue_identifier(
+ identification_guides=tuple(ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[code] for code in issue_codes)
+ ),
+ )
+ )
+
+ return identifiers
+
+
+def _generate_with_name_in_debug_info(
+ name: str,
+ generator: Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo],
+) -> Generator[GeneratedIssueSchema, None, tuple[str, IssueIdentificationDebugInfo]]:
+ generator_with_capture = ReturnCapturingGenerator(generator)
+ for result in generator_with_capture:
+ yield result
+ return name, generator_with_capture.return_value
+
+
+def _combine_issue_generator_debug_info(
+ generator: Generator[GeneratedIssueSchema, None, tuple[tuple[str, IssueIdentificationDebugInfo], ...]],
+) -> Generator[GeneratedIssueSchema, None, IssueIdentificationDebugInfo]:
+ collected_debug_info: tuple[tuple[str, IssueIdentificationDebugInfo], ...] = (yield from generator)
+
+ updated_llm_responses = []
+ for identifier_name, debug_info in collected_debug_info:
+ for response in debug_info.llm_responses:
+ assert isinstance(response.metadata, IssueIdentificationLLMResponseMetadata)
+ updated_response = response.evolve(response.ref().metadata.identifier_name, identifier_name)
+ updated_llm_responses.append(updated_response)
+
+ return IssueIdentificationDebugInfo(llm_responses=tuple(updated_llm_responses))
+
+
+def run(
+ identifier_inputs: IdentifierInputs,
+ project_context: ProjectContext,
+ config: VetConfig,
+) -> Generator[IssueIdentifierResult, None, IssueIdentificationDebugInfo]:
+ """
+ Run all the registered and configured issue identifiers on the given inputs.
+ """
+ enabled_issue_codes = get_enabled_issue_codes(config)
+ identifiers = _build_identifiers(_get_enabled_identifier_names(config), enabled_issue_codes)
+ ensure_global_resource_limits(max_dollars=config.max_identifier_spend_dollars)
+
+ issue_generators: list[Generator[GeneratedIssueSchema, None, tuple[str, IssueIdentificationDebugInfo]]] = []
+ compatible_enabled_identifier_names: list[str] = []
+ # The set of issue codes that can be detected by the compatible identifiers. A subset of enabled_issue_codes.
+ detectable_issue_codes: set[IssueCode] = set()
+ for identifier_name, identifier in identifiers:
+ # 1. Identification
+ try:
+ inputs = identifier.to_required_inputs(identifier_inputs)
+ identified_issues_generator = identifier.identify_issues(inputs, project_context, config)
+ compatible_enabled_identifier_names.append(identifier_name)
+ detectable_issue_codes.update(identifier.enabled_issue_codes)
+ except IdentifierInputsMissingError as e:
+ logger.debug(
+ "skipping identifier {} because of missing inputs: {}",
+ identifier_name,
+ e,
+ )
+ continue
+
+ # 2. Collation for agentic identifiers
+ if identifier.requires_agentic_collation and config.enable_collation:
+ try:
+ collated_issues_generator = collate_issues_with_agent(
+ identified_issues_generator,
+ identifier_inputs,
+ project_context,
+ config,
+ identifier.enabled_issue_codes,
+ )
+ except IdentifierInputsMissingError as e:
+ logger.warning(
+ "collate_issues_with_agent requires commit message and diff, skipping: {}",
+ e,
+ )
+ continue
+ else:
+ collated_issues_generator = identified_issues_generator
+
+ # 3. Filtration
+ if config.filter_issues:
+ filtered_results_generator = filter_issues(
+ collated_issues_generator,
+ identifier_inputs,
+ project_context,
+ config,
+ is_code_based_issue_generator=identifier.identifies_code_issues,
+ )
+ else:
+ filtered_results_generator = collated_issues_generator
+
+ issue_generators.append(_generate_with_name_in_debug_info(identifier_name, filtered_results_generator))
+
+ logger.info(
+ "Using the following issue identifiers compatible with the input: {}",
+ ", ".join([n for n in compatible_enabled_identifier_names]),
+ )
+
+ multiplexed_generators = multiplex_generators(issue_generators, max_workers=config.max_identify_workers)
+ multiplexed_generators_with_combined_debug_info = _combine_issue_generator_debug_info(multiplexed_generators)
+
+ # 4. Deduplicate issues across all identifiers
+ if config.enable_deduplication:
+ deduplicated_generator = deduplicate_issues(
+ multiplexed_generators_with_combined_debug_info,
+ config,
+ tuple(detectable_issue_codes),
+ )
+ else:
+ deduplicated_generator = multiplexed_generators_with_combined_debug_info
+
+ # Conversion from GeneratedIssueSchema to IssueIdentifierResult
+ converted_issues_generator = convert_to_issue_identifier_result(
+ deduplicated_generator, project_context, tuple(enabled_issue_codes)
+ )
+
+ # Yield out results
+ debug_info = yield from converted_issues_generator
+
+ return debug_info
diff --git a/vet/issue_identifiers/test_prompt_lengths.py b/vet/issue_identifiers/test_prompt_lengths.py
@@ -0,0 +1,60 @@
+from imbue_core.data_types import IssueIdentifierType
+from imbue_core.frozen_utils import FrozenDict
+from imbue_core.itertools import first
+from imbue_tools.get_conversation_history.input_data_types import CommitInputs
+from imbue_tools.repo_utils.project_context import BaseProjectContext
+from imbue_tools.types.vet_config import VetConfig
+from vet.issue_identifiers import registry
+from vet.issue_identifiers.identification_guides import (
+ ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE,
+)
+from vet.repo_utils import VET_MAX_PROMPT_TOKENS
+
+EMPTY_PROJECT_CONTEXT = BaseProjectContext(file_contents_by_path=FrozenDict(), cached_prompt_prefix="")
+DEFAULT_VET_CONFIG = VetConfig()
+
+
+# Helper functions to extract a base prompt for different identifier types.
+PROMPT_EXTRACTOR_FUNCTIONS = {
+ IssueIdentifierType.BATCHED_COMMIT_CHECK: lambda identifier: identifier._get_prompt(
+ EMPTY_PROJECT_CONTEXT,
+ DEFAULT_VET_CONFIG,
+ CommitInputs(maybe_goal="", maybe_diff=""),
+ ),
+ IssueIdentifierType.CORRECTNESS_COMMIT_CLASSIFIER: lambda identifier: identifier._get_prompt(
+ EMPTY_PROJECT_CONTEXT,
+ DEFAULT_VET_CONFIG,
+ CommitInputs(maybe_goal="", maybe_diff=""),
+ ),
+}
+
+
+def _estimate_tokens(prompt: str) -> int:
+ """
+ Estimate the number of tokens in a prompt.
+ This is a rough estimate and may not be accurate for all models.
+ """
+ # A factor of 1/4.5 appears to be a reasonable empirical estimate for current models.
+ # We use a slighly larger factor (1/4) to have a more conservative estimate.
+ return round(len(prompt) / 4)
+
+
+def test_prompt_lengths() -> None:
+ """
+ Test that the prompt lengths for various issue identifiers do not exceed the maximum allowed length.
+ This is important to ensure that the LLM can process the prompts without raising errors.
+ """
+
+ for identifier_name, extract_prompt in PROMPT_EXTRACTOR_FUNCTIONS.items():
+ identifier = first(
+ [
+ harness.make_issue_identifier(tuple(ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE[c] for c in codes))
+ for name, harness, codes in registry.HARNESS_PRESETS
+ if name == identifier_name
+ ]
+ )
+ prompt = extract_prompt(identifier)
+ num_tokens = _estimate_tokens(prompt)
+ assert (
+ num_tokens <= VET_MAX_PROMPT_TOKENS
+ ), f"Prompt for {identifier_name} exceeds VET_MAX_PROMPT_TOKENS. Consider increasing VET_MAX_PROMPT_TOKENS or shortening the prompt. "
diff --git a/imbue_verify/issue_identifiers/utils.py b/vet/issue_identifiers/utils.py
diff --git a/vet/issue_identifiers/utils_test.py b/vet/issue_identifiers/utils_test.py
@@ -0,0 +1,92 @@
+import contextvars
+import threading
+from typing import Generator
+
+from vet.issue_identifiers.utils import multiplex_generators
+from vet.issue_identifiers.utils import xml_post_escape
+
+
+def test_xml_post_escape_does_not_escape_if_not_necessary() -> None:
+ input_string = "<root><code_part>hello</code_part></root>"
+ assert xml_post_escape(input_string, "code_part") == input_string
+
+
+def test_xml_post_escape_properly_escapes_single_line() -> None:
+ input_string = "<root><desc>Hey</desc><code_part>1 < 2</code_part></root>"
+ assert xml_post_escape(input_string, "code_part") == "<root><desc>Hey</desc><code_part>1 < 2</code_part></root>"
+
+
+def test_xml_post_escape_properly_escapes_multi_line() -> None:
+ input_string = """
+ <root>
+ <code_part>
+ 1 < 2
+ </code_part>
+ </root>
+ """
+ assert (
+ xml_post_escape(input_string, "code_part")
+ == """
+ <root>
+ <code_part>
+ 1 < 2
+ </code_part>
+ </root>
+ """
+ )
+
+
+def test_xml_post_escape_does_not_escape_if_not_asked_to() -> None:
+ input_string = "<root><desc>Hey</desc><code_part>1 < 2</code_part></root>"
+ assert xml_post_escape(input_string, "desc") == "<root><desc>Hey</desc><code_part>1 < 2</code_part></root>"
+
+
+def test_xml_post_escape_does_not_change_case() -> None:
+ input_string = "<root><desc>Hey</desc><code_part>1 < 2</CODE_PART></root>"
+ assert xml_post_escape(input_string, "code_part") == "<root><desc>Hey</desc><code_part>1 < 2</CODE_PART></root>"
+
+
+def test_xml_post_escape_does_nothing_if_element_not_present() -> None:
+ input_string = "<root><greeting>hello</greeting></root>"
+ assert xml_post_escape(input_string, "code_part") == input_string
+
+
+def _generator_with_barrier(value: int, count: int, barrier: threading.Barrier) -> Generator[int, None, int]:
+ for i in range(count):
+ barrier.wait(timeout=1.0)
+ yield value + i
+ return value * 100
+
+
+def test_multiplex_generators_runs_in_parallel() -> None:
+ barrier = threading.Barrier(2)
+
+ gen1 = _generator_with_barrier(0, 3, barrier)
+ gen2 = _generator_with_barrier(10, 3, barrier)
+
+ multiplexed = multiplex_generators([gen1, gen2], max_workers=2)
+
+ results = []
+ for item in multiplexed:
+ results.append(item)
+
+ assert len(results) == 6
+ assert set(results) == {0, 1, 2, 10, 11, 12}
+
+
+def test_multiple_generators_transfers_contextvars() -> None:
+ """Test that existing context variables are transferred to the generator threads."""
+ var = contextvars.ContextVar("test_var", default=123)
+
+ def _gen_with_contextvar() -> Generator[int, None, None]:
+ yield var.get()
+
+ gen = _gen_with_contextvar()
+
+ multiplexed = multiplex_generators([gen])
+
+ results = []
+ for item in multiplexed:
+ results.append(item)
+
+ assert results == [123]
diff --git a/imbue_verify/py.typed b/vet/py.typed
diff --git a/vet/repo_utils.py b/vet/repo_utils.py
@@ -0,0 +1,73 @@
+from pathlib import Path
+
+from imbue_core.async_monkey_patches import log_exception
+from imbue_core.computing_environment.data_types import RunCommandError
+from imbue_core.simple_git import SyncLocalGitRepo
+from imbue_tools.repo_utils.find_relative_to import find_relative_to_commit_hash
+from vet.errors import GitException
+
+# Maximum length of LLM prompts used within vet in tokens, without the repository-specific context.
+# Currently, the prompt is well under 10k tokens, but this value might need to be bumped up if we add a lot of additional
+# identification guides, few-shot examples, or other context.
+VET_MAX_PROMPT_TOKENS = 10000
+
+
+def get_code_to_check(relative_to: str, repo_path: Path) -> tuple[str, str, str]:
+ """
+ Returns:
+ - The commit hash to use as the base commit for the diff.
+ - The combined diff including staged, unstaged, and untracked changes. (compatible with `git apply`)
+ - The combined diff but with binary diffs shortened. (cannot be applied if binary changes are present)
+ """
+ try:
+ base_commit = find_relative_to_commit_hash(relative_to, repo_path=repo_path)
+ except RunCommandError as e:
+ raise GitException(f"Unable to determine base commit for code verification: {e}") from e
+
+ repo = SyncLocalGitRepo(repo_path)
+
+ # Get the combined diff which includes all changes; staged, unstaged, and untracked.
+ try:
+ combined_diff = repo.get_git_diff(commit_hash=base_commit)
+ combined_diff_no_binary = repo.get_git_diff(commit_hash=base_commit, include_binary=False)
+ except RunCommandError as e:
+ raise GitException(f"Unable to get diff to {base_commit}: {e}") from e
+
+ # Get untracked files since we want to include these as part of the unstaged and full changes
+ try:
+ untracked_files = repo.get_untracked_files()
+ except RunCommandError as e:
+ raise GitException(f"Unable to get untracked files: {e}") from e
+
+ # Create diffs for untracked files (treat them as new files)
+ untracked_diffs = []
+ untracked_diffs_no_binary = []
+ for file_path in untracked_files:
+ if file_path: # Skip empty lines
+ try:
+ untracked_diff = repo.get_untracked_file_diff(file_path, include_binary=True)
+ untracked_diffs.append(untracked_diff)
+ except RunCommandError as e:
+ log_exception(
+ e,
+ "Skipping untracked file we couldn't diff: {file_path}",
+ file_path=file_path,
+ )
+
+ try:
+ untracked_diff_no_binary = repo.get_untracked_file_diff(file_path, include_binary=False)
+ untracked_diffs_no_binary.append(untracked_diff_no_binary)
+ except RunCommandError as e:
+ log_exception(
+ e,
+ "Skipping untracked file we couldn't diff (no binary): {file_path}",
+ file_path=file_path,
+ )
+
+ # Add untracked files to unstaged changes and the combined diff
+ if untracked_diffs:
+ combined_diff += "\n" + "\n".join(untracked_diffs)
+ if untracked_diffs_no_binary:
+ combined_diff_no_binary += "\n" + "\n".join(untracked_diffs_no_binary)
+
+ return base_commit, combined_diff, combined_diff_no_binary
diff --git a/vet/repo_utils_test.py b/vet/repo_utils_test.py
@@ -0,0 +1,97 @@
+import subprocess
+from pathlib import Path
+
+from syrupy.assertion import SnapshotAssertion
+
+from imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
+from imbue_core.nested_evolver import assign
+from imbue_core.nested_evolver import chill
+from imbue_core.nested_evolver import evolver
+from imbue_tools.repo_utils.project_context import LazyProjectContext
+from vet.repo_utils import get_code_to_check
+
+
+def test_get_code_to_check(simple_test_git_repo: Path) -> None:
+ """Test that get_code_to_check correctly handles staged, unstaged, and untracked files"""
+ repo_path = simple_test_git_repo
+ first_commit = subprocess.run(
+ ["git", "rev-parse", "HEAD"],
+ cwd=repo_path,
+ capture_output=True,
+ text=True,
+ check=True,
+ ).stdout.strip()
+
+ # Create an untracked file
+ new_file_content = "This is a new untracked file\nwith multiple lines\nof content"
+ (repo_path / "new_file.txt").write_text(new_file_content)
+ (repo_path / "new_file.bin").write_bytes(b"\x00\x01\x02")
+
+ # Create a committed change
+ (repo_path / "file1.txt").write_text("committed modified content\n")
+ (repo_path / "file1.bin").write_bytes(b"\x00\x01\x02")
+ subprocess.run(["git", "add", "file1.txt"], cwd=repo_path, check=True)
+ subprocess.run(["git", "commit", "-m", "Modify file1"], cwd=repo_path, check=True)
+
+ # Create a staged change
+ with open((repo_path / "file1.txt"), "a+") as f:
+ # make sure to have multiple newlines to sepearate changes so they don't get
+ # picked up in same diff block
+ f.write("\nstaged written modified content\n")
+ subprocess.run(["git", "add", "file1.txt"], cwd=repo_path, check=True)
+
+ # Create an unstaged change
+ with open((repo_path / "file1.txt"), "a+") as f:
+ f.write("\nunstaged written modified content")
+
+ git_hash, diff, diff_no_binary = get_code_to_check(first_commit, repo_path=repo_path)
+
+ assert git_hash == first_commit
+
+ # Verify the untracked file is included in the diffs
+ assert "new_file.txt" in diff
+ assert "new_file.bin" in diff
+ assert "new_file.txt" in diff_no_binary
+ assert "new_file.bin" in diff_no_binary
+ assert "Binary files /dev/null and b/new_file.bin differ" in diff_no_binary
+
+ # Verify tracked changes are also included
+ assert "file1.txt" in diff
+ assert "+staged written modified content" in diff
+ assert "+unstaged written modified content" in diff
+ assert "+committed modified content" in diff
+ assert "file1.bin" in diff
+
+ assert "file1.txt" in diff_no_binary
+ assert "+staged written modified content" in diff_no_binary
+ assert "+unstaged written modified content" in diff_no_binary
+ assert "+committed modified content" in diff_no_binary
+ assert "Binary files /dev/null and b/file1.bin differ" in diff_no_binary
+
+
+def test_build_context(simple_test_git_repo: Path, snapshot: SnapshotAssertion) -> None:
+ first_commit = subprocess.run(
+ ["git", "rev-parse", "HEAD"],
+ cwd=simple_test_git_repo,
+ capture_output=True,
+ text=True,
+ check=True,
+ ).stdout.strip()
+ git_hash, diff, _diff_no_binary = get_code_to_check(first_commit, repo_path=simple_test_git_repo)
+ project_context = LazyProjectContext.build(
+ git_hash,
+ diff,
+ language_model_name=AnthropicModelName.CLAUDE_4_5_HAIKU_2025_10_01,
+ repo_path=simple_test_git_repo,
+ tokens_to_reserve=20000,
+ ).to_base_project_context()
+ assert project_context.repo_path == simple_test_git_repo
+
+ # the temp dir isn't the same every time so we need to remove it
+ project_context_evolver = evolver(project_context)
+ assign(
+ project_context_evolver.repo_path,
+ lambda: None,
+ )
+ project_context_without_repo_path = chill(project_context_evolver)
+ assert project_context_without_repo_path == snapshot
diff --git a/vet_types/pyproject.toml b/vet_types/pyproject.toml
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vet_types"
version = "0.1.0"
-description = "Type definitions for VET (imbue-verify) without telemetry dependencies"
+description = "Type definitions for Vet without telemetry dependencies"
dependencies = [
"pydantic",
"imbue_core",
diff --git a/vet_types/vet_types/__init__.py b/vet_types/vet_types/__init__.py
@@ -1,4 +1,4 @@
-"""Shared type definitions for imbue_verify."""
+"""Shared type definitions for Vet."""
from vet_types.chat_state import ContentBlock
from vet_types.chat_state import ContentBlockTypes
diff --git a/vet_types/vet_types/chat_state.py b/vet_types/vet_types/chat_state.py
@@ -1,4 +1,4 @@
-"""Chat state types for imbue_verify."""
+"""Chat state types for Vet."""
from typing import Annotated
from typing import Any
@@ -56,7 +56,9 @@ class CommandBlock(ContentBlock):
object_type: str = "CommandBlock"
type: Literal["command"] = "command"
command: str
- is_automated: bool = Field(default=False, description="Whether the command is automated")
+ is_automated: bool = Field(
+ default=False, description="Whether the command is automated"
+ )
ToolInput = dict[str, Any]
@@ -67,13 +69,17 @@ class ToolUseBlock(ContentBlock):
type: Literal["tool_use"] = "tool_use"
id: ToolUseID = Field(..., description="Unique identifier for this tool use")
name: str = Field(..., description="Name of the tool being used")
- input: ToolInput = Field(default_factory=ToolInput, description="Input parameters for the tool")
+ input: ToolInput = Field(
+ default_factory=ToolInput, description="Input parameters for the tool"
+ )
class ToolResultContent(SerializableModel):
"""Base class for tool result content with type discriminator"""
- content_type: str = Field(..., description="Type discriminator for tool result content")
+ content_type: str = Field(
+ ..., description="Type discriminator for tool result content"
+ )
class SimpleToolContent(ToolResultContent):
@@ -108,9 +114,15 @@ class ToolResultBlock(ContentBlock):
type: Literal["tool_result"] = "tool_result"
tool_use_id: ToolUseID = Field(..., description="ID of the corresponding tool use")
tool_name: str = Field(..., description="Name of the tool that was used")
- invocation_string: str = Field(..., description="String representation of how the tool was invoked")
- content: ToolResultContentType = Field(..., description="Result content from the tool execution")
- is_error: bool = Field(default=False, description="Whether the tool execution resulted in an error")
+ invocation_string: str = Field(
+ ..., description="String representation of how the tool was invoked"
+ )
+ content: ToolResultContentType = Field(
+ ..., description="Result content from the tool execution"
+ )
+ is_error: bool = Field(
+ default=False, description="Whether the tool execution resulted in an error"
+ )
class WarningBlock(ContentBlock):
@@ -118,7 +130,9 @@ class WarningBlock(ContentBlock):
type: Literal["warning"] = "warning"
message: str = Field(..., description="Warning message")
traceback: str | None = Field(..., description="Warning traceback")
- warning_type: str | None = Field(..., description="Type of warning, i.e. name of the exception that was raised")
+ warning_type: str | None = Field(
+ ..., description="Type of warning, i.e. name of the exception that was raised"
+ )
class ErrorBlock(ContentBlock):
@@ -126,7 +140,9 @@ class ErrorBlock(ContentBlock):
type: Literal["error"] = "error"
message: str = Field(..., description="Error message")
traceback: str = Field(..., description="Error traceback")
- error_type: str = Field(..., description="Type of error, i.e. name of the exception that was raised")
+ error_type: str = Field(
+ ..., description="Type of error, i.e. name of the exception that was raised"
+ )
class FileBlock(ContentBlock):
diff --git a/vet_types/vet_types/ids.py b/vet_types/vet_types/ids.py
@@ -1,4 +1,4 @@
-"""ID types for imbue_verify."""
+"""ID types for Vet."""
from abc import ABC
from typing import Any
@@ -32,17 +32,23 @@ class ObjectID(TypeID, ABC):
# For convenience, don't require the caller to strip the prefix from existing IDs.
if prefix is not None:
if prefix != self.tag:
- raise TypeIDPrefixMismatchError(f"Expected prefix '{self.tag}', got '{prefix}'")
+ raise TypeIDPrefixMismatchError(
+ f"Expected prefix '{self.tag}', got '{prefix}'"
+ )
value = suffix
super().__init__(self.tag, value)
@classmethod
- def __get_pydantic_core_schema__(cls, source_type: type, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
+ def __get_pydantic_core_schema__(
+ cls, source_type: type, handler: GetCoreSchemaHandler
+ ) -> core_schema.CoreSchema:
"""
Support transparently deserializing strings into ObjectID instances and vice versa.
"""
return core_schema.no_info_before_validator_function(
- lambda raw_value: (cls(raw_value) if isinstance(raw_value, str) else raw_value),
+ lambda raw_value: (
+ cls(raw_value) if isinstance(raw_value, str) else raw_value
+ ),
core_schema.union_schema(
[
core_schema.is_instance_schema(cls),
@@ -71,12 +77,16 @@ class NonEmptyStr(str):
return value
@classmethod
- def __get_pydantic_core_schema__(cls, source_type: type, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
+ def __get_pydantic_core_schema__(
+ cls, source_type: type, handler: GetCoreSchemaHandler
+ ) -> core_schema.CoreSchema:
"""
Support transparently deserializing strings into ObjectID instances and vice versa.
"""
return core_schema.no_info_before_validator_function(
- lambda raw_value: (cls(raw_value) if isinstance(raw_value, str) else raw_value),
+ lambda raw_value: (
+ cls(raw_value) if isinstance(raw_value, str) else raw_value
+ ),
core_schema.union_schema(
[
core_schema.is_instance_schema(cls),
diff --git a/vet_types/vet_types/messages.py b/vet_types/vet_types/messages.py
@@ -1,4 +1,4 @@
-"""Message types for imbue_verify conversation history.
+"""Message types for Vet conversation history.
These are simplified versions that avoid dependencies on external telemetry libraries.
"""
@@ -65,11 +65,15 @@ class Message(SerializableModel):
# the source of the message, which can be either the agent, user, or runner.
source: AgentMessageSource
# roughly when the message was created, in UTC.
- approximate_creation_time: datetime.datetime = Field(default_factory=get_current_time)
+ approximate_creation_time: datetime.datetime = Field(
+ default_factory=get_current_time
+ )
@property
def is_ephemeral(self) -> bool:
- raise NotImplementedError("All messages must be subclassed off of PersistentMessage or EphemeralMessage")
+ raise NotImplementedError(
+ "All messages must be subclassed off of PersistentMessage or EphemeralMessage"
+ )
class PersistentMessage(Message):