vet

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

commit e2ad4d9a0c05b6fcf96fa98edf4e05a335d9485f
parent 6f2f275154adfa37391d9f0c3fd2ea65e2a50b4f
Author: andrewlaack-collab <andrew.laack@imbue.com>
Date:   Tue,  3 Feb 2026 00:29:11 +0000

Context Fixes (#8)

* Wrote tests, updated code

* Formatter

* Refactoring stuff

* Refactoring

* Refactoring

* refactor

* refactoring

* Removing useless tests

* Pruning tests
Diffstat:
Mpyproject.toml | 1+
Muv.lock | 23+++++++++++++++++++++++
Mvet/api.py | 27+++++++++++++++++++++++++--
Mvet/cli/main.py | 2+-
Mvet/imbue_tools/get_conversation_history/get_conversation_history.py | 21+++++++++++++++++++--
Mvet/imbue_tools/get_conversation_history/input_data_types.py | 13+++++++++++--
Mvet/imbue_tools/types/vet_config.py | 2--
Mvet/imbue_tools/util_prompts/conversation_prefix.py | 3+++
Mvet/imbue_tools/util_prompts/goal_from_conversation.py | 10++++++++--
Mvet/issue_identifiers/harnesses/agentic.py | 20++++++++++++++++----
Mvet/issue_identifiers/harnesses/conversation_single_prompt.py | 20++++++++++++++++++--
Mvet/issue_identifiers/harnesses/single_prompt.py | 49++++++++++++++++++++++++++++++++++++++++++++-----
Mvet/issue_identifiers/issue_evaluation.py | 26++++++++++++++++++++++----
Avet/truncation.py | 112+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Avet/truncation_test.py | 224+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
15 files changed, 527 insertions(+), 26 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml @@ -68,6 +68,7 @@ include = ["vet*"] [dependency-groups] dev = [ "black", + "hypothesis", "moto>=4.1.12", ] diff --git a/uv.lock b/uv.lock @@ -817,6 +817,18 @@ wheels = [ ] [[package]] +name = "hypothesis" +version = "6.151.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/47/03/9fd03d5db09029250e69745c1600edab16fe90947636f77a12ba92d79939/hypothesis-6.151.4.tar.gz", hash = "sha256:658a62da1c3ccb36746ac2f7dc4bb1a6e76bd314e0dc54c4e1aaba2503d5545c", size = 475706, upload-time = "2026-01-29T01:30:14.985Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/6d/01ad1b6c3b8cb2bb47eeaa9765dabc27cbe68e3b59f6cff83d5668f57780/hypothesis-6.151.4-py3-none-any.whl", hash = "sha256:a1cf7e0fdaa296d697a68ff3c0b3912c0050f07aa37e7d2ff33a966749d1d9b4", size = 543146, upload-time = "2026-01-29T01:30:12.805Z" }, +] + +[[package]] name = "idna" version = "3.11" source = { registry = "https://pypi.org/simple" } @@ -2175,6 +2187,15 @@ wheels = [ ] [[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload-time = "2021-05-16T22:03:42.897Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, +] + +[[package]] name = "syrupy" version = "5.1.0" source = { registry = "https://pypi.org/simple" } @@ -2443,6 +2464,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "black" }, + { name = "hypothesis" }, { name = "moto" }, ] @@ -2494,6 +2516,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "black" }, + { name = "hypothesis" }, { name = "moto", specifier = ">=4.1.12" }, ] diff --git a/vet/api.py b/vet/api.py @@ -18,11 +18,17 @@ from vet.imbue_tools.get_conversation_history.input_data_types import Identifier from vet.imbue_tools.repo_utils.project_context import LazyProjectContext from vet.imbue_tools.repo_utils.project_context import ProjectContext from vet.imbue_tools.types.vet_config import VetConfig -from vet.imbue_tools.util_prompts.goal_from_conversation import get_goal_from_conversation +from vet.imbue_tools.util_prompts.goal_from_conversation import ( + get_goal_from_conversation, +) from vet.issue_identifiers import registry from vet.issue_identifiers.utils import ReturnCapturingGenerator -from vet.repo_utils import VET_MAX_PROMPT_TOKENS from vet.repo_utils import get_code_to_check +from vet.repo_utils import VET_MAX_PROMPT_TOKENS +from vet.truncation import ContextBudget +from vet.truncation import get_available_tokens +from vet.truncation import get_token_budget +from vet.truncation import truncate_to_token_limit def get_issues_with_raw_responses( @@ -33,6 +39,7 @@ def get_issues_with_raw_responses( config: VetConfig, repo_path: Path, conversation_history: tuple[ConversationMessageUnion, ...] | None = None, + extra_context: str | None = None, ) -> tuple[tuple[IdentifiedVerifyIssue, ...], IssueIdentificationDebugInfo, ProjectContext]: if not goal or not goal.strip(): logger.info("No goal was provided, generating one from conversation history") @@ -52,10 +59,22 @@ def get_issues_with_raw_responses( goal = "" lm_config = config.language_model_generation_config + + available_tokens = get_available_tokens(config) + diff_budget = get_token_budget(available_tokens, ContextBudget.DIFF) + if diff_no_binary: + diff_no_binary, diff_truncated = truncate_to_token_limit( + diff_no_binary, + max_tokens=diff_budget, + count_tokens=lm_config.count_tokens, + label="diff", + truncate_end=True, + ) diff_no_binary_tokens = lm_config.count_tokens(diff_no_binary) else: diff_no_binary_tokens = 0 + diff_truncated = False project_context = LazyProjectContext.build( base_commit, @@ -72,6 +91,8 @@ def get_issues_with_raw_responses( maybe_diff=diff_no_binary or None, maybe_goal=goal, maybe_conversation_history=conversation_history, + diff_truncated=diff_truncated, + maybe_extra_context=extra_context, ) results_generator = registry.run( @@ -96,6 +117,7 @@ def find_issues( goal: str, config: VetConfig, conversation_history: tuple[ConversationMessageUnion, ...] | None = None, + extra_context: str | None = None, ) -> tuple[IdentifiedVerifyIssue, ...]: logger.info( "Finding issues in {repo_path} relative to commit hash {relative_to}", @@ -121,5 +143,6 @@ def find_issues( config=config, repo_path=repo_path, conversation_history=conversation_history, + extra_context=extra_context, ) return issues diff --git a/vet/cli/main.py b/vet/cli/main.py @@ -456,7 +456,6 @@ def main(argv: list[str] | None = None) -> int: filter_issues_below_confidence=args.confidence_threshold, max_identify_workers=args.max_workers, max_output_tokens=max_output_tokens or 20000, - extra_context=extra_context, ) issues = find_issues( @@ -465,6 +464,7 @@ def main(argv: list[str] | None = None) -> int: goal=goal, config=config, conversation_history=conversation_history, + extra_context=extra_context, ) output_fields = args.output_fields if args.output_fields else OUTPUT_FIELDS diff --git a/vet/imbue_tools/get_conversation_history/get_conversation_history.py b/vet/imbue_tools/get_conversation_history/get_conversation_history.py @@ -1,4 +1,5 @@ import json +from typing import Callable from typing import assert_never from loguru import logger @@ -6,6 +7,7 @@ from pydantic import TypeAdapter from pydantic import ValidationError from vet.vet_types.chat_state import ContentBlockTypes +from vet.truncation import truncate_to_token_limit from vet.vet_types.messages import ChatInputUserMessage from vet.vet_types.messages import ConversationMessageUnion from vet.vet_types.messages import ResponseBlockAgentMessage @@ -46,9 +48,24 @@ def delete_unnecessary_conversation_message_fields( def format_conversation_history_for_prompt( conversation_history: tuple[ConversationMessageUnion, ...], -) -> str: + max_tokens: int | None = None, + count_tokens: Callable[[str], int] | None = None, +) -> tuple[str, bool]: formatted_messages = [delete_unnecessary_conversation_message_fields(message) for message in conversation_history] - return "\n".join(message for message in formatted_messages if message is not None) + result = "\n".join(message for message in formatted_messages if message is not None) + + if max_tokens is not None and count_tokens is not None: + + result, was_truncated = truncate_to_token_limit( + result, + max_tokens=max_tokens, + count_tokens=count_tokens, + label="conversation history", + truncate_end=False, + ) + return result, was_truncated + + return result, False # === loading from file === diff --git a/vet/imbue_tools/get_conversation_history/input_data_types.py b/vet/imbue_tools/get_conversation_history/input_data_types.py @@ -12,15 +12,24 @@ class IdentifierInputsMissingError(Exception): class IdentifierInputs(SerializableModel): - # goal (for now, commit message) and diff to check + # goal maybe_goal: str | None = None + goal_truncated: bool = False + + # diff maybe_diff: str | None = None + diff_truncated: bool = False - # whole files to check + # files to check maybe_files: tuple[str, ...] | None = None # conversation history to check maybe_conversation_history: tuple[ConversationMessageUnion, ...] | None = None + conversation_truncated: bool = False + + # additional user supplied context + maybe_extra_context: str | None = None + extra_context_truncated: bool = False class CommitInputs(IdentifierInputs): diff --git a/vet/imbue_tools/types/vet_config.py b/vet/imbue_tools/types/vet_config.py @@ -47,8 +47,6 @@ class VetConfig(SerializableModel): # contexts (such as black_box_evals) where the same inputs are being evaluated multiple times. cache_full_prompt: bool = False - extra_context: str | None = None - @classmethod def build( cls, diff --git a/vet/imbue_tools/util_prompts/conversation_prefix.py b/vet/imbue_tools/util_prompts/conversation_prefix.py @@ -2,6 +2,9 @@ CONVERSATION_PREFIX_TEMPLATE = """[ROLE=SYSTEM_CACHED] You will be provided a conversation history between a user and another agent. The other agent may be from any model provider or model family. The conversation history includes the user's messages and the agent's text-based messages, but may be missing some automated messages and tool calls/tool call results. Examine the conversation carefully and be prepared to answer questions about it. +{% if conversation_truncated %} +Note: Earlier conversation messages were removed due to size constraints. Do not assume details about prior messages that are not visible. +{% endif %} Here is the conversation history between the user and the other agent. {% filter indent(width=2) %} ``` diff --git a/vet/imbue_tools/util_prompts/goal_from_conversation.py b/vet/imbue_tools/util_prompts/goal_from_conversation.py @@ -9,7 +9,9 @@ from vet.vet_types.messages import ConversationMessageUnion from vet.imbue_tools.get_conversation_history.get_conversation_history import ( format_conversation_history_for_prompt, ) -from vet.imbue_tools.util_prompts.conversation_prefix import CONVERSATION_PREFIX_TEMPLATE +from vet.imbue_tools.util_prompts.conversation_prefix import ( + CONVERSATION_PREFIX_TEMPLATE, +) # TODO: see how this does on actual examples where the agent did something other than what the user asked for PROMPT_TEMPLATE = ( @@ -35,7 +37,11 @@ def prompt_for_getting_goal_from_conversation( ) -> str: env = jinja2.Environment(undefined=jinja2.StrictUndefined) jinja_template = env.from_string(PROMPT_TEMPLATE) - return jinja_template.render(conversation_history=format_conversation_history_for_prompt(conversation_history)) + formatted_history, conversation_truncated = format_conversation_history_for_prompt(conversation_history) + return jinja_template.render( + conversation_history=formatted_history, + conversation_truncated=conversation_truncated, + ) def get_goal_from_conversation_with_usage( diff --git a/vet/issue_identifiers/harnesses/agentic.py b/vet/issue_identifiers/harnesses/agentic.py @@ -45,12 +45,16 @@ PROMPT_TEMPLATE = """You are analyzing a code repository for potential issues. T Assume that a user requested work to be done and a programmer delivered the diff below. The changes from the diff are present in the codebase but are not yet committed. - +{% if goal_truncated %} +Note: The user request was truncated. The full request may contain additional details not shown. +{% endif %} ### User request ### {% filter indent(width=2) %} {{ commit_message }} {% endfilter %} - +{% if diff_truncated %} +Note: The diff below was truncated due to size constraints. Do not assume details about code or context that is not visible. +{% endif %} ### Diff (lines starting with `-` indicate removed code, and lines starting with `+` indicate added code) ### {% filter indent(width=2) %} {{ unified_diff }} @@ -117,12 +121,16 @@ ISSUE_TYPE_PROMPT_TEMPLATE = """You are analyzing a code repository for potentia Assume that a user requested work to be done and a programmer delivered the diff below. The changes from the diff are present in the codebase but are not yet committed. - +{% if goal_truncated %} +Note: The user request was truncated. The full request may contain additional details not shown. +{% endif %} ### User request ### {% filter indent(width=2) %} {{ commit_message }} {% endfilter %} - +{% if diff_truncated %} +Note: The diff below was truncated due to size constraints. Do not assume details about code or context that is not visible. +{% endif %} ### Diff (lines starting with `-` indicate removed code, and lines starting with `+` indicate added code) ### {% filter indent(width=2) %} {{ unified_diff }} @@ -212,7 +220,9 @@ class _AgenticIssueIdentifier(IssueIdentifier[CommitInputs]): { "repo_path": project_context.repo_path, "commit_message": escape_prompt_markers(identifier_inputs.goal), + "goal_truncated": identifier_inputs.goal_truncated, "unified_diff": escape_prompt_markers(identifier_inputs.diff), + "diff_truncated": identifier_inputs.diff_truncated, "guides": formatted_guides, "response_schema": self._response_schema, "additional_guidance": additional_guidance_by_issue_code, @@ -235,7 +245,9 @@ class _AgenticIssueIdentifier(IssueIdentifier[CommitInputs]): { "repo_path": project_context.repo_path, "commit_message": escape_prompt_markers(identifier_inputs.goal), + "goal_truncated": identifier_inputs.goal_truncated, "unified_diff": escape_prompt_markers(identifier_inputs.diff), + "diff_truncated": identifier_inputs.diff_truncated, "guide": formatted_guide, "response_schema": self._response_schema, "issue_type": guide.issue_code, diff --git a/vet/issue_identifiers/harnesses/conversation_single_prompt.py b/vet/issue_identifiers/harnesses/conversation_single_prompt.py @@ -24,7 +24,9 @@ from vet.imbue_tools.get_conversation_history.get_conversation_history import ( from vet.imbue_tools.get_conversation_history.input_data_types import ConversationInputs from vet.imbue_tools.repo_utils.project_context import ProjectContext from vet.imbue_tools.types.vet_config import VetConfig -from vet.imbue_tools.util_prompts.conversation_prefix import CONVERSATION_PREFIX_TEMPLATE +from vet.imbue_tools.util_prompts.conversation_prefix import ( + CONVERSATION_PREFIX_TEMPLATE, +) from vet.issue_identifiers.base import IssueIdentifier from vet.issue_identifiers.common import GeneratedIssueSchema from vet.issue_identifiers.common import GeneratedResponseSchema @@ -39,6 +41,9 @@ from vet.issue_identifiers.harnesses.base import IssueIdentifierHarness from vet.issue_identifiers.identification_guides import ( IssueIdentificationGuide, ) +from vet.truncation import ContextBudget +from vet.truncation import get_available_tokens +from vet.truncation import get_token_budget PROMPT_TEMPLATE = ( CONVERSATION_PREFIX_TEMPLATE @@ -89,12 +94,23 @@ class _ConversationSinglePromptIssueIdentifier(IssueIdentifier[ConversationInput guide.issue_code: format_issue_identification_guide_for_llm(guide) for guide in sorted_guides } + lm_config = config.language_model_generation_config + available_tokens = get_available_tokens(config) + conversation_budget = get_token_budget(available_tokens, ContextBudget.CONVERSATION) + + conversation_history, conversation_truncated = format_conversation_history_for_prompt( + identifier_inputs.conversation_history, + max_tokens=conversation_budget, + count_tokens=lm_config.count_tokens, + ) + env = jinja2.Environment(undefined=jinja2.StrictUndefined) jinja_template = env.from_string(PROMPT_TEMPLATE) return jinja_template.render( cached_prompt_prefix=project_context.cached_prompt_prefix, cache_full_prompt=config.cache_full_prompt, - conversation_history=format_conversation_history_for_prompt(identifier_inputs.conversation_history), + conversation_history=conversation_history, + conversation_truncated=conversation_truncated or identifier_inputs.conversation_truncated, # pyre-fixme[16]: SubrepoContext need not have a formatted_repo_context, and instruction_context can be None instruction_context=project_context.instruction_context.formatted_repo_context, response_schema=self._response_schema, diff --git a/vet/issue_identifiers/harnesses/single_prompt.py b/vet/issue_identifiers/harnesses/single_prompt.py @@ -34,12 +34,18 @@ from vet.issue_identifiers.harnesses.base import IssueIdentifierHarness from vet.issue_identifiers.identification_guides import ( IssueIdentificationGuide, ) +from vet.truncation import ContextBudget +from vet.truncation import get_available_tokens +from vet.truncation import get_token_budget +from vet.truncation import truncate_to_token_limit USER_REQUEST_PREFIX_TEMPLATE = """{{cached_prompt_prefix}} [ROLE=USER_CACHED] I'm working on a project, adding commits one after another. The current state of the project is captured by the codebase snapshot above. {% if extra_context %} - +{% if extra_context_truncated %} +Note: Additional context was truncated due to size constraints. Do not assume details about content that is not visible. +{% endif %} === ADDITIONAL CONTEXT BEGIN === {{ extra_context }} === ADDITIONAL CONTEXT END === @@ -48,11 +54,15 @@ I'm working on a project, adding commits one after another. The current state of Assume that I asked for a piece of work to be done by specifying the user request and another programmer has delivered the diff. {% if include_request_and_diff %} Below, you can see the user request, as well as the delivered diff. IMPORTANT: The codebase snapshot already includes the changes made in this diff! - +{% if goal_truncated %} +Note: The user request was truncated. The full request may contain additional details not shown. +{% endif %} === USER REQUEST BEGIN === {{ commit_message }} === USER REQUEST END === - +{% if diff_truncated %} +Note: The diff below was truncated due to size constraints. Do not assume details about code or context that is not visible. +{% endif %} === DIFF BEGIN (unified; lines starting with `-` are removed and `+` are added) === {{ unified_diff }} === DIFF END === @@ -133,6 +143,32 @@ class _SinglePromptIssueIdentifier(IssueIdentifier[CommitInputs]): guide.issue_code: format_issue_identification_guide_for_llm(guide) for guide in sorted_guides } + lm_config = config.language_model_generation_config + available_tokens = get_available_tokens(config) + goal_budget = get_token_budget(available_tokens, ContextBudget.GOAL) + extra_context_budget = get_token_budget(available_tokens, ContextBudget.EXTRA_CONTEXT) + + goal, goal_truncated = truncate_to_token_limit( + identifier_inputs.goal, + max_tokens=goal_budget, + count_tokens=lm_config.count_tokens, + label="goal", + truncate_end=True, + ) + + extra_context = identifier_inputs.maybe_extra_context or "" + if extra_context: + extra_context, extra_context_truncated = truncate_to_token_limit( + extra_context, + max_tokens=extra_context_budget, + count_tokens=lm_config.count_tokens, + label="extra context", + truncate_end=True, + ) + extra_context_truncated = extra_context_truncated or identifier_inputs.extra_context_truncated + else: + extra_context_truncated = False + env = jinja2.Environment(undefined=jinja2.StrictUndefined) jinja_template = env.from_string(PROMPT_TEMPLATE) return jinja_template.render( @@ -140,9 +176,12 @@ class _SinglePromptIssueIdentifier(IssueIdentifier[CommitInputs]): "include_request_and_diff": True, "cached_prompt_prefix": project_context.cached_prompt_prefix, "cache_full_prompt": config.cache_full_prompt, - "extra_context": (escape_prompt_markers(config.extra_context) if config.extra_context else None), - "commit_message": escape_prompt_markers(identifier_inputs.goal), + "extra_context": (escape_prompt_markers(extra_context) if extra_context else None), + "extra_context_truncated": extra_context_truncated, + "commit_message": escape_prompt_markers(goal), + "goal_truncated": goal_truncated or identifier_inputs.goal_truncated, "unified_diff": escape_prompt_markers(identifier_inputs.diff), + "diff_truncated": identifier_inputs.diff_truncated, "guides": formatted_guides, "response_schema": self._response_schema, } diff --git a/vet/issue_identifiers/issue_evaluation.py b/vet/issue_identifiers/issue_evaluation.py @@ -25,7 +25,9 @@ from vet.imbue_tools.repo_utils.context_utils import escape_prompt_markers from vet.imbue_tools.repo_utils.project_context import ProjectContext from vet.imbue_tools.types.vet_config import DEFAULT_CONFIDENCE_THRESHOLD from vet.imbue_tools.types.vet_config import VetConfig -from vet.imbue_tools.util_prompts.conversation_prefix import CONVERSATION_PREFIX_TEMPLATE +from vet.imbue_tools.util_prompts.conversation_prefix import ( + CONVERSATION_PREFIX_TEMPLATE, +) from vet.issue_identifiers.common import GeneratedIssueSchema from vet.issue_identifiers.common import ( extract_invocation_info_from_costed_response, @@ -40,6 +42,9 @@ from vet.issue_identifiers.identification_guides import ( ISSUE_IDENTIFICATION_GUIDES_BY_ISSUE_CODE, ) from vet.issue_identifiers.utils import ReturnCapturingGenerator +from vet.truncation import ContextBudget +from vet.truncation import get_available_tokens +from vet.truncation import get_token_budget CODE_BASED_CRITERIA = ( "1. The issue is based on specific code, and not merely on the absence of information in the codebase snapshot. (true/false)", @@ -145,12 +150,25 @@ def _format_prompt( if is_code_based_issue: template_vars["include_request_and_diff"] = True template_vars["commit_message"] = escape_prompt_markers(inputs.maybe_goal or "") + template_vars["goal_truncated"] = inputs.goal_truncated template_vars["unified_diff"] = escape_prompt_markers(inputs.maybe_diff or "") - template_vars["extra_context"] = escape_prompt_markers(config.extra_context) if config.extra_context else None + template_vars["diff_truncated"] = inputs.diff_truncated + template_vars["extra_context"] = ( + escape_prompt_markers(inputs.maybe_extra_context) if inputs.maybe_extra_context else None + ) + template_vars["extra_context_truncated"] = inputs.extra_context_truncated else: - template_vars["conversation_history"] = format_conversation_history_for_prompt( - inputs.maybe_conversation_history or () + lm_config = config.language_model_generation_config + available_tokens = get_available_tokens(config) + conversation_budget = get_token_budget(available_tokens, ContextBudget.CONVERSATION) + + conversation_history, conversation_truncated = format_conversation_history_for_prompt( + inputs.maybe_conversation_history or (), + max_tokens=conversation_budget, + count_tokens=lm_config.count_tokens, ) + template_vars["conversation_history"] = conversation_history + template_vars["conversation_truncated"] = conversation_truncated or inputs.conversation_truncated return jinja_template.render(template_vars) diff --git a/vet/truncation.py b/vet/truncation.py @@ -0,0 +1,112 @@ +from enum import Enum +from typing import Callable + +from loguru import logger + +from vet.repo_utils import VET_MAX_PROMPT_TOKENS + +from vet.imbue_tools.types.vet_config import VetConfig + + +class ContextBudget(Enum): + REPO_CONTEXT = 50 + DIFF = 30 + CONVERSATION = 10 + EXTRA_CONTEXT = 6 + GOAL = 4 + + +def get_token_budget(total_available: int, budget: ContextBudget) -> int: + return int(total_available * budget.value / 100) + + +def get_available_tokens(config: "VetConfig") -> int: + lm_config = config.language_model_generation_config + context_window = lm_config.get_max_context_length() + return context_window - VET_MAX_PROMPT_TOKENS - config.max_output_tokens + + +def truncate_to_token_limit( + text: str, + max_tokens: int, + count_tokens: Callable[[str], int], + label: str, + truncate_end: bool = True, +) -> tuple[str, bool]: + if not text: + return text, False + + if max_tokens <= 0: + logger.warning("{} budget is zero or negative, returning empty string", label.capitalize()) + return "", True + + token_count = count_tokens(text) + if token_count <= max_tokens: + return text, False + + logger.warning( + "{} exceeds token limit ({} > {}), truncating", + label.capitalize(), + token_count, + max_tokens, + ) + + if truncate_end: + truncated = _find_truncation_point_from_end(text, max_tokens, count_tokens) + else: + truncated = _find_truncation_point_from_start(text, max_tokens, count_tokens) + + return truncated, True + + +def _find_truncation_point_from_end( + text: str, + max_tokens: int, + count_tokens: Callable[[str], int], +) -> str: + char_estimate = min(max_tokens * 4, len(text)) + + low, high = 0, char_estimate + result = "" + + if high < len(text) and count_tokens(text[:high]) <= max_tokens: + low = high + high = len(text) + + while low <= high: + mid = (low + high) // 2 + candidate = text[:mid] + if count_tokens(candidate) <= max_tokens: + result = candidate + low = mid + 1 + else: + high = mid - 1 + + return result + + +def _find_truncation_point_from_start( + text: str, + max_tokens: int, + count_tokens: Callable[[str], int], +) -> str: + char_estimate = min(max_tokens * 4, len(text)) + start_estimate = max(0, len(text) - char_estimate) + + low, high = 0, start_estimate + result = text[start_estimate:] + + if count_tokens(result) > max_tokens: + low = start_estimate + high = len(text) + + while low <= high: + mid = (low + high) // 2 + candidate = text[mid:] + if count_tokens(candidate) <= max_tokens: + result = candidate + high = mid - 1 + else: + low = mid + 1 + + return result diff --git a/vet/truncation_test.py b/vet/truncation_test.py @@ -0,0 +1,224 @@ +from typing import Callable + +import tiktoken +from hypothesis import given, settings, strategies as st, assume + +from vet.truncation import ContextBudget, get_token_budget, truncate_to_token_limit + + +def word_count(text: str) -> int: + return len(text.split()) + + +def char_count(text: str) -> int: + return len(text) + + +def char_div4_count(text: str) -> int: + return len(text) // 4 + 1 if text else 0 + + +_tiktoken_encoder = tiktoken.get_encoding("cl100k_base") + + +def tiktoken_count(text: str) -> int: + return len(_tiktoken_encoder.encode(text)) + + +SIMPLE_TOKEN_COUNTERS: list[Callable[[str], int]] = [ + word_count, + char_count, + char_div4_count, +] + +ALL_TOKEN_COUNTERS: list[Callable[[str], int]] = SIMPLE_TOKEN_COUNTERS + [ + tiktoken_count +] + +ascii_text = st.text( + alphabet=st.characters(min_codepoint=32, max_codepoint=126), + min_size=0, + max_size=1000, +) + +unicode_text = st.text( + alphabet=st.characters( + min_codepoint=32, + max_codepoint=0xFFFF, + blacklist_categories=("Cs",), + ), + min_size=0, + max_size=500, +) + +code_text = st.from_regex( + r"[a-zA-Z_][a-zA-Z0-9_]{0,20}(\s*[+\-*/=<>!&|]+\s*[a-zA-Z_][a-zA-Z0-9_]{0,20}){0,10}", + fullmatch=True, +) + +repeated_char_text = st.builds( + lambda char, count: char * count, + char=st.characters(min_codepoint=32, max_codepoint=126), + count=st.integers(min_value=0, max_value=500), +) + +mixed_text = st.builds( + lambda code, comment, uni: f"{code} // {comment}\n{uni}", + code=st.from_regex(r"[a-z_][a-z0-9_]{0,10}", fullmatch=True), + comment=ascii_text, + uni=st.text(min_size=0, max_size=100), +) + + +def test_context_budgets_sum_to_100(): + total = sum(budget.value for budget in ContextBudget) + assert total == 100, f"ContextBudget values must sum to 100, got {total}" + + +@given( + total_tokens=st.integers(min_value=0, max_value=1_000_000), + budget=st.sampled_from(list(ContextBudget)), +) +def test_get_token_budget_is_mathematically_correct( + total_tokens: int, budget: ContextBudget +): + result = get_token_budget(total_tokens, budget) + expected = total_tokens * budget.value // 100 + assert result == expected + + +@given( + text=st.text(min_size=0, max_size=10_000), + max_tokens=st.integers(min_value=0, max_value=10_000), + truncate_end=st.booleans(), + count_tokens=st.sampled_from(SIMPLE_TOKEN_COUNTERS), +) +def test_truncate_always_respects_token_limit_simple( + text: str, max_tokens: int, truncate_end: bool, count_tokens +): + result, _ = truncate_to_token_limit( + text, + max_tokens=max_tokens, + count_tokens=count_tokens, + label="test", + truncate_end=truncate_end, + ) + + assert count_tokens(result) <= max_tokens, ( + f"Token limit violated: got {count_tokens(result)} > {max_tokens} " + f"(counter={count_tokens.__name__}, truncate_end={truncate_end})" + ) + + +@given( + text=st.text(min_size=1, max_size=1000), + max_tokens=st.integers(min_value=1, max_value=100), +) +def test_truncate_end_produces_prefix(text: str, max_tokens: int): + result, was_truncated = truncate_to_token_limit( + text, + max_tokens=max_tokens, + count_tokens=char_count, + label="test", + truncate_end=True, + ) + + assert text.startswith(result), ( + f"Result '{result[:50]}...' is not a prefix of original" + ) + + if was_truncated: + assert len(result) < len(text) + + +@given( + text=st.text(min_size=1, max_size=1000), + max_tokens=st.integers(min_value=1, max_value=100), +) +def test_truncate_start_produces_suffix(text: str, max_tokens: int): + result, was_truncated = truncate_to_token_limit( + text, + max_tokens=max_tokens, + count_tokens=char_count, + label="test", + truncate_end=False, + ) + + assert text.endswith(result), ( + f"Result '...{result[-50:]}' is not a suffix of original" + ) + + if was_truncated: + assert len(result) < len(text) + + +@given( + text=st.text(min_size=0, max_size=1000), + budget_multiplier=st.integers(min_value=1, max_value=10), + count_tokens=st.sampled_from(SIMPLE_TOKEN_COUNTERS), +) +def test_text_within_budget_unchanged(text: str, budget_multiplier: int, count_tokens): + token_count = count_tokens(text) + + max_tokens = token_count * budget_multiplier + 10 + + result, was_truncated = truncate_to_token_limit( + text, + max_tokens=max_tokens, + count_tokens=count_tokens, + label="test", + ) + + assert result == text + assert was_truncated is False + + +@given(count_tokens=st.sampled_from(SIMPLE_TOKEN_COUNTERS)) +def test_empty_text_always_returns_empty(count_tokens): + result, was_truncated = truncate_to_token_limit( + "", + max_tokens=100, + count_tokens=count_tokens, + label="test", + ) + + assert result == "" + assert was_truncated is False + + +@given( + text=st.text(min_size=1, max_size=1000), + count_tokens=st.sampled_from(SIMPLE_TOKEN_COUNTERS), +) +def test_zero_budget_returns_empty_and_truncated(text: str, count_tokens): + assume(count_tokens(text) > 0) + + result, was_truncated = truncate_to_token_limit( + text, + max_tokens=0, + count_tokens=count_tokens, + label="test", + ) + + assert result == "" + assert was_truncated is True + + +@settings(max_examples=100) # 20 per strategy * 5 strategies +@given( + text=st.one_of(ascii_text, unicode_text, code_text, repeated_char_text, mixed_text), + max_tokens=st.integers(min_value=0, max_value=500), + truncate_end=st.booleans(), +) +def test_truncate_respects_limit_tiktoken( + text: str, max_tokens: int, truncate_end: bool +): + result, _ = truncate_to_token_limit( + text, + max_tokens=max_tokens, + count_tokens=tiktoken_count, + label="test", + truncate_end=truncate_end, + ) + + assert tiktoken_count(result) <= max_tokens