vet

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

commit 11aec54c2ec3757e78f21441dda886ef118416b8
parent 551ef4753c2cf429d2ace51ce1cd3835364a7cd9
Author: andrewlaack-collab <andrew.laack@imbue.com>
Date:   Thu, 12 Feb 2026 20:34:40 +0000

Undo restrictive safeguards on context (#71)


Diffstat:
Mpyproject.toml | 1-
Muv.lock | 23-----------------------
Mvet/api.py | 17-----------------
Mvet/imbue_tools/get_conversation_history/get_conversation_history.py | 21++-------------------
Mvet/imbue_tools/get_conversation_history/input_data_types.py | 10++--------
Mvet/imbue_tools/util_prompts/conversation_prefix.py | 3---
Mvet/imbue_tools/util_prompts/goal_from_conversation.py | 6+-----
Mvet/issue_identifiers/harnesses/agentic.py | 20++++----------------
Mvet/issue_identifiers/harnesses/conversation_single_prompt.py | 12+-----------
Mvet/issue_identifiers/harnesses/single_prompt.py | 48+++++-------------------------------------------
Mvet/issue_identifiers/issue_evaluation.py | 18++----------------
Dvet/truncation.py | 111-------------------------------------------------------------------------------
Dvet/truncation_test.py | 217-------------------------------------------------------------------------------
13 files changed, 17 insertions(+), 490 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml @@ -72,7 +72,6 @@ include = ["vet*"] [dependency-groups] dev = [ "black", - "hypothesis", "isort", "pytest", "syrupy", diff --git a/uv.lock b/uv.lock @@ -480,18 +480,6 @@ wheels = [ ] [[package]] -name = "hypothesis" -version = "6.151.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "sortedcontainers" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/47/03/9fd03d5db09029250e69745c1600edab16fe90947636f77a12ba92d79939/hypothesis-6.151.4.tar.gz", hash = "sha256:658a62da1c3ccb36746ac2f7dc4bb1a6e76bd314e0dc54c4e1aaba2503d5545c", size = 475706, upload-time = "2026-01-29T01:30:14.985Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9b/6d/01ad1b6c3b8cb2bb47eeaa9765dabc27cbe68e3b59f6cff83d5668f57780/hypothesis-6.151.4-py3-none-any.whl", hash = "sha256:a1cf7e0fdaa296d697a68ff3c0b3912c0050f07aa37e7d2ff33a966749d1d9b4", size = 543146, upload-time = "2026-01-29T01:30:12.805Z" }, -] - -[[package]] name = "idna" version = "3.11" source = { registry = "https://pypi.org/simple" } @@ -1331,15 +1319,6 @@ wheels = [ ] [[package]] -name = "sortedcontainers" -version = "2.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload-time = "2021-05-16T22:03:42.897Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, -] - -[[package]] name = "syrupy" version = "5.1.0" source = { registry = "https://pypi.org/simple" } @@ -1563,7 +1542,6 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "black" }, - { name = "hypothesis" }, { name = "isort" }, { name = "pytest" }, { name = "syrupy" }, @@ -1599,7 +1577,6 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "black" }, - { name = "hypothesis" }, { name = "isort" }, { name = "pytest" }, { name = "syrupy" }, diff --git a/vet/api.py b/vet/api.py @@ -20,10 +20,6 @@ from vet.issue_identifiers import registry from vet.issue_identifiers.utils import ReturnCapturingGenerator from vet.repo_utils import VET_MAX_PROMPT_TOKENS from vet.repo_utils import get_code_to_check -from vet.truncation import ContextBudget -from vet.truncation import get_available_tokens -from vet.truncation import get_token_budget -from vet.truncation import truncate_to_token_limit from vet.vet_types.messages import ConversationMessageUnion @@ -55,22 +51,10 @@ def get_issues_with_raw_responses( goal = "" lm_config = config.language_model_generation_config - - available_tokens = get_available_tokens(config) - diff_budget = get_token_budget(available_tokens, ContextBudget.DIFF) - if diff_no_binary: - diff_no_binary, diff_truncated = truncate_to_token_limit( - diff_no_binary, - max_tokens=diff_budget, - count_tokens=lm_config.count_tokens, - label="diff", - truncate_end=True, - ) diff_no_binary_tokens = lm_config.count_tokens(diff_no_binary) else: diff_no_binary_tokens = 0 - diff_truncated = False project_context = LazyProjectContext.build( base_commit, @@ -93,7 +77,6 @@ def get_issues_with_raw_responses( maybe_diff=diff_no_binary or None, maybe_goal=goal, maybe_conversation_history=conversation_history, - diff_truncated=diff_truncated, maybe_extra_context=extra_context, ) diff --git a/vet/imbue_tools/get_conversation_history/get_conversation_history.py b/vet/imbue_tools/get_conversation_history/get_conversation_history.py @@ -1,12 +1,10 @@ import json -from typing import Callable from typing import assert_never from loguru import logger from pydantic import TypeAdapter from pydantic import ValidationError -from vet.truncation import truncate_to_token_limit from vet.vet_types.chat_state import ContentBlockTypes from vet.vet_types.messages import ChatInputUserMessage from vet.vet_types.messages import ConversationMessageUnion @@ -48,24 +46,9 @@ def delete_unnecessary_conversation_message_fields( def format_conversation_history_for_prompt( conversation_history: tuple[ConversationMessageUnion, ...], - max_tokens: int | None = None, - count_tokens: Callable[[str], int] | None = None, -) -> tuple[str, bool]: +) -> str: formatted_messages = [delete_unnecessary_conversation_message_fields(message) for message in conversation_history] - result = "\n".join(message for message in formatted_messages if message is not None) - - if max_tokens is not None and count_tokens is not None: - - result, was_truncated = truncate_to_token_limit( - result, - max_tokens=max_tokens, - count_tokens=count_tokens, - label="conversation history", - truncate_end=False, - ) - return result, was_truncated - - return result, False + return "\n".join(message for message in formatted_messages if message is not None) # === loading from file === diff --git a/vet/imbue_tools/get_conversation_history/input_data_types.py b/vet/imbue_tools/get_conversation_history/input_data_types.py @@ -12,24 +12,18 @@ class IdentifierInputsMissingError(Exception): class IdentifierInputs(SerializableModel): - # goal + # goal (for now, commit message) and diff to check maybe_goal: str | None = None - goal_truncated: bool = False - - # diff maybe_diff: str | None = None - diff_truncated: bool = False - # files to check + # whole files to check maybe_files: tuple[str, ...] | None = None # conversation history to check maybe_conversation_history: tuple[ConversationMessageUnion, ...] | None = None - conversation_truncated: bool = False # additional user supplied context maybe_extra_context: str | None = None - extra_context_truncated: bool = False class CommitInputs(IdentifierInputs): diff --git a/vet/imbue_tools/util_prompts/conversation_prefix.py b/vet/imbue_tools/util_prompts/conversation_prefix.py @@ -2,9 +2,6 @@ CONVERSATION_PREFIX_TEMPLATE = """[ROLE=SYSTEM_CACHED] You will be provided a conversation history between a user and another agent. The other agent may be from any model provider or model family. The conversation history includes the user's messages and the agent's text-based messages, but may be missing some automated messages and tool calls/tool call results. Examine the conversation carefully and be prepared to answer questions about it. -{% if conversation_truncated %} -Note: Earlier conversation messages were removed due to size constraints. Do not assume details about prior messages that are not visible. -{% endif %} Note: This conversation is being analyzed while still in progress. The agent's final messages may reference actions it is currently performing (such as running verification tools). Do not treat these as completed claims — the results may not yet be visible because the action is still executing at the time of this analysis. Here is the conversation history between the user and the other agent. {% filter indent(width=2) %} diff --git a/vet/imbue_tools/util_prompts/goal_from_conversation.py b/vet/imbue_tools/util_prompts/goal_from_conversation.py @@ -33,11 +33,7 @@ def prompt_for_getting_goal_from_conversation( ) -> str: env = jinja2.Environment(undefined=jinja2.StrictUndefined) jinja_template = env.from_string(PROMPT_TEMPLATE) - formatted_history, conversation_truncated = format_conversation_history_for_prompt(conversation_history) - return jinja_template.render( - conversation_history=formatted_history, - conversation_truncated=conversation_truncated, - ) + return jinja_template.render(conversation_history=format_conversation_history_for_prompt(conversation_history)) def get_goal_from_conversation_with_usage( diff --git a/vet/issue_identifiers/harnesses/agentic.py b/vet/issue_identifiers/harnesses/agentic.py @@ -41,16 +41,12 @@ PROMPT_TEMPLATE = """You are analyzing a code repository for potential issues. T Assume that a user requested work to be done and a programmer delivered the diff below. The changes from the diff are present in the codebase but are not yet committed. -{% if goal_truncated %} -Note: The user request was truncated. The full request may contain additional details not shown. -{% endif %} + ### User request ### {% filter indent(width=2) %} {{ commit_message }} {% endfilter %} -{% if diff_truncated %} -Note: The diff below was truncated due to size constraints. Do not assume details about code or context that is not visible. -{% endif %} + ### Diff (lines starting with `-` indicate removed code, and lines starting with `+` indicate added code) ### {% filter indent(width=2) %} {{ unified_diff }} @@ -117,16 +113,12 @@ ISSUE_TYPE_PROMPT_TEMPLATE = """You are analyzing a code repository for potentia Assume that a user requested work to be done and a programmer delivered the diff below. The changes from the diff are present in the codebase but are not yet committed. -{% if goal_truncated %} -Note: The user request was truncated. The full request may contain additional details not shown. -{% endif %} + ### User request ### {% filter indent(width=2) %} {{ commit_message }} {% endfilter %} -{% if diff_truncated %} -Note: The diff below was truncated due to size constraints. Do not assume details about code or context that is not visible. -{% endif %} + ### Diff (lines starting with `-` indicate removed code, and lines starting with `+` indicate added code) ### {% filter indent(width=2) %} {{ unified_diff }} @@ -216,9 +208,7 @@ class _AgenticIssueIdentifier(IssueIdentifier[CommitInputs]): { "repo_path": project_context.repo_path, "commit_message": escape_prompt_markers(identifier_inputs.goal), - "goal_truncated": identifier_inputs.goal_truncated, "unified_diff": escape_prompt_markers(identifier_inputs.diff), - "diff_truncated": identifier_inputs.diff_truncated, "guides": formatted_guides, "response_schema": self._response_schema, "additional_guidance": additional_guidance_by_issue_code, @@ -241,9 +231,7 @@ class _AgenticIssueIdentifier(IssueIdentifier[CommitInputs]): { "repo_path": project_context.repo_path, "commit_message": escape_prompt_markers(identifier_inputs.goal), - "goal_truncated": identifier_inputs.goal_truncated, "unified_diff": escape_prompt_markers(identifier_inputs.diff), - "diff_truncated": identifier_inputs.diff_truncated, "guide": formatted_guide, "response_schema": self._response_schema, "issue_type": guide.issue_code, diff --git a/vet/issue_identifiers/harnesses/conversation_single_prompt.py b/vet/issue_identifiers/harnesses/conversation_single_prompt.py @@ -31,9 +31,6 @@ from vet.issue_identifiers.common import format_issue_identification_guide_for_l from vet.issue_identifiers.common import generate_issues_from_response_texts from vet.issue_identifiers.harnesses.base import IssueIdentifierHarness from vet.issue_identifiers.identification_guides import IssueIdentificationGuide -from vet.truncation import ContextBudget -from vet.truncation import get_available_tokens -from vet.truncation import get_token_budget PROMPT_TEMPLATE = ( CONVERSATION_PREFIX_TEMPLATE @@ -84,14 +81,8 @@ class _ConversationSinglePromptIssueIdentifier(IssueIdentifier[ConversationInput guide.issue_code: format_issue_identification_guide_for_llm(guide) for guide in sorted_guides } - lm_config = config.language_model_generation_config - available_tokens = get_available_tokens(config) - conversation_budget = get_token_budget(available_tokens, ContextBudget.CONVERSATION) - - conversation_history, conversation_truncated = format_conversation_history_for_prompt( + conversation_history = format_conversation_history_for_prompt( identifier_inputs.conversation_history, - max_tokens=conversation_budget, - count_tokens=lm_config.count_tokens, ) env = jinja2.Environment(undefined=jinja2.StrictUndefined) @@ -100,7 +91,6 @@ class _ConversationSinglePromptIssueIdentifier(IssueIdentifier[ConversationInput cached_prompt_prefix=project_context.cached_prompt_prefix, cache_full_prompt=config.cache_full_prompt, conversation_history=conversation_history, - conversation_truncated=conversation_truncated or identifier_inputs.conversation_truncated, # pyre-fixme[16]: SubrepoContext need not have a formatted_repo_context, and instruction_context can be None instruction_context=project_context.instruction_context.formatted_repo_context, response_schema=self._response_schema, diff --git a/vet/issue_identifiers/harnesses/single_prompt.py b/vet/issue_identifiers/harnesses/single_prompt.py @@ -28,18 +28,12 @@ from vet.issue_identifiers.common import format_issue_identification_guide_for_l from vet.issue_identifiers.common import generate_issues_from_response_texts from vet.issue_identifiers.harnesses.base import IssueIdentifierHarness from vet.issue_identifiers.identification_guides import IssueIdentificationGuide -from vet.truncation import ContextBudget -from vet.truncation import get_available_tokens -from vet.truncation import get_token_budget -from vet.truncation import truncate_to_token_limit USER_REQUEST_PREFIX_TEMPLATE = """{{cached_prompt_prefix}} [ROLE=USER_CACHED] I'm working on a project, adding commits one after another. The current state of the project is captured by the codebase snapshot above. {% if extra_context %} -{% if extra_context_truncated %} -Note: Additional context was truncated due to size constraints. Do not assume details about content that is not visible. -{% endif %} + === ADDITIONAL CONTEXT BEGIN === {{ extra_context }} === ADDITIONAL CONTEXT END === @@ -48,15 +42,11 @@ Note: Additional context was truncated due to size constraints. Do not assume de Assume that I asked for a piece of work to be done by specifying the user request and another programmer has delivered the diff. {% if include_request_and_diff %} Below, you can see the user request, as well as the delivered diff. IMPORTANT: The codebase snapshot already includes the changes made in this diff! -{% if goal_truncated %} -Note: The user request was truncated. The full request may contain additional details not shown. -{% endif %} + === USER REQUEST BEGIN === {{ commit_message }} === USER REQUEST END === -{% if diff_truncated %} -Note: The diff below was truncated due to size constraints. Do not assume details about code or context that is not visible. -{% endif %} + === DIFF BEGIN (unified; lines starting with `-` are removed and `+` are added) === {{ unified_diff }} === DIFF END === @@ -137,45 +127,17 @@ class _SinglePromptIssueIdentifier(IssueIdentifier[CommitInputs]): guide.issue_code: format_issue_identification_guide_for_llm(guide) for guide in sorted_guides } - lm_config = config.language_model_generation_config - available_tokens = get_available_tokens(config) - goal_budget = get_token_budget(available_tokens, ContextBudget.GOAL) - extra_context_budget = get_token_budget(available_tokens, ContextBudget.EXTRA_CONTEXT) - - goal, goal_truncated = truncate_to_token_limit( - identifier_inputs.goal, - max_tokens=goal_budget, - count_tokens=lm_config.count_tokens, - label="goal", - truncate_end=True, - ) - - extra_context = identifier_inputs.maybe_extra_context or "" - if extra_context: - extra_context, extra_context_truncated = truncate_to_token_limit( - extra_context, - max_tokens=extra_context_budget, - count_tokens=lm_config.count_tokens, - label="extra context", - truncate_end=True, - ) - extra_context_truncated = extra_context_truncated or identifier_inputs.extra_context_truncated - else: - extra_context_truncated = False - env = jinja2.Environment(undefined=jinja2.StrictUndefined) jinja_template = env.from_string(PROMPT_TEMPLATE) + extra_context = identifier_inputs.maybe_extra_context return jinja_template.render( { "include_request_and_diff": True, "cached_prompt_prefix": project_context.cached_prompt_prefix, "cache_full_prompt": config.cache_full_prompt, "extra_context": (escape_prompt_markers(extra_context) if extra_context else None), - "extra_context_truncated": extra_context_truncated, - "commit_message": escape_prompt_markers(goal), - "goal_truncated": goal_truncated or identifier_inputs.goal_truncated, + "commit_message": escape_prompt_markers(identifier_inputs.goal), "unified_diff": escape_prompt_markers(identifier_inputs.diff), - "diff_truncated": identifier_inputs.diff_truncated, "guides": formatted_guides, "response_schema": self._response_schema, } diff --git a/vet/issue_identifiers/issue_evaluation.py b/vet/issue_identifiers/issue_evaluation.py @@ -26,9 +26,6 @@ from vet.issue_identifiers.common import format_issue_identification_guide_for_l from vet.issue_identifiers.harnesses.single_prompt import USER_REQUEST_PREFIX_TEMPLATE from vet.issue_identifiers.identification_guides import IssueIdentificationGuide from vet.issue_identifiers.utils import ReturnCapturingGenerator -from vet.truncation import ContextBudget -from vet.truncation import get_available_tokens -from vet.truncation import get_token_budget CODE_BASED_CRITERIA = ( "1. The issue is based on specific code, and not merely on the absence of information in the codebase snapshot. (true/false)", @@ -134,25 +131,14 @@ def _format_prompt( if is_code_based_issue: template_vars["include_request_and_diff"] = True template_vars["commit_message"] = escape_prompt_markers(inputs.maybe_goal or "") - template_vars["goal_truncated"] = inputs.goal_truncated template_vars["unified_diff"] = escape_prompt_markers(inputs.maybe_diff or "") - template_vars["diff_truncated"] = inputs.diff_truncated template_vars["extra_context"] = ( escape_prompt_markers(inputs.maybe_extra_context) if inputs.maybe_extra_context else None ) - template_vars["extra_context_truncated"] = inputs.extra_context_truncated else: - lm_config = config.language_model_generation_config - available_tokens = get_available_tokens(config) - conversation_budget = get_token_budget(available_tokens, ContextBudget.CONVERSATION) - - conversation_history, conversation_truncated = format_conversation_history_for_prompt( - inputs.maybe_conversation_history or (), - max_tokens=conversation_budget, - count_tokens=lm_config.count_tokens, + template_vars["conversation_history"] = format_conversation_history_for_prompt( + inputs.maybe_conversation_history or () ) - template_vars["conversation_history"] = conversation_history - template_vars["conversation_truncated"] = conversation_truncated or inputs.conversation_truncated return jinja_template.render(template_vars) diff --git a/vet/truncation.py b/vet/truncation.py @@ -1,111 +0,0 @@ -from enum import Enum -from typing import Callable - -from loguru import logger - -from vet.imbue_tools.types.vet_config import VetConfig -from vet.repo_utils import VET_MAX_PROMPT_TOKENS - - -class ContextBudget(Enum): - REPO_CONTEXT = 50 - DIFF = 30 - CONVERSATION = 10 - EXTRA_CONTEXT = 6 - GOAL = 4 - - -def get_token_budget(total_available: int, budget: ContextBudget) -> int: - return int(total_available * budget.value / 100) - - -def get_available_tokens(config: "VetConfig") -> int: - lm_config = config.language_model_generation_config - context_window = lm_config.get_max_context_length() - return context_window - VET_MAX_PROMPT_TOKENS - config.max_output_tokens - - -def truncate_to_token_limit( - text: str, - max_tokens: int, - count_tokens: Callable[[str], int], - label: str, - truncate_end: bool = True, -) -> tuple[str, bool]: - if not text: - return text, False - - if max_tokens <= 0: - logger.warning("{} budget is zero or negative, returning empty string", label.capitalize()) - return "", True - - token_count = count_tokens(text) - if token_count <= max_tokens: - return text, False - - logger.warning( - "{} exceeds token limit ({} > {}), truncating", - label.capitalize(), - token_count, - max_tokens, - ) - - if truncate_end: - truncated = _find_truncation_point_from_end(text, max_tokens, count_tokens) - else: - truncated = _find_truncation_point_from_start(text, max_tokens, count_tokens) - - return truncated, True - - -def _find_truncation_point_from_end( - text: str, - max_tokens: int, - count_tokens: Callable[[str], int], -) -> str: - char_estimate = min(max_tokens * 4, len(text)) - - low, high = 0, char_estimate - result = "" - - if high < len(text) and count_tokens(text[:high]) <= max_tokens: - low = high - high = len(text) - - while low <= high: - mid = (low + high) // 2 - candidate = text[:mid] - if count_tokens(candidate) <= max_tokens: - result = candidate - low = mid + 1 - else: - high = mid - 1 - - return result - - -def _find_truncation_point_from_start( - text: str, - max_tokens: int, - count_tokens: Callable[[str], int], -) -> str: - char_estimate = min(max_tokens * 4, len(text)) - start_estimate = max(0, len(text) - char_estimate) - - low, high = 0, start_estimate - result = text[start_estimate:] - - if count_tokens(result) > max_tokens: - low = start_estimate - high = len(text) - - while low <= high: - mid = (low + high) // 2 - candidate = text[mid:] - if count_tokens(candidate) <= max_tokens: - result = candidate - high = mid - 1 - else: - low = mid + 1 - - return result diff --git a/vet/truncation_test.py b/vet/truncation_test.py @@ -1,217 +0,0 @@ -from typing import Callable - -import tiktoken -from hypothesis import assume -from hypothesis import given -from hypothesis import settings -from hypothesis import strategies as st - -from vet.truncation import ContextBudget -from vet.truncation import get_token_budget -from vet.truncation import truncate_to_token_limit - - -def word_count(text: str) -> int: - return len(text.split()) - - -def char_count(text: str) -> int: - return len(text) - - -def char_div4_count(text: str) -> int: - return len(text) // 4 + 1 if text else 0 - - -_tiktoken_encoder = tiktoken.get_encoding("cl100k_base") - - -def tiktoken_count(text: str) -> int: - return len(_tiktoken_encoder.encode(text)) - - -SIMPLE_TOKEN_COUNTERS: list[Callable[[str], int]] = [ - word_count, - char_count, - char_div4_count, -] - -ALL_TOKEN_COUNTERS: list[Callable[[str], int]] = SIMPLE_TOKEN_COUNTERS + [tiktoken_count] - -ascii_text = st.text( - alphabet=st.characters(min_codepoint=32, max_codepoint=126), - min_size=0, - max_size=1000, -) - -unicode_text = st.text( - alphabet=st.characters( - min_codepoint=32, - max_codepoint=0xFFFF, - blacklist_categories=("Cs",), - ), - min_size=0, - max_size=500, -) - -code_text = st.from_regex( - r"[a-zA-Z_][a-zA-Z0-9_]{0,20}(\s*[+\-*/=<>!&|]+\s*[a-zA-Z_][a-zA-Z0-9_]{0,20}){0,10}", - fullmatch=True, -) - -repeated_char_text = st.builds( - lambda char, count: char * count, - char=st.characters(min_codepoint=32, max_codepoint=126), - count=st.integers(min_value=0, max_value=500), -) - -mixed_text = st.builds( - lambda code, comment, uni: f"{code} // {comment}\n{uni}", - code=st.from_regex(r"[a-z_][a-z0-9_]{0,10}", fullmatch=True), - comment=ascii_text, - uni=st.text(min_size=0, max_size=100), -) - - -def test_context_budgets_sum_to_100(): - total = sum(budget.value for budget in ContextBudget) - assert total == 100, f"ContextBudget values must sum to 100, got {total}" - - -@given( - total_tokens=st.integers(min_value=0, max_value=1_000_000), - budget=st.sampled_from(list(ContextBudget)), -) -def test_get_token_budget_is_mathematically_correct(total_tokens: int, budget: ContextBudget): - result = get_token_budget(total_tokens, budget) - expected = total_tokens * budget.value // 100 - assert result == expected - - -@given( - text=st.text(min_size=0, max_size=10_000), - max_tokens=st.integers(min_value=0, max_value=10_000), - truncate_end=st.booleans(), - count_tokens=st.sampled_from(SIMPLE_TOKEN_COUNTERS), -) -def test_truncate_always_respects_token_limit_simple(text: str, max_tokens: int, truncate_end: bool, count_tokens): - result, _ = truncate_to_token_limit( - text, - max_tokens=max_tokens, - count_tokens=count_tokens, - label="test", - truncate_end=truncate_end, - ) - - assert count_tokens(result) <= max_tokens, ( - f"Token limit violated: got {count_tokens(result)} > {max_tokens} " - f"(counter={count_tokens.__name__}, truncate_end={truncate_end})" - ) - - -@given( - text=st.text(min_size=1, max_size=1000), - max_tokens=st.integers(min_value=1, max_value=100), -) -def test_truncate_end_produces_prefix(text: str, max_tokens: int): - result, was_truncated = truncate_to_token_limit( - text, - max_tokens=max_tokens, - count_tokens=char_count, - label="test", - truncate_end=True, - ) - - assert text.startswith(result), f"Result '{result[:50]}...' is not a prefix of original" - - if was_truncated: - assert len(result) < len(text) - - -@given( - text=st.text(min_size=1, max_size=1000), - max_tokens=st.integers(min_value=1, max_value=100), -) -def test_truncate_start_produces_suffix(text: str, max_tokens: int): - result, was_truncated = truncate_to_token_limit( - text, - max_tokens=max_tokens, - count_tokens=char_count, - label="test", - truncate_end=False, - ) - - assert text.endswith(result), f"Result '...{result[-50:]}' is not a suffix of original" - - if was_truncated: - assert len(result) < len(text) - - -@given( - text=st.text(min_size=0, max_size=1000), - budget_multiplier=st.integers(min_value=1, max_value=10), - count_tokens=st.sampled_from(SIMPLE_TOKEN_COUNTERS), -) -def test_text_within_budget_unchanged(text: str, budget_multiplier: int, count_tokens): - token_count = count_tokens(text) - - max_tokens = token_count * budget_multiplier + 10 - - result, was_truncated = truncate_to_token_limit( - text, - max_tokens=max_tokens, - count_tokens=count_tokens, - label="test", - ) - - assert result == text - assert was_truncated is False - - -@given(count_tokens=st.sampled_from(SIMPLE_TOKEN_COUNTERS)) -def test_empty_text_always_returns_empty(count_tokens): - result, was_truncated = truncate_to_token_limit( - "", - max_tokens=100, - count_tokens=count_tokens, - label="test", - ) - - assert result == "" - assert was_truncated is False - - -@given( - text=st.text(min_size=1, max_size=1000), - count_tokens=st.sampled_from(SIMPLE_TOKEN_COUNTERS), -) -def test_zero_budget_returns_empty_and_truncated(text: str, count_tokens): - assume(count_tokens(text) > 0) - - result, was_truncated = truncate_to_token_limit( - text, - max_tokens=0, - count_tokens=count_tokens, - label="test", - ) - - assert result == "" - assert was_truncated is True - - -@settings(max_examples=100) # 20 per strategy * 5 strategies -@given( - text=st.one_of(ascii_text, unicode_text, code_text, repeated_char_text, mixed_text), - max_tokens=st.integers(min_value=0, max_value=500), - truncate_end=st.booleans(), -) -def test_truncate_respects_limit_tiktoken(text: str, max_tokens: int, truncate_end: bool): - result, _ = truncate_to_token_limit( - text, - max_tokens=max_tokens, - count_tokens=tiktoken_count, - label="test", - truncate_end=truncate_end, - ) - - assert tiktoken_count(result) <= max_tokens