vet

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

commit be084ce91b81503c5bb3978eb760b34fdb050e9b
parent 91088763849a9c1db036598089fbfe0f6e725e09
Author: andrewlaack-collab <andrew.laack@imbue.com>
Date:   Sun,  1 Feb 2026 02:35:56 +0000

Remove unnecessary things (#5)

* Removed some internal notes

* Remove useless git stuff

* Refactoring git stuff and computing environment

* Unused files removal

* Dead code

* Removed more unused code.

* More pruning

* More code

* Removing

* Removals
Diffstat:
MREADME.md | 4++--
Mimbue_core/imbue_core/agents/agent_api/client.py | 1-
Dimbue_core/imbue_core/agents/data_types/__init__.py | 0
Dimbue_core/imbue_core/agents/data_types/ids.py | 64----------------------------------------------------------------
Mimbue_core/imbue_core/agents/llm_apis/anthropic_api.py | 8++------
Mimbue_core/imbue_core/agents/llm_apis/language_model_api.py | 2+-
Dimbue_core/imbue_core/agents/llm_apis/llm_testing_utils.py | 86-------------------------------------------------------------------------------
Mimbue_core/imbue_core/async_utils.py | 471-------------------------------------------------------------------------------
Mimbue_core/imbue_core/caching.py | 82+++++++++++++++----------------------------------------------------------------
Mimbue_core/imbue_core/cattrs_serialization.py | 26++------------------------
Mimbue_core/imbue_core/common.py | 130-------------------------------------------------------------------------------
Dimbue_core/imbue_core/computing_environment/__init__.py | 0
Dimbue_core/imbue_core/computing_environment/computing_environment.py | 1080-------------------------------------------------------------------------------
Dimbue_core/imbue_core/computing_environment/data_types.py | 31-------------------------------
Dimbue_core/imbue_core/error_utils.py | 28----------------------------
Mimbue_core/imbue_core/frozen_utils.py | 33---------------------------------
Dimbue_core/imbue_core/git.py | 587-------------------------------------------------------------------------------
Dimbue_core/imbue_core/git_data_types.py | 35-----------------------------------
Dimbue_core/imbue_core/ids.py | 44--------------------------------------------
Mimbue_core/imbue_core/itertools.py | 28----------------------------
Dimbue_core/imbue_core/llm_testing_utils.py | 57---------------------------------------------------------
Mimbue_core/imbue_core/nested_evolver.py | 97+++++++++++++++++++++++++++----------------------------------------------------
Mimbue_core/imbue_core/pydantic_serialization.py | 63+++------------------------------------------------------------
Dimbue_core/imbue_core/retry_utils.py | 30------------------------------
Mimbue_core/imbue_core/secrets_utils.py | 19++-----------------
Dimbue_core/imbue_core/section.py | 131-------------------------------------------------------------------------------
Mimbue_core/imbue_core/serialization.py | 188++++++++++++++++++++++++++++++++-----------------------------------------------
Dimbue_core/imbue_core/simple_git.py | 215-------------------------------------------------------------------------------
Dimbue_core/imbue_core/test_utils.py | 88-------------------------------------------------------------------------------
Mimbue_core/pyproject.toml | 1-
Mimbue_tools/README.md | 2+-
Dimbue_tools/imbue_tools/conftest.py | 44--------------------------------------------
Mimbue_tools/imbue_tools/get_conversation_history/get_conversation_history.py | 12------------
Mimbue_tools/imbue_tools/repo_utils/context_retrieval.py | 15---------------
Mimbue_tools/imbue_tools/repo_utils/diff_utils.py | 60------------------------------------------------------------
Mimbue_tools/imbue_tools/repo_utils/errors.py | 15---------------
Mimbue_tools/imbue_tools/repo_utils/file_system.py | 31-------------------------------
Dimbue_tools/imbue_tools/repo_utils/find_relative_to.py | 30------------------------------
Mimbue_tools/imbue_tools/repo_utils/stubify_file.py | 97++++++++++++++++---------------------------------------------------------------
Mimbue_tools/imbue_tools/repo_utils/subrepo_formatting.py | 23-----------------------
Mvet/api.py | 2+-
Mvet/conftest.py | 3+--
Mvet/errors.py | 17+++++++++++++++++
Avet/git.py | 236+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mvet/issue_identifiers/agentic_issue_collation.py | 2+-
Mvet/issue_identifiers/base.py | 2+-
Mvet/issue_identifiers/issue_deduplication.py | 2+-
Mvet/issue_identifiers/issue_evaluation.py | 2+-
Mvet/repo_utils.py | 6+++---
Mvet_types/vet_types/chat_state.py | 32++++++++------------------------
Mvet_types/vet_types/ids.py | 20+++++---------------
Mvet_types/vet_types/messages.py | 22++++++++--------------
52 files changed, 438 insertions(+), 3866 deletions(-)

diff --git a/README.md b/README.md @@ -1,4 +1,4 @@ -# Vet : Verify EveryThing +# Vet : Verify Everything Vet is a standalone verification tool for **code changes** and **coding agent behavior**. @@ -7,7 +7,7 @@ It reviews git diffs, and optionally an agent's conversation history, to find is ## Installation ```bash -pip install vet +pip install verify-everything ``` ## Quickstart diff --git a/imbue_core/imbue_core/agents/agent_api/client.py b/imbue_core/imbue_core/agents/agent_api/client.py @@ -80,7 +80,6 @@ class CachedAgentClient(AgentClient[AgentOptionsT]): # This means that if the generator is not exhausted, the cache will not be updated. # If we do want a way to still cache interactions, even if we early exit the generator, # then we could use a separate thread to get the agent response and cache it in the background. - # See https://gitlab.com/generally-intelligent/generally_intelligent/-/merge_requests/7323#note_2897340073 agent_interaction_record = AgentInteractionRecord.from_agent_interaction(agent_interaction) update_cache(agent_interaction_record, cache_path) diff --git a/imbue_core/imbue_core/agents/data_types/__init__.py b/imbue_core/imbue_core/agents/data_types/__init__.py diff --git a/imbue_core/imbue_core/agents/data_types/ids.py b/imbue_core/imbue_core/agents/data_types/ids.py @@ -1,64 +0,0 @@ -from abc import ABC - -from pydantic import GetCoreSchemaHandler -from pydantic_core import core_schema -from typeid import TypeID -from typeid import get_prefix_and_suffix - - -class TypeIDPrefixMismatchError(Exception): - pass - - -class ObjectID(TypeID, ABC): - """ - A convenience class for string-based object IDs. - - Use in place of strings for IDs. (We don't use raw UUIDs because they are not supported by SQLite.) - - Use `tag` to prefix the ID with the ID type. (We don't use `prefix` because it's already taken by the ancestor class.) - - """ - - # Override this in subclasses to specify the ID type. - tag: str = "oid" - - def __init__(self, value: str | None = None) -> None: - if value is not None: - prefix, suffix = get_prefix_and_suffix(value) - # For convenience, don't require the caller to strip the prefix from existing IDs. - if prefix is not None: - if prefix != self.tag: - raise TypeIDPrefixMismatchError(f"Expected prefix '{self.tag}', got '{prefix}'") - value = suffix - super().__init__(self.tag, value) - - @classmethod - def __get_pydantic_core_schema__(cls, source_type: type, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: - """ - Support transparently deserializing strings into ObjectID instances and vice versa. - """ - return core_schema.no_info_before_validator_function( - lambda raw_value: (cls(raw_value) if isinstance(raw_value, str) else raw_value), - core_schema.union_schema( - [ - core_schema.is_instance_schema(cls), - core_schema.str_schema(), - ] - ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: str(instance), return_schema=core_schema.str_schema() - ), - ) - - -class TaskID(ObjectID): - tag: str = "tsk" - - -class ProjectID(ObjectID): - tag: str = "prj" - - -class AgentMessageID(ObjectID): - tag: str = "agm" diff --git a/imbue_core/imbue_core/agents/llm_apis/anthropic_api.py b/imbue_core/imbue_core/agents/llm_apis/anthropic_api.py @@ -781,13 +781,9 @@ class AnthropicAPI(LanguageModelAPI): def _get_api_key_or_auth_token() -> tuple[str | None, str | None]: api_key = get_secret("ANTHROPIC_API_KEY") - # The standard environment variable for this is ANTHROPIC_AUTH_TOKEN, - # but we don't use it since it has some bad interactions with Claude Code. - auth_token = get_secret("IMBUE_ANTHROPIC_AUTH_TOKEN") + auth_token = get_secret("ANTHROPIC_AUTH_TOKEN") if not api_key and not auth_token: - raise MissingAPIKeyError( - "Neither of ANTHROPIC_API_KEY and IMBUE_ANTHROPIC_AUTH_TOKEN environment variables is set" - ) + raise MissingAPIKeyError("Neither ANTHROPIC_API_KEY nor ANTHROPIC_AUTH_TOKEN environment variable is set") return api_key, auth_token diff --git a/imbue_core/imbue_core/agents/llm_apis/language_model_api.py b/imbue_core/imbue_core/agents/llm_apis/language_model_api.py @@ -155,7 +155,7 @@ class LanguageModelAPI(abc.ABC, MutableModel): if "PYTEST_CURRENT_TEST" not in os.environ: logger.warning( - "You are trying to call a language model from outside of a hammer with no global resource limits set. That is a bad idea because the spend will not be restricted, and you may end up accidentally spending much more than you expected." + "You are trying to call a language model with no global resource limits set. That is a bad idea because the spend will not be restricted, and you may end up accidentally spending much more than you expected." ) return None diff --git a/imbue_core/imbue_core/agents/llm_apis/llm_testing_utils.py b/imbue_core/imbue_core/agents/llm_apis/llm_testing_utils.py @@ -1,86 +0,0 @@ -from google.genai.types import CountTokensResponse -from syrupy.assertion import SnapshotAssertion -from syrupy.extensions.single_file import SingleFileAmberSnapshotExtension -from syrupy.extensions.single_file import SingleFileSnapshotExtension - -from imbue_core.agents.llm_apis.data_types import CachedCostedModelResponse -from imbue_core.agents.llm_apis.data_types import CostedLanguageModelResponse -from imbue_core.caching import AsyncCache -from imbue_core.frozen_utils import FrozenMapping - - -async def check_llm_responses_in_cache(snapshot: SnapshotAssertion, temp_cache: AsyncCache, suffix: str = "") -> None: - """Runs as the test fixture completes to check that the LLM inputs and outputs stay the same, in a human-readable format.""" - - async with temp_cache as cache: - all_keys: tuple[str, ...] = await cache.get_all_keys() # Contains both the streaming and non-streaming keys? - value_by_key: FrozenMapping[str, CachedCostedModelResponse | None] = await cache.get_all(all_keys) - - cache_items: list[tuple[str, CachedCostedModelResponse]] = [ - (k, v) for k, v in value_by_key.items() if v is not None - ] - cache_items.sort(key=lambda x: x[1].timestamp) - for cache_index, (cache_key, cached_response) in enumerate(cache_items): - prompt: bytes = b"" - joined_responses: bytes = b"" - metadata_lines: list[str] = [] - - metadata_lines.append(f"{cache_index=} (when cache is sorted by timestamp)") - metadata_lines.append(f"{cache_key=}") # Keys must be stable and not too big. - if cached_response.inputs is not None: - prompt = cached_response.inputs.prompt.encode("utf-8") - metadata_lines.append(f"request metdata ({type(cached_response.inputs)})") - for field, field_value in cached_response.inputs.__dict__.items(): - if field != "prompt": # print the prompt separately below. - metadata_lines.append(f" {field}: {field_value}") - - metadata_lines.append("cached_response metadata:") - for field, field_value in cached_response.__dict__.items(): - if field not in ("inputs", "response"): # already printed above - metadata_lines.append(f" {field}: {field_value}") - - if cached_response.response is not None: - match cached_response.response: - case CostedLanguageModelResponse(): - joined_responses = "".join([r.text for r in cached_response.response.responses]).encode("utf-8") - for response_index, response in enumerate(cached_response.response.responses): - metadata_lines.append(f"response[{response_index}] metadata:") - for ( - field, - field_value, - ) in cached_response.response.__dict__.items(): - if field != "responses": # already printed the responses above - metadata_lines.append(f" {field}: {field_value}") - case CountTokensResponse(): - metadata_lines.append("response metadata:") - for field, field_value in cached_response.response.__dict__.items(): - metadata_lines.append(f" {field}: {field_value}") - - snapshotted_prompt = snapshot( - extension_class=SingleFileSnapshotExtension, - name=f"{cache_index:03d}_inputs{suffix}", - ) - - # TODO nasty syrupy hacking - snapshot_contents, _ = snapshotted_prompt._recall_data(snapshotted_prompt.index) - - assert ( - snapshotted_prompt == prompt - ), f"Your prompt changed, did you mean for this to happen?\nExpected prompt: {snapshot_contents!r}\nPrompt: {prompt!r}" - - snapshotted_response = snapshot( - extension_class=SingleFileSnapshotExtension, - name=f"{cache_index:03d}_response{suffix}", - ) - - assert ( - snapshotted_response == joined_responses - ), "Your response changed; maybe you aren't actually hitting the cache?" - - snapshotted_metadata = snapshot( - extension_class=SingleFileAmberSnapshotExtension, - name=f"{cache_index:03d}_metadata{suffix}", - ) - assert snapshotted_metadata == "\n".join( - metadata_lines - ), "Metadata changed; maybe you aren't actually hitting the cache?" diff --git a/imbue_core/imbue_core/async_utils.py b/imbue_core/imbue_core/async_utils.py @@ -1,45 +1,13 @@ import asyncio import functools -import inspect -import os -import platform -import sys import threading -import traceback -from contextlib import AbstractAsyncContextManager -from contextlib import AbstractContextManager -from contextlib import contextmanager -from datetime import datetime -from http.server import BaseHTTPRequestHandler -from http.server import HTTPServer -from pathlib import Path -from types import FrameType -from typing import Any -from typing import AsyncGenerator from typing import Awaitable from typing import Callable -from typing import Coroutine -from typing import Generator -from typing import Generic -from typing import Iterable -from typing import Iterator from typing import ParamSpec from typing import TypeVar -from typing import cast -from urllib.parse import parse_qs -from urllib.parse import urlparse - -from loguru import logger -from traceback_with_variables.core import _iter_lines - -from imbue_core.async_monkey_patches import log_exception -from imbue_core.async_monkey_patches import safe_cancel P = ParamSpec("P") R = TypeVar("R") -S = TypeVar("S") - -ALL_EVENT_LOOPS: list[asyncio.AbstractEventLoop] = [] def sync(func: Callable[P, Awaitable[R]]) -> Callable[P, R]: @@ -55,56 +23,6 @@ def sync(func: Callable[P, Awaitable[R]]) -> Callable[P, R]: return wrapper -def sync_generator( - func: Callable[P, AsyncGenerator[R, None]], -) -> Callable[P, Generator[R, None, None]]: - """Decorator that runs an async generator synchronously by dispatching to - an event loop running in a separate thread. - """ - - @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> Generator[R, None, None]: - loop = _get_or_create_event_loop() - agen = func(*args, **kwargs) - while True: - try: - future = asyncio.run_coroutine_threadsafe(agen.__anext__(), loop) - yield future.result() - except StopAsyncIteration: - break - - return wrapper - - -@contextmanager -# pyre-ignore[24]: pyre doesn't understand AbstractAsyncContextManager -def sync_contextmanager( - async_context_manager: AbstractAsyncContextManager[S], -) -> Generator[S, None, None]: - sync_aenter = sync(async_context_manager.__aenter__) - sync_aexit = sync(async_context_manager.__aexit__) - - enter_result = sync_aenter() - try: - yield enter_result - except BaseException as e: - if not sync_aexit(e.__class__, e, e.__traceback__): - raise - else: - sync_aexit(None, None, None) - - -# pyre doesn't understand AbstractAsyncContextManager -def sync_contextmanager_func( - cm_func: Callable[P, AbstractAsyncContextManager[S]], # pyre-ignore[24] -) -> Callable[P, AbstractContextManager[S]]: # pyre-ignore[24] - @functools.wraps(cm_func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> AbstractContextManager[S]: # pyre-ignore[24] - return sync_contextmanager(cm_func(*args, **kwargs)) - - return wrapper - - _LOOP: asyncio.AbstractEventLoop | None = None _LOOP_LOCK: threading.Lock = threading.Lock() @@ -124,395 +42,6 @@ def _get_or_create_event_loop() -> asyncio.AbstractEventLoop: return _LOOP # pyre-ignore[7]: we just made _LOOP, so it's not None unless it got destroyed just now -def shorten_filename(filename: str) -> str: - path = Path(filename) - while path.parent: - path = path.parent - if not (path / "__init__.py").exists(): - break - - try: - shortened = str(Path(filename).relative_to(path)) - except ValueError: - shortened = filename # in case the path cannot be made relative - - return shortened - - -# TODO: I'd really like to print these task groups in a hierarchical way instead of flat -- we know which groups -# launched which other groups, so we could print them in a tree structure. That would be a lot more readable. -# It might also be nice to print without any stacks at all. As long as the tasks had good names, that would make it -# possible to very easily understand everything that was currently executing. -# I could even imagine controls that allowed for printing just the task groups themselves, which would also be easier -# to understand. -def get_all_async_task_stacks( - num_skipped_frames: int = 0, - log_variables: bool = False, - loop: asyncio.AbstractEventLoop | None = None, -) -> Iterator[str]: - """Yields the lines of a report for all stack frames of all async tasks including variables.""" - tasks_by_task_group: dict[asyncio.TaskGroup | None, list[asyncio.Task]] = {} - owning_task_by_task_group: dict[asyncio.TaskGroup, asyncio.Task] = {} - - for task in asyncio.all_tasks(loop=loop): - if task.done(): - continue - task_group = cast(asyncio.TaskGroup | None, getattr(task, "task_group", None)) - owned_task_group = cast(asyncio.TaskGroup | None, getattr(task, "owned_task_group", None)) - if owned_task_group is not None: - owning_task_by_task_group[owned_task_group] = task - if owned_task_group not in tasks_by_task_group: - tasks_by_task_group[owned_task_group] = [] - else: - tasks_by_task_group.setdefault(task_group, []).append(task) - - all_owning_tasks = set(owning_task_by_task_group.values()) - - task_group_keys = list(tasks_by_task_group.keys()) - for task_group in cast(list[asyncio.TaskGroup | None], [None]) + [x for x in task_group_keys if x is not None]: - if task_group not in tasks_by_task_group: - continue - tasks = tasks_by_task_group[task_group] - if task_group is None: - yield f"\n\n{'=' * 40}\nNo TaskGroup:\n" - else: - yield f"\n\n{'=' * 40}\nTaskGroup: {getattr(task_group, 'name', 'unknown')}\n" - owning_task = None - if task_group is not None: - owning_task = owning_task_by_task_group.get(task_group) - is_first_line_skipped = False - all_tasks = tasks - if owning_task is not None: - yield f"Owning Task: {owning_task.get_name()}\n" - is_first_line_skipped = True - all_tasks.insert(0, owning_task) - for task in all_tasks: - # skip these -- they'll be printed at the top of the group that they own - if task_group is None and task in all_owning_tasks: - continue - if is_first_line_skipped: - is_first_line_skipped = False - else: - yield f"{'-' * 40}\nTask {task.get_name()}:\n" - frames = extract_frames(task) - for frame in frames: - frame_infos = inspect.getouterframes(frame)[num_skipped_frames:] - # Use private method _iter_lines to traceback async tasks, which is not explicitly handled in the API - if log_variables: - for line in _iter_lines( - e=None, - frame_infos=frame_infos, - fmt=None, - for_file=None, - ): - yield line + "\n" - else: - frame_summaries = [ - traceback.FrameSummary( - shorten_filename(info.filename), - lineno=info.lineno, - name=info.function, - line=(info.code_context[0].strip() if info.code_context else None), - ) - for info in frame_infos - ] - yield from traceback.format_list(frame_summaries) - - -def extract_frames(task: asyncio.Task) -> list[FrameType]: - """Extract the stack frames of an async task.""" - coro = task.get_coro() - assert isinstance(coro, Coroutine) - frames = [] - while coro is not None and coro.cr_frame is not None: - frames.append(coro.cr_frame) - coro = coro.cr_await # type: ignore - # this happens at the very bottom of the call stack, there it seems to often be a FutureIter, Event, etc - if type(coro).__name__ != "coroutine": - break - return frames - - -def print_all_async_task_stacks(log_variables: bool = False) -> None: - """Prints the stack frames of all running tasks.""" - for line in get_all_async_task_stacks(log_variables=log_variables): - print(line) - - -def dump_all_async_task_stacks(log_path: str | Path, log_variables: bool = False) -> None: - """Dump the stack frames of all running tasks to file.""" - with open(log_path, "w") as f: - for line in get_all_async_task_stacks(log_variables=log_variables): - if log_variables: - line += "\n" - f.write(line) - - -async def periodically_log_async_stacks(log_dir: str | Path, interval: float, log_variables: bool = False) -> None: - """Periodically print the stack traces of all running tasks.""" - Path(log_dir).mkdir(parents=True, exist_ok=True) - while True: - log_path = Path(log_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" - dump_all_async_task_stacks(log_path=log_path, log_variables=log_variables) - logger.debug("Dumped asyncio stack traces to {}", log_path) - await asyncio.sleep(interval) - - -async def is_task_group_complete( - task_group: asyncio.TaskGroup, trace_task: asyncio.Task, buffer_time: float = 1.0 -) -> None: - """Continuously check if all tasks except the stack trace logger are done.""" - while True: - if all(task.done() for task in task_group._tasks if task is not trace_task): - await asyncio.sleep(buffer_time) # Wait for buffer time in case new tasks are added - # Recheck to confirm - if all(task.done() for task in task_group._tasks if task is not trace_task): - break - await asyncio.sleep(1.0) - - -async def inject_async_stack_trace_logger( - task_group: asyncio.TaskGroup, - log_dir: str | Path, - log_interval: float = 60.0, - log_variables: bool = False, -) -> None: - """Inject a stack trace logger into the task group.""" - trace_task = asyncio.create_task( - periodically_log_async_stacks(log_dir=log_dir, interval=log_interval, log_variables=log_variables), - name="periodically_log_async_stacks", - ) - await is_task_group_complete(task_group, trace_task) - safe_cancel(trace_task) - try: - await trace_task - except asyncio.CancelledError: - pass - - -class AsyncTaskStacksHandler(BaseHTTPRequestHandler): - def do_GET(self) -> None: - try: - parsed_url = urlparse(self.path) - query_params = parse_qs(parsed_url.query) - log_variables = query_params.get("locals", ["false"])[0].lower() in [ - "true", - "1", - "yes", - ] - - self.send_response(200) - self.send_header("Content-type", "text/plain") - self.end_headers() - - pid = os.getpid() - command_line = " ".join(sys.argv) - - # Path to the Python executable - python_executable = sys.executable - - # Python version - python_version = platform.python_version() - - # Print the collected information - self.wfile.write(f"Process {pid}: {python_executable} {command_line}\n".encode("utf-8")) - self.wfile.write(f"Python v{python_version} ({python_executable})\n\n".encode("utf-8")) - - for loop in ALL_EVENT_LOOPS: - for line in get_all_async_task_stacks(log_variables=log_variables, loop=loop): - self.wfile.write(line.encode("utf-8")) - except BaseException as e: - log_exception(e, "exception in AsyncTaskStacksHandler") - raise - - -def run_async_stackframe_server_thread(port_range_low: int, port_range_high: int) -> None: - try: - success = False - for port in range(port_range_low, port_range_high): - try: - server_address = ("localhost", port) - httpd = HTTPServer(server_address, AsyncTaskStacksHandler) - success = True - print(f"Starting async stack trace server on port {port}. Process pid: {os.getpid()}") - break - except OSError: - continue - - if not success: - logger.info("Could not find an open port to start the async stack trace server, continuing without it.") - return - - httpd.serve_forever() - except BaseException as e: - log_exception(e, "exception in run_async_stackframe_server_thread") - raise - - -IS_STACKFRAME_SERVER_RUNNING = False -STACKFRAME_SERVER_PORT_LOW = 60000 -STACKFRAME_SERVER_PORT_HIGH = 61000 - - -def run_async_stackframe_server_for_loop(event_loop: asyncio.AbstractEventLoop) -> None: - ALL_EVENT_LOOPS.append(event_loop) - run_async_stackframe_server() - - -def run_async_stackframe_server() -> None: - global IS_STACKFRAME_SERVER_RUNNING - if not IS_STACKFRAME_SERVER_RUNNING: - IS_STACKFRAME_SERVER_RUNNING = True - - port_str = os.environ.get("ASYNC_STACKFRAME_SERVER_PORT") - if port_str is not None: - try: - port = int(port_str) - port_range_low = port - port_range_high = port + 1 - except ValueError: - logger.error("ASYNC_STACKFRAME_SERVER_PORT is not an integer: {}", port_str) - raise - else: - port_range_low = STACKFRAME_SERVER_PORT_LOW - port_range_high = STACKFRAME_SERVER_PORT_HIGH - - thread = threading.Thread( - target=run_async_stackframe_server_thread, - daemon=True, - kwargs={ - "port_range_low": port_range_low, - "port_range_high": port_range_high, - }, - ) - thread.start() - - -def with_timeout(func: Callable[P, Awaitable[R]], timeout_secs: float) -> Callable[P, Awaitable[R]]: - """Decorator that adds a timeout to an async function.""" - - @functools.wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - return await asyncio.wait_for(func(*args, **kwargs), timeout_secs) - - return wrapper - - -T = TypeVar("T") - - -async def gather_with_limited_concurrency(coros: Iterable[Awaitable[T]], n: int) -> list[T]: - """Like asyncio.gather() but will only run `n` in parallel at a time. - - Note that a call like `asyncio.gather(*coros)` is now `gather_with_limited_concurrency(coros, n=10), - without the splat. - """ - semaphore = asyncio.Semaphore(n) - - async def sem_coro(coro: Awaitable[T]) -> T: - async with semaphore: - return await coro - - return list(await asyncio.gather(*(sem_coro(c) for c in coros))) - - -_NOT_FOUND = object() - - -class AsyncCachedProperty(Generic[T]): - """A descriptor factory that behaves very similarly to `functools.cached_property`, but for - async methods! - - The type annotations here are rough; it's not realistic to get them perfect without using a .pyi file. - """ - - def __init__(self, func: Callable[[Any], Coroutine[None, None, T]]) -> None: - self.func = func - self.attrname: str | None = None - self.__doc__ = func.__doc__ - - def __set_name__(self, owner: type, name: str) -> None: - if self.attrname is None: - self.attrname = name - elif name != self.attrname: - raise TypeError("Cannot assign the same AsyncCachedProperty to multiple names") - - def _get_attrname(self) -> str: - if self.attrname is None: - raise TypeError("Cannot use AsyncCachedProperty instance without calling __set_name__") - return self.attrname - - def _get_cache(self, instance: object) -> dict[str, Any]: - try: - return instance.__dict__ - except AttributeError: - raise TypeError( - "Cannot use AsyncCachedProperty with instances that do not have a __dict__ attribute" - ) from None - - def __get__(self, instance: object, owner: type | None = None) -> Awaitable[T]: - if instance is None: - return self # type: ignore - attrname = self._get_attrname() - cache = self._get_cache(instance) - val = cache.get(attrname, _NOT_FOUND) - if val is not _NOT_FOUND: - return cast(Awaitable[T], val) - - task = asyncio.create_task(self.func(instance)) - cache[attrname] = task - return task - - def __delete__(self, instance: object) -> None: - if instance is None: - raise TypeError("Cannot delete AsyncCachedProperty on a class") - attrname = self._get_attrname() - cache = self._get_cache(instance) - try: - awaitable = cache.pop(attrname) - if not awaitable.done(): - safe_cancel(awaitable) - except KeyError: - raise AttributeError(f"Cannot delete attribute {self.attrname!r}") from None - - def __set__(self, instance: object, value: T) -> None: - if instance is None: - raise TypeError("Cannot set AsyncCachedProperty on a class") - attrname = self._get_attrname() - cache = self._get_cache(instance) - existing = cache.pop(attrname, None) - if existing is not None and not existing.done(): - safe_cancel(existing) - fut: asyncio.Future[T] = asyncio.Future() - fut.set_result(value) - cache[attrname] = fut - - -def wrapped_asyncio_run( - main: Awaitable[T], - *, - debug: bool | None = None, - loop_factory: Callable[..., asyncio.AbstractEventLoop] | None = None, -) -> T: - """ - This is a lightweight wrapper with a singular purpose -- it is here to enable apyspy - - apyspy is our async equivalent to py-spy. - - Without apyspy, it's really annoying to debug why some async task is stuck. - """ - - async def wrapper_main() -> T: - ALL_EVENT_LOOPS.append(asyncio.get_event_loop()) - result = await main - ALL_EVENT_LOOPS.pop() - return result - - run_async_stackframe_server() - with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner: - return runner.run(wrapper_main()) - - def make_async(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: """ Turn the annotated function into an async function by running it in a thread. diff --git a/imbue_core/imbue_core/caching.py b/imbue_core/imbue_core/caching.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import os from functools import lru_cache from pathlib import Path from types import TracebackType @@ -98,7 +97,9 @@ class AsyncCache(AsyncCacheInterface[ValueType], Generic[ValueType]): loop = asyncio.get_running_loop() cache = self.cache assert cache is not None - result = await loop.run_in_executor(None, cache.__exit__, exc_type, exc_val, exc_tb) + result = await loop.run_in_executor( + None, cache.__exit__, exc_type, exc_val, exc_tb + ) self.cache = None return result @@ -115,15 +116,13 @@ class AsyncCache(AsyncCacheInterface[ValueType], Generic[ValueType]): cache = self.cache assert cache is not None loop = asyncio.get_running_loop() - assert isinstance(value, self.value_cls), f"Expected {self.value_cls}, got {type(value)}" + assert isinstance(value, self.value_cls), ( + f"Expected {self.value_cls}, got {type(value)}" + ) serialized_value = serialize_to_json(value) - return await loop.run_in_executor(None, cache.set, key, serialized_value, expire, read, tag, retry) - - async def delete(self, key: str, retry: bool = False) -> bool: - cache = self.cache - assert cache is not None - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, cache.delete, key, retry) + return await loop.run_in_executor( + None, cache.set, key, serialized_value, expire, read, tag, retry + ) async def get( self, @@ -137,13 +136,15 @@ class AsyncCache(AsyncCacheInterface[ValueType], Generic[ValueType]): cache = self.cache assert cache is not None loop = asyncio.get_running_loop() - value = await loop.run_in_executor(None, cache.get, key, None, read, expire_time, tag, retry) + value = await loop.run_in_executor( + None, cache.get, key, None, read, expire_time, tag, retry + ) if value is None: return default deserialized_value = deserialize_from_json(value) - assert isinstance( - deserialized_value, self.value_cls - ), f"Expected {self.value_cls}, got {type(deserialized_value)}" + assert isinstance(deserialized_value, self.value_cls), ( + f"Expected {self.value_cls}, got {type(deserialized_value)}" + ) return deserialized_value # TODO: this is not smart implementation, but at least it will be possible to optimize later without refactoring @@ -180,56 +181,3 @@ def get_cache(data_path: Path) -> Cache: eviction_policy="none", size_limit=2**36, ) - - -def get_default_llm_response_cache() -> Path: - return Path(os.environ.get("RESPONSE_CACHE_PATH", os.path.expanduser("~/.llm_response_cache"))) - - -def get_default_count_tokens_cache() -> Path: - return Path(os.environ.get("COUNT_TOKENS_CACHE_PATH", os.path.expanduser("~/.count_tokens_cache"))) - - -def get_test_llm_response_cache() -> Path: - return Path(os.path.expanduser("~/.llm_test_response_cache")) - - -class InMemoryCache(AsyncCacheInterface[ValueType], Generic[ValueType]): - def __init__(self, values: tuple[ValueType, ...]) -> None: - self.values = values - - async def __aenter__(self) -> Self: - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - pass - - async def get( - self, - key: str, - default: ValueType | None = None, - read: bool = False, - expire_time: bool = False, - tag: bool = False, - retry: bool = False, - ) -> ValueType | None: - return self.values[int(key)] - - async def get_all( - self, - keys: Sequence[str], - default: ValueType | None = None, - read: bool = False, - expire_time: bool = False, - tag: bool = False, - retry: bool = False, - ) -> FrozenMapping[str, ValueType | None]: - return FrozenDict(zip(await self.get_all_keys(), self.values)) - - async def get_all_keys(self, reverse: bool = False) -> tuple[str, ...]: - return tuple(map(str, range(len(self.values)))) diff --git a/imbue_core/imbue_core/cattrs_serialization.py b/imbue_core/imbue_core/cattrs_serialization.py @@ -868,10 +868,9 @@ def _serialize_to_json_dumpable_object( # with `is_reversible=False` since we won't know the type to be able to recreate the object. assert is_reversible, "Cannot restructure inputs if is_reversible=False" - # TODO: this is a hack to make it possible to serialize ExecutionContexts for class method hammers. + # TODO: this is a hack to make it possible to serialize ExecutionContexts for class methods. # This lets us serialize ExecutionContexts for calls to class methods without serializing the class itself. - # The long-term solutions are 1) either get rid of all class method hammers, - # or 2) write a custom hook that can serialize type objects. + # The long-term solution is to write a custom hook that can serialize type objects. if type(obj) is dict and "__class__" in obj: del obj["__class__"] @@ -990,24 +989,3 @@ def deserialize_from_dict( return _deserialize_using_type_marker(data, as_type, converter=converter) except Exception as e: raise SerializationError(str(e)) from e - - -def deserialize_from_dict_with_type(data: dict[str, Any], obj_type: type[T]) -> T: - try: - converter = CONVERTER_FACTORY.get_converter(for_javascript=False, exclude_dont_serialize_fields=False) - result = converter.structure(data, obj_type) - assert isinstance(result, obj_type), f"Expected an object of type {obj_type}, but got {result}" - return result - except Exception as e: - raise SerializationError(str(e)) from e - - -def deserialize_from_json_with_type(data: str | bytes | bytearray, obj_type: type[T]) -> T: - try: - converter = CONVERTER_FACTORY.get_converter(for_javascript=False, exclude_dont_serialize_fields=False) - return cast( - T, - _deserialize_serialized_object(json.loads(data), obj_type, converter=converter), - ) - except Exception as e: - raise SerializationError(str(e)) from e diff --git a/imbue_core/imbue_core/common.py b/imbue_core/imbue_core/common.py @@ -1,135 +1,5 @@ -import functools -import hashlib -import inspect -import os -import platform -import sys import uuid -from pathlib import Path -from types import FrameType - -import pathspec - - -def is_on_osx() -> bool: - return platform.system().lower() == "darwin" - - -def is_running_within_a_pytest_tree() -> bool: - """ - This is true if this, or any parent process, is running under pytest. - - This is different from `is_running_within_a_pytest_process` in that it is true if we are logically testing or not. - - This is usually what you want to check - (eg, this will be true even if you are a separately launched integration server process) - """ - return "PYTEST_CURRENT_TEST" in os.environ - - -def is_running_within_a_pytest_process() -> bool: - """ - This is true if the current process is literally running pytest. - - This is different from `is_running_within_a_pytest_tree` in that it checks if the current process is pytest itself, - which is most useful for knowing whether we are running a bunch of unit tests in this process or not. - """ - return "pytest" in sys.modules - - -def is_live_debugging() -> bool: - """ - Returns True if the current process is being debugged, for example by PyCharm or another IDE. - """ - # this is unfortunately true when measuring coverage and in other cases, sigh - # return sys.gettrace() is not None - # but this is only true when debugging in pycharm, I think? - return sys.breakpointhook.__module__ != "sys" - - -@functools.lru_cache(maxsize=1) -def get_filesystem_root() -> str: - env_value = os.getenv("SCIENCE_FILESYSTEM_ROOT") - if not env_value: - if is_on_osx(): - return "/tmp/science" - else: - # When on the physical cluster (and possibly other core clusters), this path is mounted to a unique per-container file path. - # Anything produced at runtime >10mb should likely go here, as well as anything you might want to dig up for later debugging. - # The hosts clean up the paths from dead containers periodically, but large data processing jobs should still clean up after themselves. - return "/mnt/private" - return env_value - - -@functools.lru_cache(maxsize=1) -def get_temp_dir() -> str: - temp_dir = os.path.join(get_filesystem_root(), "tmp") - os.makedirs(temp_dir, exist_ok=True) - return temp_dir - - -def hash_string(string: str) -> str: - return hashlib.md5(string.encode("utf-8")).hexdigest() - - -def get_current_function_name() -> str: - frame = inspect.currentframe() - if frame is None: - return "no_frame" - prev_frame = frame.f_back - if prev_frame is None or not isinstance(prev_frame, FrameType): - return "no_previous_frame" - return prev_frame.f_code.co_name - - -def filter_excluded_files(files: list[Path], directory: Path, exclude_file_name: str = ".gitignore") -> list[Path]: - """Remove files from the list that are matched by a .gitignore or similarly-specified exclude file such as - .gitignore or ratchet_excluded.txt. - """ - - # Underneath the root directory, find all the excluders. - # They can occur in subfolders and if they do they apply only to that subfolder. - excluders = {path for path in directory.rglob(exclude_file_name) if not path.is_symlink()} - - # Per excluder, make a pathspec. - for excluder in excluders: - with excluder.open("r") as exclude_file: - exclude_spec = pathspec.GitIgnoreSpec.from_lines(exclude_file) - - # Now we have two cases - We keep the file if the excluder doesn't apply because it's in a different - # folder, or if it applies but doesn't match - prefix = os.path.dirname(excluder) - files = [ - file - for file in files - if not (file.is_relative_to(prefix) and exclude_spec.match_file(file.relative_to(prefix))) - ] - - return files def generate_id() -> str: return uuid.uuid4().hex - - -def generate_id_from_existing_id(existing_id: str, seed: int) -> str: - return hashlib.md5(f"{existing_id}-{seed}".encode()).hexdigest() - - -def truncate_string(s: str, max_length: int) -> str: - if len(s) <= max_length: - return s - return s[: max_length - 3] + "..." - - -def parse_bool_environment_variable(var_name: str) -> bool: - env_var = os.environ.get(var_name, "0").lower() - - assert env_var in ( - "0", - "1", - "true", - "false", - ), f"{var_name} environment variable must be '0', '1', 'true', or 'false'. Current value: '{env_var}'" - - return env_var in ("1", "true") diff --git a/imbue_core/imbue_core/computing_environment/__init__.py b/imbue_core/imbue_core/computing_environment/__init__.py diff --git a/imbue_core/imbue_core/computing_environment/computing_environment.py b/imbue_core/imbue_core/computing_environment/computing_environment.py @@ -1,1080 +0,0 @@ -from __future__ import annotations - -import asyncio -import shlex -import time -from datetime import datetime -from pathlib import Path -from typing import Protocol -from typing import Sequence -from typing import TYPE_CHECKING -from uuid import uuid4 - -import anyio -from loguru import logger -from tenacity import TryAgain -from tenacity import retry -from tenacity import retry_all -from tenacity import retry_if_exception_type -from tenacity import stop_after_attempt -from tenacity import wait_random_exponential - -from imbue_core.async_monkey_patches import log_exception -from imbue_core.computing_environment.data_types import AnyPath -from imbue_core.computing_environment.data_types import FailedToMakeCommitError -from imbue_core.computing_environment.data_types import PatchApplicationError -from imbue_core.computing_environment.data_types import RunCommandError -from imbue_core.git_data_types import CommitTimestamp -from imbue_core.retry_utils import log_before_sleep -from imbue_core.section import Section -from imbue_core.time_utils import get_current_time - -# Import the types needed for file modes -if TYPE_CHECKING: - # for proper file mode typing - from _typeshed import OpenBinaryModeReading - from _typeshed import OpenBinaryModeWriting - from _typeshed import OpenTextModeReading - from _typeshed import OpenTextModeWriting - - -class ComputingEnvironment(Protocol): - """Protocol defining the interface for a computing environment. - - This protocol specifies the required methods for interacting with a computing - environment, including running commands and file operations. - """ - - async def run_command( - self, - command: Sequence[str], - check: bool = True, - secrets: dict[str, str] | None = None, - cwd: AnyPath | None = None, - is_error_logged: bool = True, - ) -> str: ... - - async def run_git( - self, - command: Sequence[str], - check: bool = True, - cwd: AnyPath | None = None, - is_error_logged: bool = True, - is_stripped: bool = True, - retry_on_git_lock_error: bool = True, - ) -> str: ... - - async def write_file( - self, - relative_path: AnyPath, - content: str | bytes | None, - cwd: AnyPath | None = None, - mode: OpenTextModeWriting | OpenBinaryModeWriting = "w", - mkdir_if_missing: bool = True, - ) -> None: ... - - async def read_file( - self, - relative_path: AnyPath, - cwd: AnyPath | None = None, - mode: OpenTextModeReading | OpenBinaryModeReading = "r", - mkdir_if_missing: bool = True, - ) -> str | bytes: ... - - async def delete_file( - self, - relative_path: AnyPath, - cwd: AnyPath | None = None, - ) -> None: ... - - -def _get_temp_patch_file() -> anyio.Path: - # this is a bad idea because it triggers the file watcher - # patch_file = (self.base_path / str(uuid4())).with_suffix(".patch") - patch_file = (Path("/tmp") / uuid4().hex).with_suffix(".patch") - return anyio.Path(patch_file) - - -async def run_command_with_retry_on_git_lock_error( - computing_environment: ComputingEnvironment, - command: Sequence[str], - check: bool = True, - is_error_logged: bool = True, - cwd: AnyPath | None = None, -) -> str: - max_retries = 50 - retry_count = 0 - retry_delay = 0.1 # seconds - while True: - try: - return await computing_environment.run_command( - command, - check=check, - is_error_logged=is_error_logged and retry_count >= max_retries, - cwd=cwd, - ) - except RunCommandError as e: - error_message = str(e) - if "fatal: Unable to create" in error_message and ".git/index.lock': File exists" in error_message: - if retry_count >= max_retries: - raise - await asyncio.sleep(retry_delay) - retry_count += 1 - else: - raise - - -@retry( - wait=wait_random_exponential(multiplier=0.1, max=2, exp_base=2), - reraise=True, - stop=stop_after_attempt(50), - before_sleep=log_before_sleep, -) -async def wait_for_git_index_lock_to_be_free(local_sync_repo_path: Path) -> None: - # Path to the git index lock file using anyio.Path to avoid blocking - lock_file_path = anyio.Path(local_sync_repo_path / ".git" / "index.lock") - if await lock_file_path.exists(): - raise TryAgain - - -async def apply_patch_without_git(computing_environment: ComputingEnvironment, diff: str) -> None: - if diff.strip() == "": - return - patch_file = _get_temp_patch_file() - try: - await computing_environment.write_file(patch_file, diff) - await computing_environment.run_command(("bash", "-c", f"patch -p1 < {patch_file}")) - except RunCommandError as e: - raise PatchApplicationError(f"Failed to apply patch: {e}") from e - finally: - await computing_environment.delete_file(patch_file) - - -async def is_repo_dirty(computing_environment: ComputingEnvironment, is_untracked_considered: bool = True) -> bool: - """Check if the repo has any uncommitted changes.""" - return bool( - await computing_environment.run_git( - ( - "status", - "--porcelain", - *([] if is_untracked_considered else ["--untracked-files=no"]), - ) - ) - ) - - -async def are_all_commits_pushed(computing_environment: ComputingEnvironment) -> bool: - """Check if the repo has any unpushed commits.""" - output = await computing_environment.run_git(("cherry",)) - return output.strip() == "" - - -async def are_all_remote_commits_pulled( - computing_environment: ComputingEnvironment, -) -> bool: - # note this will fail if the branch hasn't been pushed to the remote - output = await computing_environment.run_command( - ("bash", "-c", "git fetch && git rev-list HEAD..@{upstream} --count") - ) - return output.strip() == "0" - - -async def assert_repo_is_clean(computing_environment: ComputingEnvironment) -> None: - """Assert that the repo has no uncommitted changes.""" - assert not await is_repo_dirty( - computing_environment - ), "You have untracked files. Please address them before using this script (this is to prevent accidentally adding large files unintentionally)" - - -async def get_branch_name(computing_environment: ComputingEnvironment, is_error_logged: bool = True) -> str: - """Get the name of the current branch.""" - return await computing_environment.run_git(("symbolic-ref", "--short", "HEAD"), is_error_logged=is_error_logged) - - -async def rename_branch( - computing_environment: ComputingEnvironment, - old_name: str, - new_name: str, - force_if_exists: bool = True, -) -> None: - """Rename the given branch.""" - if force_if_exists: - await computing_environment.run_git(("branch", "-M", old_name, new_name)) - else: - await computing_environment.run_git(("branch", "-m", old_name, new_name)) - - -async def get_branch_description(computing_environment: ComputingEnvironment, branch_name: str) -> str: - """Get the description of the given branch.""" - try: - return await computing_environment.run_git( - ("config", f"branch.{branch_name}.description"), is_error_logged=False - ) - except RunCommandError as e: - if e.returncode == 1: - # no description set - return "" - raise - - -async def is_branch_exists(computing_environment: ComputingEnvironment, branch_name: str) -> bool: - """Check if the given branch exists.""" - result = await computing_environment.run_git( - ("rev-parse", "--verify", "--quiet", branch_name), - is_error_logged=False, - check=False, - ) - return result.strip() != "" - - -async def is_detached_head(computing_environment: ComputingEnvironment) -> bool: - """Check if the current HEAD is detached.""" - result = await computing_environment.run_git(("rev-parse", "--abbrev-ref", "HEAD"), is_error_logged=False) - return result.strip() == "HEAD" - - -async def set_branch_description( - computing_environment: ComputingEnvironment, branch_name: str, description: str -) -> None: - """Set the description of the given branch.""" - await computing_environment.run_git(("config", f"branch.{branch_name}.description", description)) - - -async def get_branch_commit(computing_environment: ComputingEnvironment, branch_name: str) -> str: - """Get the commit of the given branch.""" - return await computing_environment.run_git(("rev-parse", branch_name)) - - -async def get_all_branch_names_pointing_to_commit( - computing_environment: ComputingEnvironment, commit_hash: str -) -> tuple[str, ...]: - """Get all branch names that point to the given commit.""" - result = await computing_environment.run_git( - ( - "for-each-ref", - "refs/heads/", - "--format='%(refname:short)'", - "--points-at", - commit_hash, - ) - ) - branch_names = tuple(result.splitlines()) - # strip the quotes - return tuple(branch_name.strip("'") for branch_name in branch_names) - - -async def is_branch_child_of_branch( - computing_environment: ComputingEnvironment, - child_branch_name: str, - parent_branch_name: str, -) -> bool: - """Check if the given branch is a child of the parent branch.""" - try: - await computing_environment.run_git( - ("merge-base", "--is-ancestor", parent_branch_name, child_branch_name), - is_error_logged=False, - ) - return True - except RunCommandError as e: - if e.stderr.strip() == "" and e.returncode == 1: - # we expect this command to give an empty stderr and a return code of 1 - # if the child branch is not an ancestor of the parent branch - return False - raise - - -async def is_commit_on_branch( - computing_environment: ComputingEnvironment, - commit_hash: str, - branch_name: str, - local_only: bool = True, -) -> bool: - """Check if the given commit is on the given branch.""" - if local_only: - result = await computing_environment.run_git(("branch", "--contains", commit_hash)) - else: - result = await computing_environment.run_git(("branch", "-a", "--contains", commit_hash)) - return any(branch_name == x.strip() for x in result.splitlines()) - - -async def fetch_and_get_first_entry_in_fetch_head( - computing_environment: ComputingEnvironment, remote: str, fetch_refs: Sequence[str] -) -> str: - """Fetch the given refs from the remote and return the first entry in FETCH_HEAD.""" - refs_str = " ".join(fetch_refs) - command = [ - "bash", - "-c", - ( - f"git fetch {remote} {refs_str} && " - # get first commit from FETCH_HEAD - "git rev-parse FETCH_HEAD" - ), - ] - result = await run_command_with_retry_on_git_lock_error(computing_environment, command) - return result.strip() - - -async def fetch_branch(computing_environment: ComputingEnvironment, branch_name: str) -> None: - """Fetch the given branch from the remote.""" - await computing_environment.run_git(("fetch", "origin", branch_name)) - - -async def is_branch_present(computing_environment: ComputingEnvironment, branch_name: str) -> bool: - """Check if branch with given name is present.""" - result = await computing_environment.run_git(("branch",)) - return branch_name in result.splitlines() - - -async def create_reset_and_checkout_branch(computing_environment: ComputingEnvironment, branch_name: str) -> str: - """Create new branch with given name.""" - return await computing_environment.run_git(("switch", "-C", branch_name)) - - -async def switch_branch(computing_environment: ComputingEnvironment, branch_name: str) -> str: - """Switch to branch with given name.""" - return await computing_environment.run_git(("switch", branch_name)) - - -async def delete_branch(computing_environment: ComputingEnvironment, branch_name: str, delete_remote: bool) -> str: - """Delete branch with given name.""" - result = await computing_environment.run_git(("branch", "-D", branch_name)) - if delete_remote: - result = await computing_environment.run_git(("push", "origin", "--delete", branch_name)) - return result - - -async def update_branch_to_hash(computing_environment: ComputingEnvironment, branch_name: str, git_hash: str) -> None: - """Update the given branch to reference the given git hash.""" - # here we do it without checking out the branch - await computing_environment.run_git(("branch", "-f", branch_name, git_hash)) - - -async def switch_and_create_branch_if_needed(computing_environment: ComputingEnvironment, branch_name: str) -> str: - """Switch to new branch, creating it if it doesn't already exist.""" - if await is_branch_present(computing_environment, branch_name): - await switch_branch(computing_environment, branch_name) - else: - await create_reset_and_checkout_branch(computing_environment, branch_name) - return await get_branch_name(computing_environment) - - -async def merge_branches( - computing_environment: ComputingEnvironment, - base_branch_name: str, - merge_branch_name: str, - is_moving_to_base_branch: bool = True, -) -> str: - """Merge `merge_branch_name` into `base_branch_name`.""" - await switch_branch(computing_environment, base_branch_name) - await computing_environment.run_git(("merge", merge_branch_name)) - if not is_moving_to_base_branch: - await switch_branch(computing_environment, "-") - return await get_branch_name(computing_environment) - - -async def get_merge_base(computing_environment: ComputingEnvironment, branch_name: str, target_branch: str) -> str: - """Get the merge base of the given branch and target branch. - - The merge base is the most recent commit that is on both branches. - """ - return await computing_environment.run_git(["merge-base", branch_name, target_branch], is_error_logged=False) - - -async def checkout_hash(computing_environment: ComputingEnvironment, git_hash: str) -> str: - """Checkout given git hash.""" - return await computing_environment.run_git(("checkout", git_hash)) - - -async def force_add(computing_environment: ComputingEnvironment, *paths: str) -> None: - """Force-add the specified paths to the git index.""" - await computing_environment.run_git(("add", "-f", *paths)) - - -async def git_add(computing_environment: ComputingEnvironment, *paths: str) -> None: - """Add the specified paths to the git index.""" - await computing_environment.run_git(("add", *paths)) - - -def convert_datetime_to_git_timestamp(dt: datetime) -> str: - return datetime.isoformat(dt) - - -def convert_git_timestamp_to_datetime(timestamp: str) -> datetime: - return datetime.fromisoformat(timestamp) - - -def get_commit_ts_for_current_time() -> CommitTimestamp: - """Get the commit timestamp for the current time.""" - current_time = get_current_time() - return CommitTimestamp( - author_ts=convert_datetime_to_git_timestamp(current_time), - committer_ts=convert_datetime_to_git_timestamp(current_time), - ) - - -def _convert_time_to_commit_ts( - time: str | datetime | CommitTimestamp | None, -) -> CommitTimestamp: - if time is None: - return get_commit_ts_for_current_time() - elif isinstance(time, datetime): - return CommitTimestamp( - author_ts=convert_datetime_to_git_timestamp(time), - committer_ts=convert_datetime_to_git_timestamp(time), - ) - elif isinstance(time, CommitTimestamp): - return time - else: - # assume it's a git timestamp - return CommitTimestamp(author_ts=time, committer_ts=time) - - -async def make_commit( - computing_environment: ComputingEnvironment, - commit_message: str, - allow_empty: bool = False, - amend: bool = False, - commit_time: str | datetime | CommitTimestamp | None = None, -) -> str: - if commit_message.strip() == "": - commit_message = "No commit message provided" - - commit_ts = _convert_time_to_commit_ts(commit_time) - time_args = f'GIT_AUTHOR_DATE="{commit_ts.author_ts}" GIT_COMMITTER_DATE="{commit_ts.committer_ts}" ' - - commit_message = shlex.quote(commit_message) - no_changes_message = "No changes to commit" - amend_args = "--amend " if amend else "" - if allow_empty or amend: - bash_command = f"""git add . && {time_args}git commit {amend_args}--allow-empty -m {commit_message} > /dev/null && git rev-parse HEAD""" - else: - bash_command = f"""git add . && ( git status | grep -q "nothing to commit" && echo "{no_changes_message}" ) || ( {time_args}git commit {amend_args}-m {commit_message} > /dev/null && git rev-parse HEAD )""" - - with Section(f"committing changes with message: '{commit_message}'", log_level="DEBUG"): - stdout = await run_command_with_retry_on_git_lock_error( - computing_environment, - ["bash", "-c", bash_command], - ) - stdout = stdout.strip() - if stdout == no_changes_message: - raise FailedToMakeCommitError(f"Failed to make commit with message: {commit_message}. {bash_command=}") - new_git_hash = stdout - return new_git_hash - - -async def get_tree_hash_for_commit(computing_environment: ComputingEnvironment, commit: str) -> str: - """Get the tree hash for the given commit.""" - return await computing_environment.run_git(["rev-parse", commit + "^{tree}"]) - - -async def get_commit_timestamp(computing_environment: ComputingEnvironment, commit: str) -> CommitTimestamp: - """Get the commit timestamp for the given commit.""" - split_token = "<|>" - result = await computing_environment.run_git(["show", "-s", "--format=%aI<|>%cI", commit]) - author_ts, committer_ts = result.split(split_token) - return CommitTimestamp(author_ts=author_ts.strip(), committer_ts=committer_ts.strip()) - - -async def tag_commit(computing_environment: ComputingEnvironment, tag: str, commit_hash: str) -> None: - """Tag the given commit with the given tag.""" - # We use -f to force the tag to be created even if it already exists. - await computing_environment.run_git(("tag", "-f", tag, commit_hash)) - - -async def git_push(computing_environment: ComputingEnvironment, branch_name: str) -> str: - """Push changes to remote branch with given name.""" - return await computing_environment.run_git(("push", "origin", branch_name)) - - -async def force_push(computing_environment: ComputingEnvironment, branch_name: str) -> str: - """Push changes to remote branch with given name.""" - return await computing_environment.run_git(("push", "--force", "origin", branch_name)) - - -async def force_push_commit_with_retry( - computing_environment: ComputingEnvironment, - commit: str, - branch_name: str, - timeout: float = 30.0, -) -> None: - start_time = time.monotonic() - sleep_time = 0.5 - while True: - try: - await force_push_commit(computing_environment, commit, branch_name) - break - except Exception as exc: - if time.monotonic() - start_time > timeout: - raise TimeoutError( - f"Timeout reached: Could not force push {commit} to {branch_name} in {timeout} seconds." - ) from exc - logger.info("Force push of {} to {} failed; trying again...", commit, branch_name) - await asyncio.sleep(sleep_time) - sleep_time *= 2 - - -async def force_push_commit(computing_environment: ComputingEnvironment, commit: str, branch_name: str) -> None: - try: - await computing_environment.run_git(["push", "-f", "origin", f"{commit}:{branch_name}"], is_error_logged=False) - except RunCommandError as e: - if "fatal: bad object" in e.stderr: - # TODO : We're retrying failed fetches here. However, there is also a separate - # force_push_commit_with_retry method that retries the entire force_push_commit. - # We should probably try the fetch only once, and then rely on the outer - # force_push_commit_with_retry to retry the entire force_push_commit call when retrying is = - # desired? - NUM_TRIES = 3 - for _ in range(NUM_TRIES): - try: - await computing_environment.run_git(["fetch", "origin", commit], is_error_logged=False) - except RunCommandError as fetch_e: - if "not our ref" in fetch_e.stderr: - # FIXME: actually, this has been getting worse... I suspect perhaps rate limiting or something? We are checking thing out much more than usual... - await asyncio.sleep(2) - else: - raise fetch_e - else: - start_time = time.monotonic() - while time.monotonic() - start_time < 10: - try: - await computing_environment.run_git(["push", "-f", "origin", f"{commit}:{branch_name}"]) - except RunCommandError as repush_e: - if "not our ref" in repush_e.stderr: - # FIXME: actually, this has been getting worse... I suspect perhaps rate limiting or something? We are checking thing out much more than usual... - await asyncio.sleep(2) - else: - raise repush_e - else: - return - raise Exception(f"Could not force push commit {commit}") - raise Exception(f"Could not fetch commit {commit} to force push it") - else: - raise - - -async def get_staged_files( - computing_environment: ComputingEnvironment, -) -> tuple[str, ...]: - """Get list of all files in repo that are currently staged.""" - result = await computing_environment.run_git(("diff", "--full-index", "--binary", "--name-only", "--cached")) - return tuple(result.splitlines()) - - -async def get_unstaged_files( - computing_environment: ComputingEnvironment, -) -> tuple[str, ...]: - """Get list of all files in repo that are currently unstaged.""" - result = await computing_environment.run_git(("diff", "--full-index", "--binary", "--name-only")) - return tuple(result.splitlines()) - - -async def restore_all_staged_files(computing_environment: ComputingEnvironment) -> None: - """Restore all staged files.""" - await computing_environment.run_git(("restore", "--staged", ".")) - - -async def restore_all_unstaged_changes( - computing_environment: ComputingEnvironment, -) -> None: - """Restore all unstaged changes.""" - await computing_environment.run_git(("restore", ".")) - - -async def apply_patch_via_git_with_conflict_markers( - computing_environment: ComputingEnvironment, - git_diff: str, - is_error_logged: bool = True, -) -> None: - """Apply a diff to repo with conflict markers.""" - if git_diff.strip() == "": - return - if not git_diff.endswith("\n"): - # git requires a newline at the end of the patch - git_diff += "\n" - patch_file = _get_temp_patch_file() - try: - await computing_environment.write_file(patch_file, git_diff) - await computing_environment.run_command( - [ - "bash", - "-c", - f"git add . && git apply --verbose {patch_file} || git apply -3 --verbose {patch_file}", - ], - is_error_logged=is_error_logged, - ) - except RunCommandError as e: - raise PatchApplicationError(f"Failed to apply patch: {e}") from e - finally: - await computing_environment.delete_file(patch_file) - - -async def is_repo_conflicted(computing_environment: ComputingEnvironment) -> bool: - output = await computing_environment.run_git(["status"], is_error_logged=False, check=False) - if "Unmerged paths:" in output: - return True - return False - - -async def get_head_hash(computing_environment: ComputingEnvironment) -> str: - """Get the hash of the current HEAD commit.""" - git_hash = await computing_environment.run_git(["rev-parse", "HEAD"]) - assert len(git_hash) == 40, f"Expected 40-character git hash, got {git_hash}" - return git_hash - - -async def get_parent_commit_hash(computing_environment: ComputingEnvironment, commit_hash: str) -> str: - """Get the parent commit hash of the given commit hash.""" - git_hash = await computing_environment.run_git(["rev-parse", f"{commit_hash}^"]) - assert len(git_hash) == 40, f"Expected 40-character git hash, got {git_hash}" - return git_hash - - -async def get_most_recent_sibling_branch_of_branch( - computing_environment: ComputingEnvironment, target_branch_name: str -) -> tuple[str, ...] | None: - """Get the most recent sibling branch of the given branch name. - - This is the first branch that shares a common ancestor commit with the given branch name. - Note, that it is possible that there are multiple branches that share the most recent common ancestor. - In this case, we return all of them. - - Also if there are no sibling branches (either at all or within the max lookback), we return None. - """ - # FIXME: this is imprecise -- it really just needs to be head 2, but I didn't want to figure out the regex... - output = await computing_environment.run_command( - [ - "bash", - "-c", - f'git log --decorate=full --oneline {target_branch_name} | grep "refs/heads/" | head -n 10', - ], - check=True, - ) - for line in output.splitlines(): - branch_list_string = line.split(" (", maxsplit=1)[-1].split(")", maxsplit=1)[0] - parsed_branch_names = [] - for branch_name_string in branch_list_string.split(", "): - parsed_branch_name = branch_name_string.rsplit(" ")[-1] - if parsed_branch_name.startswith("refs/heads") and parsed_branch_name != f"refs/heads/{target_branch_name}": - parsed_branch_names.append(parsed_branch_name.replace("refs/heads/", "", 1)) - # otherwise, just return the first one - if len(parsed_branch_names) > 0: - return tuple(parsed_branch_names) - return None - - -async def get_nth_commit_ago( - computing_environment: ComputingEnvironment, - commit_hash: str, - n: int, - is_error_logged: bool = True, -) -> str: - """Get the nth commit ago of the given commit hash.""" - git_hash = await computing_environment.run_git(["rev-parse", f"{commit_hash}~{n}"], is_error_logged=is_error_logged) - assert len(git_hash) == 40, f"Expected 40-character git hash, got {git_hash}" - return git_hash - - -async def get_initial_repo_commit_hash(computing_environment: ComputingEnvironment, commit_hash: str = "HEAD") -> str: - """Get the initial commit hash of the repo. - - As written, if invoked on an empty repo with no commits (immediately after `git init`), this fails with: - `fatal: ambiguous argument 'HEAD': unknown revision or path not in the working tree.` - """ - # --max-parents=0: only consider commits with no parents - # --date-order: sort by date (newest first) - output = await computing_environment.run_git(["rev-list", "--max-parents=0", commit_hash, "--date-order"]) - # assume the oldest commit with no parents is the initial repo commit - all_root_commits = output.splitlines() - root_commit = all_root_commits[-1] - assert len(root_commit) == 40, f"Expected 40-character git hash, got {root_commit}" - return root_commit - - -async def get_upto_nth_commit_ago(computing_environment: ComputingEnvironment, commit_hash: str, n: int) -> str: - """Get the commit hash of the upto nth commit ago of the given commit hash. - - If the commit history is shorter than n, it will return the first commit. - """ - try: - return await get_nth_commit_ago(computing_environment, commit_hash, n, is_error_logged=False) - except RunCommandError as e: - if "unknown revision or path not in the working tree" in e.stderr: - return await get_initial_repo_commit_hash(computing_environment, commit_hash) - raise - - -async def get_commit_message(computing_environment: ComputingEnvironment, commit_hash: str) -> str: - """Get the commit message of the given commit hash.""" - return await computing_environment.run_git(["log", "-1", "--pretty=%B", commit_hash]) - - -async def get_commit_count_between_hashes( - computing_environment: ComputingEnvironment, old_hash: str, new_hash: str -) -> int: - """Get the number of commits between two hashes.""" - output = await computing_environment.run_git(["rev-list", "--count", f"{old_hash}..{new_hash}"]) - return int(output.strip()) - - -async def fetch_and_checkout_hash( - computing_environment: ComputingEnvironment, - git_hash: str, - is_error_logged: bool = True, -) -> None: - await computing_environment.run_command( - ["bash", "-c", f"git fetch origin {git_hash} && git checkout {git_hash}"], - is_error_logged=is_error_logged, - ) - - -async def fetch_ref_and_checkout_hash( - computing_environment: ComputingEnvironment, - ref: str, - git_hash: str, - is_error_logged: bool = True, -) -> None: - await computing_environment.run_command( - ["bash", "-c", f"git fetch origin {ref} && git checkout {git_hash}"], - is_error_logged=is_error_logged, - ) - - -@retry( - stop=stop_after_attempt(5), - wait=wait_random_exponential(min=0.5, max=30, exp_base=2), - reraise=True, - retry=retry_all(retry_if_exception_type(RunCommandError)), - before_sleep=log_before_sleep, -) -async def wait_for_git_hash_to_checkout(computing_environment: ComputingEnvironment, git_hash: str) -> None: - await fetch_and_checkout_hash(computing_environment, git_hash, is_error_logged=False) - - -async def wait_for_git_hash_with_ref_to_checkout( - computing_environment: ComputingEnvironment, - git_hash: str, - ref: str, - timeout: float = 20.0, -) -> None: - with Section(f"Checking out git hash {git_hash} with ref {ref}", log_level="DEBUG"): - start_time = time.monotonic() - while True: - try: - await fetch_ref_and_checkout_hash(computing_environment, ref, git_hash, is_error_logged=False) - break - except RunCommandError as exc: - if time.monotonic() - start_time > timeout: - log_exception( - exc, - "Timeout reached: Git hash {git_hash} is not available after {timeout} seconds.", - git_hash=git_hash, - timeout=timeout, - ) - raise TimeoutError(f"Timeout reached: Git hash {git_hash} is not available.") from exc - logger.info("Waiting for git hash {} to be available...", git_hash) - await asyncio.sleep(0.5) - - -async def force_checkout_git_hash_immediate_on_branch( - computing_environment: ComputingEnvironment, git_hash: str, branch_name: str -) -> None: - # Here we clear any uncommited changes, then change branch, before checking out the new hash - # We need to do these three steps so we don't affect the currently checked out branch - command = f"git reset --hard && git checkout -B {branch_name} && git reset --hard {git_hash}" - logger.debug("Running command: {}", command) - await run_command_with_retry_on_git_lock_error( - computing_environment, - [ - "bash", - "-c", - command, - ], - ) - - -async def get_git_folder_paths( - computing_environment: ComputingEnvironment, -) -> tuple[str, ...]: - """Get the paths of all the git folders in the repo.""" - result = await computing_environment.run_command(["ls", ".git"]) - return tuple(result.splitlines()) - - -async def apply_patch_via_git( - computing_environment: ComputingEnvironment, git_diff: str, is_error_logged: bool -) -> None: - """Apply a diff to repo.""" - if git_diff.strip() == "": - return - patch_file = _get_temp_patch_file() - await computing_environment.write_file(patch_file, git_diff) - # NOTE: --allow-empty is necessary because the patch may be empty, or result in no changes - # update (2024-11-22) --allow-empty is not available in the git version in our devcontainers - # so we have to do a janky error check below - try: - await computing_environment.run_git(("apply", "--verbose", str(patch_file)), is_error_logged=is_error_logged) - except RunCommandError as e: - raise PatchApplicationError(f"Failed to apply patch: {e}") from e - finally: - await computing_environment.delete_file(patch_file) - - -async def get_git_diff( - computing_environment: ComputingEnvironment, - commit_hash: str | None = None, - staged: bool = False, - is_error_logged: bool = True, - include_binary: bool = True, -) -> str: - """Get the diff for the current repo state.""" - # make sure `is_stripped=False` otherwise patch can be invalid - command = ["diff", "--full-index"] - if include_binary: - # Without --binary, diffs of binary files will just contain a summary statement such as "Binary files a/file.bin and b/file.bin differ". - # Such diffs cannot be applied, but are useful for inclusion in LLM prompts. - command.append("--binary") - if staged: - command.append("--staged") - if commit_hash: - command.append(commit_hash) - return await computing_environment.run_git(command, is_stripped=False, is_error_logged=is_error_logged) - - -async def get_diff_between_hashes(computing_environment: ComputingEnvironment, old_hash: str, new_hash: str) -> str: - """Get the diff between two git hashes.""" - # make sure `is_stripped=False` otherwise patch can be invalid - return await computing_environment.run_git( - ["diff", "--full-index", "--binary", old_hash, new_hash], is_stripped=False - ) - - -async def get_patch_for_commit(computing_environment: ComputingEnvironment, commit_hash: str) -> str: - """Get the patch for a given commit hash.""" - return await computing_environment.run_git(["show", "--pretty=format:", "--patch", commit_hash], is_stripped=False) - - -async def get_untracked_files( - computing_environment: ComputingEnvironment, -) -> tuple[str, ...]: - """Get the untracked files in the repo.""" - result = await computing_environment.run_git(["ls-files", "--others", "--exclude-standard"], is_error_logged=False) - return tuple([line.strip() for line in result.splitlines() if line.strip()]) - - -async def get_untracked_file_diff( - computing_environment: ComputingEnvironment, - file_path: str, - include_binary: bool = True, -) -> str: - """Get the diff for a untracked file. - - Note this function will raise a RunCommandError if the there is no diff for the untracked file or if there - is another error running the command. So it is best to use this function after checking that the file is untracked - using get_untracked_files function. - """ - command = ["diff", "--no-index"] - if include_binary: - command.append("--binary") - untracked_diff = await computing_environment.run_git( - command + ["/dev/null", str(file_path)], - # Unfortunately, `--no-index` implies `--exit-code`, which will cause git diff to return an exit code of 1 - # if the diff is not empty. So we can't use check=True here. We'll check for an empty output to detect failures. - check=False, - is_error_logged=True, - is_stripped=False, - ) - if not untracked_diff: - raise RunCommandError(f"Unable to diff untracked file {file_path}") - return untracked_diff - - -async def get_staged_unstaged_and_combined_diffs( - computing_environment: ComputingEnvironment, base_commit: str = "HEAD" -) -> tuple[str, str, str]: - """Get the staged diff, the unstaged diff, and the combined diff""" - staged_diff = await get_git_diff(computing_environment, staged=True) - unstaged_diff = await get_git_diff(computing_environment, staged=False) - combined_diff = await get_git_diff(computing_environment, commit_hash=base_commit) - return staged_diff, unstaged_diff, combined_diff - - -async def get_unmerged_blob_hashes( - computing_environment: ComputingEnvironment, -) -> tuple[str, ...]: - """Get the blob hashes of all the unmerged files in the repo.""" - result = await computing_environment.run_command( - ["bash", "-c", "git ls-files --unmerged | awk '{print $2}' | sort -u"], - check=False, - ) - return tuple(line.strip() for line in result.splitlines() if line.strip() != "") - - -async def get_staged_blob_hashes( - computing_environment: ComputingEnvironment, -) -> tuple[str, ...]: - """Get the blob hashes of all the staged files in the repo.""" - staged_blob_hashes = await computing_environment.run_command( - [ - "bash", - "-c", - 'staged_blobs=$(git diff --full-index --binary --cached --name-only --diff-filter=ACMRT | while read file; do git ls-files --stage "$file" | awk \'{print $2}\'; done); echo "$staged_blobs"', - ] - ) - return tuple(line.strip() for line in staged_blob_hashes.splitlines() if line.strip() != "") - - -async def get_blob_content_by_hash(computing_environment: ComputingEnvironment, blob_hash: str) -> bytes: - """Get the content of a blob by its hash.""" - result = await computing_environment.run_git(["cat-file", "-p", blob_hash], is_stripped=False) - return result.encode("utf-8") - - -async def get_unmerged_and_staged_blob_contents_by_hash( - computing_environment: ComputingEnvironment, -) -> dict[str, bytes]: - """Get the contents of all the unmerged and staged blobs in the repo.""" - unmerged_blob_hashes = await get_unmerged_blob_hashes(computing_environment) - staged_blob_hashes = await get_staged_blob_hashes(computing_environment) - all_relevant_blob_hashes = unmerged_blob_hashes + staged_blob_hashes - blob_content_tasks_by_hash = { - blob_hash: get_blob_content_by_hash(computing_environment, blob_hash) for blob_hash in all_relevant_blob_hashes - } - blob_contents = await asyncio.gather(*blob_content_tasks_by_hash.values()) - return { - blob_hash: blob_content for blob_hash, blob_content in zip(blob_content_tasks_by_hash.keys(), blob_contents) - } - - -async def write_blob_content(computing_environment: ComputingEnvironment, blob_hash: str, blob_content: bytes) -> None: - """Write the content of a blob to the repo.""" - # write the blob content to a temp file - temp_file = anyio.Path(f"/tmp/{blob_hash}") - try: - await computing_environment.write_file(temp_file, blob_content, mode="wb") - # write the blob to the repo - result = await computing_environment.run_git(["hash-object", "-w", str(temp_file)]) - assert result.strip() == blob_hash, f"Expected blob hash {blob_hash}, got {result.strip()}" - finally: - await computing_environment.delete_file(temp_file) - - -async def write_blob_content_by_hash( - computing_environment: ComputingEnvironment, blob_content_by_hash: dict[str, bytes] -) -> None: - """Write the content of all the blobs to the repo.""" - tasks = [] - for blob_hash, blob_content in blob_content_by_hash.items(): - tasks.append(write_blob_content(computing_environment, blob_hash, blob_content)) - await asyncio.gather(*tasks) - - -async def get_modified_files_with_conflicts( - computing_environment: ComputingEnvironment, -) -> tuple[str, ...]: - """Get the modified files with conflicts.""" - commands = [ - "diff --check --staged --full-index --binary", - "diff --check --full-index --binary", - ] - conflicted_files = set() - for command in commands: - result = await computing_environment.run_git(command.split(), check=False, is_error_logged=False) - # output is of the form: - # test.txt:2: leftover conflict marker - - for line in result.splitlines(): - parts = line.split(":", maxsplit=1) - if len(parts) == 2: - file_path, message = parts - if "leftover conflict marker" in message.strip(): - conflicted_files.add(file_path) - return tuple(conflicted_files) - - -async def get_conflicted_pathnames( - computing_environment: ComputingEnvironment, -) -> tuple[str, ...]: - """Get the pathnames of all the conflicted files in the repo.""" - result = await computing_environment.run_git(["diff", "--full-index", "--binary", "--name-only", "--diff-filter=U"]) - return tuple(result.splitlines()) - - -async def get_conflicted_contents_by_path( - computing_environment: ComputingEnvironment, -) -> dict[str, bytes]: - """Get the contents of all the conflicted files in the repo.""" - conflicted_files = await get_conflicted_pathnames(computing_environment) - conflicted_contents_by_path: dict[str, bytes] = {} - for file_path in conflicted_files: - content = await computing_environment.read_file(file_path, mode="rb") - assert isinstance(content, bytes), f"Expected bytes, got {type(content)}" - conflicted_contents_by_path[file_path] = content - return conflicted_contents_by_path - - -async def get_modified_pathnames( - computing_environment: ComputingEnvironment, -) -> tuple[str, ...]: - """Get the pathnames of all the modified files in the repo.""" - result = await computing_environment.run_command(["bash", "-c", "git status --porcelain | awk '{print $2}'"]) - return tuple(result.splitlines()) - - -async def get_modified_file_contents_by_path( - computing_environment: ComputingEnvironment, -) -> dict[str, bytes]: - """Get the contents of all the modified files in the repo.""" - modified_pathnames = await get_modified_pathnames(computing_environment) - modified_file_contents_by_path: dict[str, bytes] = {} - for pathname in modified_pathnames: - content = await computing_environment.read_file(pathname, mode="rb") - assert isinstance(content, bytes), f"Expected bytes, got {type(content)}" - modified_file_contents_by_path[pathname] = content - return modified_file_contents_by_path - - -async def get_repo_url(computing_environment: ComputingEnvironment) -> str: - repo_url = await computing_environment.run_git(["remote", "get-url", "origin"]) - if repo_url.startswith("git@"): - # convert ssh url to https - repo_url = repo_url.replace(":", "/") - repo_url = f"https://{repo_url[4:]}" - if "https://oauth2:" in repo_url: - # remove the oauth2 prefix - # repo_url is something like https://oauth2:{token}@gitlab.com/.../.git - # change it to https://gitlab.com/.../.git - suffix = repo_url.split("@")[-1] - repo_url = "https://" + suffix - return repo_url - - -async def get_main_branch_name_for_repo( - computing_environment: ComputingEnvironment, default_branch: str | None = None -) -> str: - """Get the name of the main branch for the repo. - - Attempts to detect whether the repository uses 'main', 'master', or another name - as its primary branch by checking for common branch names in order of preference. - """ - possible_main_branches = ["main", "master", "trunk", "development"] - - if default_branch is not None and default_branch not in possible_main_branches: - possible_main_branches.insert(0, default_branch) - - # First check if any of the common main branch names exist - # and return the first one that does - for branch in possible_main_branches: - if await is_branch_exists(computing_environment, branch): - return branch - - # If we couldn't find a common main branch, try to determine the default branch - # This gets the branch that HEAD points to in a newly cloned repo - default_remote_branch = await computing_environment.run_git( - ["symbolic-ref", "refs/remotes/origin/HEAD"], is_error_logged=False - ) - if default_remote_branch: - # Format is typically refs/remotes/origin/main, so extract the last part - default_remote_branch = default_remote_branch.strip().split("/")[-1] - return default_remote_branch - raise ValueError("Could not detect main branch for repo.") diff --git a/imbue_core/imbue_core/computing_environment/data_types.py b/imbue_core/imbue_core/computing_environment/data_types.py @@ -1,31 +0,0 @@ -from __future__ import annotations - -import subprocess -from pathlib import Path -from typing import Any - -import anyio - -# Use AnyPath type to match Sanctum -AnyPath = Path | str | anyio.Path - - -class RunCommandError(subprocess.CalledProcessError): - """Custom exception for errors encountered during Git commands.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - self.cwd = kwargs.get("cwd", None) - if "cwd" in kwargs: - del kwargs["cwd"] - super().__init__(*args, **kwargs) - - def __str__(self) -> str: - return f"Command `{self.cmd}` returned non-zero exit status {self.returncode}.\nOutput: {self.stdout}\nError: {self.stderr}\nCWD: {self.cwd}" - - -class PatchApplicationError(Exception): - """Custom exception for errors encountered during patch application.""" - - -class FailedToMakeCommitError(Exception): - """Custom exception for errors encountered during commit creation.""" diff --git a/imbue_core/imbue_core/error_utils.py b/imbue_core/imbue_core/error_utils.py @@ -1,28 +0,0 @@ -"""Error handling utilities.""" - -import sys -import traceback - -import traceback_with_variables -from traceback_with_variables import Format - - -def get_traceback_with_vars(exception: BaseException | None = None) -> str: - - # be careful of potential performance regressions with increasing these limits - tb_format = Format(max_value_str_len=100_000, max_exc_str_len=2_000_000) - if exception is None: - # no exception passed in; get the current exception. this will still be None if not in an exception handler - exception = sys.exception() - try: - if exception is not None: - # we are in an exception handler, use that for the traceback - # for some reason this breaks when casting to an `Exception`, so just using type: ignore - return traceback_with_variables.format_exc(exception, fmt=tb_format) # type: ignore - else: - # not in an exception handler, just get the current stack - return traceback_with_variables.format_cur_tb(fmt=tb_format) - except Exception as e: - return ( - f"got exception while formatting traceback with `traceback_with_variables`: {traceback.format_exception(e)}" - ) diff --git a/imbue_core/imbue_core/frozen_utils.py b/imbue_core/imbue_core/frozen_utils.py @@ -6,10 +6,7 @@ from typing import Any from typing import Iterable from typing import Mapping from typing import NoReturn -from typing import Protocol -from typing import Sequence from typing import TYPE_CHECKING -from typing import TypeAlias from typing import TypeVar from typing import cast @@ -17,13 +14,8 @@ if TYPE_CHECKING: from _typeshed import SupportsKeysAndGetItem -class _SupportsLessThan(Protocol): - def __lt__(self, __other: Any) -> bool: ... - - T = TypeVar("T") TV = TypeVar("TV") -TK = TypeVar("TK", bound=_SupportsLessThan) class FrozenMapping(Mapping[T, TV], ABC): @@ -91,10 +83,6 @@ class FrozenDict(dict[T, TV], FrozenMapping[T, TV]): return (FrozenDict, (dict(self),)) -def empty_mapping() -> FrozenDict[Any, Any]: - return FrozenDict() - - def deep_freeze_mapping(mapping: Mapping[T, TV]) -> FrozenDict[T, Any]: return FrozenDict({key: cast(TV, _deep_freeze_any(value)) for key, value in mapping.items()}) @@ -118,24 +106,3 @@ def _deep_freeze_any(input_object: object) -> object: return tuple(_freeze_iterable_values(input_object)) return input_object - - -def deep_freeze_sequence(sequence: Sequence[T]) -> tuple[Any, ...]: - return tuple(_freeze_iterable_values(sequence)) - - -# Recursive type alias that captures the possible types of JSON objects (e.g. from json.loads). -JSON: TypeAlias = "str | int | bool | float | None | dict[str, JSON] | list[JSON]" - - -# Immutable version of JSON. -FrozenJSON: TypeAlias = "str | int | bool | float | None | FrozenDict[str, FrozenJSON] | tuple[FrozenJSON, ...]" - - -def deep_freeze_json(json: JSON) -> FrozenJSON: - if isinstance(json, dict): - return FrozenDict({k: deep_freeze_json(v) for k, v in json.items()}) - elif isinstance(json, list): - return tuple(deep_freeze_json(v) for v in json) - else: - return json diff --git a/imbue_core/imbue_core/git.py b/imbue_core/imbue_core/git.py @@ -1,587 +0,0 @@ -"""Utility abstractions for interacting with git repositories.""" - -from __future__ import annotations - -import asyncio -import contextlib -import shlex -import shutil -import subprocess -import sys -from asyncio.subprocess import PIPE -from asyncio.subprocess import STDOUT -from contextlib import asynccontextmanager -from io import StringIO -from pathlib import Path -from types import TracebackType -from typing import Any -from typing import AsyncGenerator -from typing import AsyncIterator -from typing import Self -from typing import Sequence -from typing import TYPE_CHECKING -from typing import TextIO - -import anyio -import attr -from loguru import logger - -from imbue_core.async_monkey_patches import log_exception -from imbue_core.async_utils import sync -from imbue_core.async_utils import sync_contextmanager_func -from imbue_core.computing_environment.computing_environment import assert_repo_is_clean -from imbue_core.computing_environment.computing_environment import get_head_hash -from imbue_core.computing_environment.computing_environment import git_add -from imbue_core.computing_environment.computing_environment import is_repo_dirty -from imbue_core.computing_environment.computing_environment import make_commit -from imbue_core.computing_environment.computing_environment import ( - restore_all_staged_files, -) -from imbue_core.computing_environment.computing_environment import ( - restore_all_unstaged_changes, -) -from imbue_core.computing_environment.computing_environment import ( - run_command_with_retry_on_git_lock_error, -) -from imbue_core.computing_environment.data_types import AnyPath -from imbue_core.computing_environment.data_types import RunCommandError - -if TYPE_CHECKING: - # for proper file mode typing - from _typeshed import OpenBinaryMode - from _typeshed import OpenBinaryModeReading - from _typeshed import OpenBinaryModeWriting - from _typeshed import OpenTextMode - from _typeshed import OpenTextModeReading - from _typeshed import OpenTextModeWriting - -PYTHON_EXTENSION = ".py" - - -def is_path_in_git_repo(path: Path) -> bool: - """Check if a path is in a git repository.""" - if path.is_file(): - path = path.parent - completed_process = subprocess.run( - ["git", "-C", path, "rev-parse", "--is-inside-work-tree"], - stdin=subprocess.DEVNULL, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=False, - ) - if completed_process.returncode != 0: - return False - result = completed_process.stdout.decode().strip() - assert result in ("true", "false"), result - return result == "true" - - -def get_git_repo_root() -> Path: - """Gets a Path to the current git repo root, assuming that our cwd is somewhere inside the repo.""" - completed_process = subprocess.run( - ("git", "rev-parse", "--show-toplevel"), - stdin=subprocess.DEVNULL, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=True, - ) - root_dir = Path(completed_process.stdout.decode().strip()) - assert root_dir.is_dir(), f"{root_dir} must be a directory" - return root_dir - - -def get_git_repo_root_from_path(path: Path) -> Path: - """Gets a Path to the git repo root for the given path.""" - if path.is_file(): - path = path.parent - completed_process = subprocess.run( - ["git", "-C", path, "rev-parse", "--show-toplevel"], - stdin=subprocess.DEVNULL, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=True, - ) - root_dir = Path(completed_process.stdout.decode().strip()) - assert root_dir.is_dir(), f"{root_dir} must be a directory" - return root_dir - - -@attr.s(auto_attribs=True, frozen=True) -class LocalGitRepo: - """ - DEPRECATED: Unless you either need asyncio or the ability to run git commands remotely on non-local compute environments, consider using SyncLocalGitRepo - from simple_git.py instead. SimpleGitRepo provides a subset of the functions available through computing_environment.py + LocalGitRepo, but - in a single class made for synchronous use. - """ - - base_path: Path - - @classmethod - def build_from_cwd(cls) -> Self: - """Create a `LocalGitRepo` instance from the current working directory.""" - return cls(get_git_repo_root()) - - async def run_git( - self, - command: Sequence[str], - check: bool = True, - cwd: AnyPath | None = None, - is_error_logged: bool = True, - is_stripped: bool = True, - retry_on_git_lock_error: bool = True, - ) -> str: - """Run a git command in the repo. - - Example: - ``` - git_repo.run_git("status") - ``` - """ - # TODO: check for whether hooks should actually be run when we call this function - # Note: this used to be within an asyncio lock to prevent the program from concurrently running git commands. - # This lock was removed since it was within global state, a dangerous pattern, and wasn't preventing other users from interacting with the git repo. - command = ["git"] + list(command) - if not retry_on_git_lock_error: - result = await self.run_command(command, check=check, is_error_logged=is_error_logged, cwd=cwd) - else: - result = await run_command_with_retry_on_git_lock_error( - self, command, check=check, is_error_logged=is_error_logged, cwd=cwd - ) - if is_stripped: - return result.strip() - return result - - sync_run_git = sync(run_git) - - async def run_command( - self, - command: Sequence[str], - check: bool = True, - secrets: dict[str, str] | None = None, - cwd: AnyPath | None = None, - is_error_logged: bool = True, - ) -> str: - """Run a command in the repo. - - Note, this can be used to run any command, not just git. - """ - command_string = shlex.join(command) - logger.trace( - f"Running command: {command_string=} from cwd={cwd or self.base_path} with {secrets=} {check=} {is_error_logged=}" - ) - proc = await asyncio.create_subprocess_exec( - *command, - cwd=cwd or self.base_path, - stdin=subprocess.DEVNULL, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=secrets, - ) - # note, need to be carefull not to strip() lines since whitespace may be important (e.g. for diffs) - # return joined lines since mostly we only use the output for logging, and this way we arn't - # passing around lots of lists. Also it's easy to parse by lines if needed - stdout_bytes, stderr_bytes = await proc.communicate() - try: - stdout = stdout_bytes.decode("UTF-8") - except UnicodeDecodeError as e: - # If we don't encounter this, it likely means something was fixed upstream and we can safely delete - log_exception( - e, - "Command {command_string} failed to decode stdout, replacing any invalid bytes which could lead to problems later", - command_string=command_string, - ) - stdout = stdout_bytes.decode("UTF-8", errors="replace") - stderr = stderr_bytes.decode("UTF-8") - if check and proc.returncode != 0: - error_message = f"command run from cwd={self.base_path} failed with exit code {proc.returncode} and stdout:\n{stdout}\nstderr:\n{stderr}" - if is_error_logged: - logger.error( - f"command attempted: '{command_string}' from cwd={self.base_path}\nerror message: {error_message}" - ) - # this should not be None, but do this to satisfy type checker, int or None we throw the same error - returncode = proc.returncode or -1 - raise RunCommandError( - cmd=command_string, - stderr=stderr, - returncode=returncode, - cwd=cwd or self.base_path, - ) - return stdout - - @contextlib.asynccontextmanager - async def _open_file( - self, - relative_path: AnyPath, - cwd: AnyPath | None = None, - mode: OpenTextMode | OpenBinaryMode = "r", - mkdir_if_missing: bool = True, - ) -> AsyncGenerator[anyio.AsyncFile[Any], None]: - logger.trace("opening file {} in cwd {} with mode {}", relative_path, cwd, mode) - if cwd is not None: - sb_file_path = str(Path(cwd) / relative_path) - else: - sb_file_path = str(self.base_path / relative_path) - - if mkdir_if_missing: - parent_dir = anyio.Path(sb_file_path).parent - await parent_dir.mkdir(parents=True, exist_ok=True) - - f: anyio.AsyncFile[Any] | None = None - try: - f = await anyio.Path(sb_file_path).open(mode=mode) # type: ignore - yield f - finally: - if f is not None: - await f.aclose() - - async def write_file( - self, - relative_path: AnyPath, - content: str | bytes | None, - cwd: AnyPath | None = None, - mode: OpenTextModeWriting | OpenBinaryModeWriting = "w", - mkdir_if_missing: bool = True, - ) -> None: - if content is None: - await self.delete_file(relative_path, cwd=cwd) - return - - async with self._open_file(relative_path, cwd=cwd, mode=mode, mkdir_if_missing=mkdir_if_missing) as f: - logger.trace("writing to file {} in cwd {} with mode {}", relative_path, cwd, mode) - # pyre-fixme[6]: content can be bytes - await f.write(content) - - async def delete_file(self, relative_path: AnyPath, cwd: AnyPath | None = None) -> None: - logger.trace("deleting the file {} in cwd {}", relative_path, cwd) - if cwd is not None: - sb_file_path = str(Path(cwd) / relative_path) - else: - sb_file_path = str(self.base_path / relative_path) - await anyio.Path(sb_file_path).unlink() - - async def read_file( - self, - relative_path: AnyPath, - cwd: AnyPath | None = None, - mode: OpenTextModeReading | OpenBinaryModeReading = "r", - mkdir_if_missing: bool = True, - ) -> str | bytes: - async with self._open_file(relative_path, cwd=cwd, mode=mode, mkdir_if_missing=mkdir_if_missing) as f: - logger.trace("reading file {} in cwd {} with mode {}", relative_path, cwd, mode) - content = await f.read() - assert isinstance(content, str) or isinstance(content, bytes) - return content - - async def head_hash(self) -> str: - """Get the hash of the current HEAD commit.""" - return await get_head_hash(self) - - async def is_git_repo(self) -> bool: - """Check that repo is valid git repo.""" - return await anyio.Path(self.base_path / ".git").exists() - - sync_is_git_repo = sync(is_git_repo) - - async def assert_clean(self) -> None: - await assert_repo_is_clean(self) - - sync_assert_clean = sync(assert_clean) - - async def configure_git( - self, - git_user_name: str | None = None, - git_user_email: str | None = None, - initial_commit_message: str = "initial commit", - is_recreating: bool = False, - ) -> None: - """Configure git repo with user name and email.""" - if is_recreating: - if await self.is_git_repo(): - await asyncio.to_thread(shutil.rmtree, self.base_path / ".git") - - # order here is important - # ref https://stackoverflow.com/questions/11656761/git-please-tell-me-who-you-are-error?noredirect=1 - await self.run_git(("init",)) - if git_user_name: - await self.run_git(("config", "user.name", f"'{git_user_name}'")) - if git_user_email: - await self.run_git(("config", "user.email", f"'{git_user_email}'")) - await self.run_git(("add", ".")) - await self.run_git(("commit", "-m", f"'{initial_commit_message}'")) - branch_name = await self.run_git(("symbolic-ref", "HEAD")) - if not branch_name == "refs/heads/main": - # rename master to main for consistency - await self.run_git(("branch", "-m", "master", "main")) - - sync_configure_git = sync(configure_git) - - @asynccontextmanager - async def temporary_commit( - self, - tag_prefix: str, - commit_message: str, - raise_on_head_hash_change: bool = False, - ) -> AsyncIterator[str]: - """Context manager to make a temporary commit and tag in the repo.""" - await self.run_git(("commit", "-am", commit_message, "--allow-empty", "--no-verify")) - head_hash = await self.head_hash() - tag = f"{tag_prefix}/{head_hash}" - await self.run_git(("tag", "--force", tag)) - await self.run_git(("push", "origin", tag, "--no-verify")) - try: - yield head_hash - finally: - # This is susceptible to a race condition (if the user makes a commit between the time we check the head hash and the time we reset the state). - # So it's important to keep any block that uses this context manager short - make the commit, copy it to the controller, and work there. Don't hold the repo hostage. - current_head_hash = await self.head_hash() - if current_head_hash != head_hash and raise_on_head_hash_change: - raise AssertionError( - f"Head hash has changed from {head_hash} to {current_head_hash} since the temporary commit was made. Giving up on resetting git state, please address this manually." - ) - else: - await self.run_git(("reset", "HEAD~")) - - sync_temporary_commit = sync_contextmanager_func(temporary_commit) - - async def copy_repo(self, new_repo_path: Path, exists_ok: bool = True) -> "LocalGitRepo": - """Make a full copy of this repo in a new directory. - - Note, this will copy all the files in the repo into a new local directory, but will not handle - configuring the new directory as a git repo. - """ - if await anyio.Path(new_repo_path).exists(): - if not exists_ok: - raise FileExistsError( - f"New repo path '{new_repo_path} already exists. Set `exists_ok=True` if you are happy overwriting it, otherwise select new path." - ) - await asyncio.to_thread(shutil.rmtree, new_repo_path) - await asyncio.to_thread( - shutil.copytree, - self.base_path, - new_repo_path, - dirs_exist_ok=True, - ignore=shutil.ignore_patterns(".git", ".gitsecret"), - ) - return LocalGitRepo(new_repo_path) - - sync_copy_repo = sync(copy_repo) - - async def is_path_in_repo(self, file_path: str | Path | anyio.Path) -> bool: - """Check whether a given file path is within this repo. - - FIXME: It doesn't seem entirely necessary to enumerate all of the files with a particular extension - just to check if a single file (whose path we know) is in the repo. - """ - if isinstance(file_path, (str, Path)): - file_path = anyio.Path(file_path) - extension = file_path.suffix - return file_path in await self.get_all_files_by_extension(extension=extension) - - async def _get_file_path(self, file_path: str | Path) -> anyio.Path: - path = anyio.Path(file_path) - if not path.is_absolute(): - path = anyio.Path(self.base_path / path) - assert await path.exists(), f"File {path} does not exist." - return path - - async def safely_read_file_from_repo(self, file_path: str | Path) -> str: - """Safely read file from repo.""" - path = await self._get_file_path(file_path) - assert await self.is_path_in_repo(path), f"File {path} is not in repo." - return await path.read_text() - - sync_safely_read_file_from_repo = sync(safely_read_file_from_repo) - - async def get_all_files_by_extension(self, extension: str = PYTHON_EXTENSION) -> tuple[Path, ...]: - """Get absolute path of all files in the repo with given extension.""" - paths: list[Path] = [] - async for path in anyio.Path(self.base_path).rglob(f"*{extension}"): - paths.append(Path(path)) - return tuple(paths) - - -@attr.s(auto_attribs=True, frozen=True, kw_only=True) -class WritableLocalGitRepo(LocalGitRepo): - """A Local Git Repo with support for modifying files and reseting to an initial state. - - Note, this does not handle creating a copy of an existing repo, or anything. Rather it adds some additional - functionality for actually writing to files in the repo. For the creation of a separate copy of an existing - repo which supports making changes without affecting the main repo see the `temp_writable_local_git_repo` context - manager. - - It is also recommended the `build_from_repo` function when creating a WritableLocalGitRepo, as this will - make sure that any untracked and uncommited changes are managed correctly. - """ - - initial_git_hash: str - stash_git_hash: str | None - - @classmethod - async def build_from_repo(cls, repo: LocalGitRepo) -> "WritableLocalGitRepo": - """Create a writable repo from an local repo.""" - init_hash = await repo.head_hash() - if await is_repo_dirty(repo): - await repo.run_git(("add", ".")) - stash_hash = await make_commit(repo, "stashing uncommited and untracked changes") - else: - stash_hash = None - - return cls( - base_path=repo.base_path, - initial_git_hash=init_hash, - stash_git_hash=stash_hash, - ) - - async def _setup(self) -> None: - init_hash = await self.head_hash() - expected_hash = self.stash_git_hash or self.initial_git_hash - assert init_hash == expected_hash, "git repo is not currently at expected commit" - assert await self.is_git_repo(), f"{self.base_path} is not a git repo" - await self.assert_clean() - - async def reset(self) -> None: - """Reset the repo to the state it was in when this class was created.""" - await restore_all_staged_files(self) - await restore_all_unstaged_changes(self) - if self.stash_git_hash: - # hard to reset to commit with stashed untracked and uncommited changes - await self.run_git(("reset", "--hard", self.stash_git_hash)) - # soft reset to return to initial commit but keep the untracked and uncommited changes - await self.run_git(("reset", "--soft", self.initial_git_hash)) - # unstage untracked and uncommited changes - await restore_all_staged_files(self) - else: - await self.run_git(("reset", "--hard", self.initial_git_hash)) - await self.run_git(("clean", "-f")) - await self.run_git(("checkout", self.initial_git_hash)) - - current_git_hash = await self.head_hash() - assert ( - current_git_hash == self.initial_git_hash - ), f"base branch changed, current git hash ({current_git_hash}) != initial git hash ({self.initial_git_hash})" - await self.assert_clean() - - async def apply_change_to_file(self, file_path: str | Path, new_contents: str) -> None: - """Apply change to a single file.""" - path = await self._get_file_path(file_path) - assert await self.is_path_in_repo(path), f"File {path} is not in repo." - await path.write_text(new_contents) - await git_add(self, str(path)) - - async def __aenter__(self) -> "WritableLocalGitRepo": - await self._setup() - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - await self.reset() - - -async def copy_files_from_one_repo_to_another( - src_repo_path: Path, dst_repo_path: Path, relative_file_paths: Sequence[str | Path] -) -> None: - """Copies files from src to dst repo using the relative file paths.""" - for relative_path in relative_file_paths: - src_file_path = src_repo_path / relative_path - dst_file_path = anyio.Path(dst_repo_path / relative_path) - # make sure necessary directories exist in destination - await dst_file_path.parent.mkdir(parents=True, exist_ok=True) - await asyncio.to_thread(shutil.copy2, src_file_path, dst_file_path) - - -def get_repo_url_from_folder(repo_path: Path) -> str: - try: - repo_url = subprocess.check_output( - ["git", "remote", "get-url", "origin"], - cwd=repo_path, - universal_newlines=True, - ).strip() - except subprocess.CalledProcessError: - raise - else: - if repo_url.startswith("git@"): - # convert ssh url to https - repo_url = repo_url.replace(":", "/") - repo_url = f"https://{repo_url[4:]}" - if "https://oauth2:" in repo_url: - # remove the oauth2 prefix - # repo_url is something like https://oauth2:{token}@gitlab.com/.../.git - # change it to https://gitlab.com/.../.git - # This will happen if repo was originallycloned using oauth2 - suffix = repo_url.split("@")[-1] - repo_url = "https://" + suffix - return repo_url - - -def get_repo_base_path() -> Path: - working_directory = Path(__file__).parent - try: - return Path( - _run_command_and_capture_output(["git", "rev-parse", "--show-toplevel"], cwd=working_directory).strip() - ) - except subprocess.CalledProcessError as e: - try: - return working_directory.parents[1] - except IndexError: - raise UnableToFindRepoBase() from e - - -def _run_command_and_capture_output(args: Sequence[str], cwd: Path | None = None) -> str: - arg_str = " ".join(shlex.quote(arg) for arg in args) - print(f"Running command: {arg_str}", file=sys.stderr) - with subprocess.Popen(args, text=True, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as proc: - with StringIO() as output: - _handle_output(proc, output, sys.stderr) - if proc.wait() != 0: - raise subprocess.CalledProcessError(proc.returncode, cmd=args, output=output.getvalue()) - return output.getvalue() - - -class UnableToFindRepoBase(Exception): - """Raised when the base of the repository cannot be found.""" - - -def _handle_output(process: subprocess.Popen[str], *files: TextIO) -> None: - process_stdout = process.stdout - assert process_stdout is not None - while True: - output = process_stdout.read(1) - if output: - for f in files: - f.write(output) - elif process.poll() is not None: - break - - -def get_diff_without_index(diff: str) -> str: - new_lines = [] - for line in diff.splitlines(): - if line.startswith("index "): - # We replace index lines with "index 0000000..0000000 100644" because: - # - `0000000..0000000` ensures no real object hashes are referenced, making the diff neutral. - # - `100644` is the standard file mode for non-executable files in git diffs, ensuring compatibility. - # - This keeps the diff format valid while removing specific index information. - new_lines.append("index 0000000..0000000 100644") - else: - new_lines.append(line) - return "\n".join(new_lines).strip() - - -def is_diffs_without_index_equal(diff_1: str, diff_2: str) -> bool: - return get_diff_without_index(diff_1) == get_diff_without_index(diff_2) - - -# Copy-pasted from imbue to avoid moving the whole hammers machinery over to imbue-core. -async def get_lines_from_process(shell_command: str, is_exit_code_validated: bool = True, **kwargs: Any) -> list[str]: - p = await asyncio.create_subprocess_shell(shell_command, stdin=PIPE, stdout=PIPE, stderr=STDOUT, **kwargs) - lines = [x.decode("UTF-8") for x in (await p.communicate())[0].splitlines()] - if is_exit_code_validated: - joined_lines = "\n".join(lines) - assert ( - p.returncode == 0 - ), f"command failed: {shell_command}\nwith output:\n{joined_lines} with exit code {p.returncode}" - return lines diff --git a/imbue_core/imbue_core/git_data_types.py b/imbue_core/imbue_core/git_data_types.py @@ -1,35 +0,0 @@ -from datetime import datetime -from typing import Annotated - -from pydantic.functional_validators import PlainValidator - -from imbue_core.pydantic_serialization import FrozenModel -from imbue_core.pydantic_serialization import SerializableModel - - -def _validate_git_timestamp(value: str) -> str: - try: - datetime.fromisoformat(value) - return value - except ValueError: - raise ValueError(f"Invalid git timestamp: {value}") - - -class CommitTimestamp(SerializableModel): - author_ts: Annotated[str, PlainValidator(_validate_git_timestamp)] - committer_ts: Annotated[str, PlainValidator(_validate_git_timestamp)] - - -class CommitMetadata(FrozenModel): - commit: str - tree_hash: str - message: str - commit_time: CommitTimestamp - - @property - def body(self) -> str: - return self.message.split("\n", 1)[-1] - - @property - def subject(self) -> str: - return self.message.split("\n", 1)[0] diff --git a/imbue_core/imbue_core/ids.py b/imbue_core/imbue_core/ids.py @@ -1,44 +0,0 @@ -from typing import Any -from typing import Self - -from pydantic import GetCoreSchemaHandler -from pydantic_core import core_schema - - -class NonEmptyStr(str): - # pyre-fixme[11]: pyre seems to have some trouble with Self in some specific cases, including type[Self] - def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self: - value = str.__new__(cls, *args, **kwargs) - if len(value) == 0: - raise ValueError("NonEmptyStr cannot be empty") - return value - - @classmethod - def __get_pydantic_core_schema__(cls, source_type: type, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: - """ - Support transparently deserializing strings into ObjectID instances and vice versa. - """ - return core_schema.no_info_before_validator_function( - lambda raw_value: (cls(raw_value) if isinstance(raw_value, str) else raw_value), - core_schema.union_schema( - [ - core_schema.is_instance_schema(cls), - core_schema.str_schema(), - ] - ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: str(instance), return_schema=core_schema.str_schema() - ), - ) - - -class ExternalID(NonEmptyStr): - pass - - -class AssistantMessageID(ExternalID): - pass - - -class ToolUseID(ExternalID): - pass diff --git a/imbue_core/imbue_core/itertools.py b/imbue_core/imbue_core/itertools.py @@ -1,17 +1,13 @@ import contextlib import itertools -from typing import AsyncGenerator -from typing import Callable from typing import Generator from typing import Iterable from typing import Sequence from typing import TypeVar -from typing import cast from imbue_core.errors import ImbueError T = TypeVar("T") -TV = TypeVar("TV") class ImbueItertoolsValueError(ImbueError, ValueError): @@ -43,30 +39,6 @@ def first(iterable: Iterable[T]) -> T | None: return next(iter(iterable), None) -async def iterable_to_async(iterable: Iterable[T]) -> AsyncGenerator[T, None]: - for item in iterable: - yield item - - -# TODO delete/migrate out computronium/computronium/universal/utils.py -def generate_unique( - source: Iterable[T], - key: Callable[[T], TV] = cast(Callable[[T], TV], lambda item: item), -) -> Generator[T, None, None]: - unique = set() - for item in source: - value = key(item) - if value in unique: - continue - yield item - unique.add(value) - - -def generate_flattened(iterable: Iterable[Iterable[T]]) -> Generator[T, None, None]: - for item in iterable: - yield from item - - # TODO replace with itertools.batched when we can require Python 3.12+ def generate_chunks(iterable: Iterable[T], chunk_size: int) -> Generator[tuple[T, ...], None, None]: """Yield successive n-sized chunks from any iterable""" diff --git a/imbue_core/imbue_core/llm_testing_utils.py b/imbue_core/imbue_core/llm_testing_utils.py @@ -1,57 +0,0 @@ -from pathlib import Path - -from loguru import logger -from syrupy.assertion import SnapshotAssertion - -from imbue_core.caching import AsyncCache -from imbue_core.cattrs_serialization import deserialize_from_json -from imbue_core.cattrs_serialization import serialize_to_json - - -async def preload_llm_cache(persistent_cache_path: Path, temp_cache: AsyncCache) -> None: - logger.info( - "Loading existing cache from {persistent_cache_path}", - persistent_cache_path=persistent_cache_path, - ) - assert persistent_cache_path.exists(), f"Cache file {persistent_cache_path} does not exist." - existing_data = deserialize_from_json(persistent_cache_path.read_text()) - async with temp_cache as cache: - for key, value in existing_data.items(): - await cache.set(key, value) - - -async def record_llm_responses_in_cache(temp_cache: AsyncCache, persistent_cache_path: Path) -> None: - logger.info( - "Updating cache (!!!) at {persistent_cache_path}", - persistent_cache_path=persistent_cache_path, - ) - async with temp_cache as cache: - all_keys = await cache.get_all_keys() - data = await cache.get_all(all_keys) - if data: - persistent_cache_path.parent.mkdir(parents=True, exist_ok=True) - persistent_cache_path.write_text(serialize_to_json(data), encoding="utf-8") - - -def _sanitize_snapshot_name(snapshot_name: str) -> str: - return snapshot_name.replace("/", "").replace("\\", "") - - -def get_cache_file_from_snapshot_core(snapshot: SnapshotAssertion, suffix: str) -> Path: - # To prevent syrupy from cleaning the cache up immediately after written, we use this suffix and register it with pytest. - # Make sure to add a line like `--snapshot-ignore-file-extensions={suffix}` to the project's pytest.ini - - # Goal here is a cache file per test, not per test-file. - test_file = Path(snapshot.test_location.filepath) - snapshot_dir = test_file.parent / "__snapshots__" / test_file.stem - snapshot_dir.mkdir(parents=True, exist_ok=True) - cache_file = snapshot_dir / f"{_sanitize_snapshot_name(snapshot.test_location.testname)}.{suffix}" - return cache_file - - -def get_cache_file_from_snapshot(snapshot: SnapshotAssertion) -> Path: - return get_cache_file_from_snapshot_core(snapshot, "llm_cache_json") - - -def get_count_tokens_cache_file_from_snapshot(snapshot: SnapshotAssertion) -> Path: - return get_cache_file_from_snapshot_core(snapshot, "count_tokens_cache_json") diff --git a/imbue_core/imbue_core/nested_evolver.py b/imbue_core/imbue_core/nested_evolver.py @@ -2,8 +2,6 @@ One of the design goals is that mypy, autocomplete, and automatic refactoring work for the assignments made into these nested structures. -See: https://imbue-ai.slack.com/archives/C05D0SM2RT5/p1726185313480779?thread_ts=1722865932.537289&cid=C05D0SM2RT5 - If you make changes here and then the tests fail with: ``` E RecursionError: maximum recursion depth exceeded @@ -17,7 +15,6 @@ import threading from typing import Any from typing import Callable from typing import Generic -from typing import TypeGuard from typing import TypeVar from typing import cast @@ -28,7 +25,8 @@ from imbue_core.frozen_utils import FrozenDict from imbue_core.pydantic_utils import model_update _T = TypeVar("_T") -ObjectType = TypeVar("ObjectType") + +_threading_local = threading.local() def evolver(obj: _T) -> _T: @@ -62,47 +60,6 @@ def chill(evolver: _T) -> _T: return cast_evolver.chill() -_threading_local = threading.local() - - -# TODO: since mutate and thaw are stateful, if you call one without the other, you run into problems. -def thaw(obj: _T) -> _T: - global _threading_local - if hasattr(_threading_local, "evolved_obj"): - raise ValueError("Thaw does not support nested thawing.") - # pyre-ignore[16]: we're deliberately setting evolved_obj for the first time here - _threading_local.evolved_obj = evolver(obj) - return _threading_local.evolved_obj - - -# TODO: mypy complains because the input isn't anything related to ObjectType, but the output is. -# This also means the type checking doesn't quite work since it can't infer the return type of this function correctly -def mutate(dest: _T, src: Callable[[], _T]) -> ObjectType: # type: ignore - assign(dest, src) - try: - # pyre-ignore[34]: we don't have generic functions yet, so pyre complains that ObjectType isn't in the input - evolved_obj: ObjectType = _threading_local.evolved_obj # pyre-ignore[16]: pyre doesn't know about evolved_obj - return chill(evolved_obj) - except AttributeError as e: - raise ValueError("You must call mutate on a thawed object") from e - finally: - delattr(_threading_local, "evolved_obj") - - -def mutate_from_dict(dest: ObjectType, src: dict[str, Any]) -> ObjectType: - # Warning: using this function doesn't provide mypy type checking at the call site, but it allows a single interface for attrs and pydantic classes - # In most cases the above function should be used instead - evolved_obj = evolver(dest) - for key, value in src.items(): - assign(getattr(evolved_obj, key), lambda: value) - return chill(evolved_obj) - - -def evolver_isinstance(evolver: Any, cls: type[_T]) -> TypeGuard[_T]: - assert isinstance(evolver, _Evolver) # Tricked you, type system! - return evolver.isinstance(cls) - - class _RegularValue: regular_value: Any @@ -146,7 +103,13 @@ class _FrozenDictValue: class _Evolver(Generic[_T]): # pyre-ignore[13]: pyre is confused by the trickery here - _value: _RegularValue | _AttrValue | _TupleValue | _FrozenDictValue | _PydanticModelValue + _value: ( + _RegularValue + | _AttrValue + | _TupleValue + | _FrozenDictValue + | _PydanticModelValue + ) def __init__(self, initial_value: _T) -> None: super().__init__() @@ -174,14 +137,18 @@ class _Evolver(Generic[_T]): if item not in value.child_evolver_by_name: child_obj = getattr(value.attr_value, item) result = evolver(child_obj) - assert isinstance(result, _Evolver), "Expose a lie to the type system." + assert isinstance(result, _Evolver), ( + "Expose a lie to the type system." + ) value.child_evolver_by_name[item] = result return value.child_evolver_by_name[item] elif isinstance(value, _PydanticModelValue): if item not in value.child_evolver_by_name: child_obj = getattr(value.pydantic_model_value, item) result = evolver(child_obj) - assert isinstance(result, _Evolver), "Expose a lie to the type system." + assert isinstance(result, _Evolver), ( + "Expose a lie to the type system." + ) value.child_evolver_by_name[item] = result return value.child_evolver_by_name[item] raise TypeError( @@ -204,7 +171,9 @@ class _Evolver(Generic[_T]): elif isinstance(value, _FrozenDictValue): if key not in value.frozen_dict_evolvers: # Presumably we're going to evolver_assign to this very soon. - cast(_FrozenDictValue, self._value).frozen_dict_evolvers[key] = _Evolver(_RegularValue(None)) + cast(_FrozenDictValue, self._value).frozen_dict_evolvers[key] = ( + _Evolver(_RegularValue(None)) + ) return cast(_FrozenDictValue, self._value).frozen_dict_evolvers[key] raise TypeError( f"You're using [square_brackets] access {key=} on an object of {type(self._value)=} that doesn't support this (should have been a mypy error)." @@ -214,38 +183,38 @@ class _Evolver(Generic[_T]): """Recursively apply the recorded changes to the original object and return a new frozen instance.""" if isinstance(self._value, _AttrValue): new_children: dict[str, Any] = { - name: chill(child) for name, child in self._value.child_evolver_by_name.items() + name: chill(child) + for name, child in self._value.child_evolver_by_name.items() } assert attr.has(self._value.attr_value.__class__) return cast( _T, - attr.evolve(cast(Any, cast(_AttrValue, self._value).attr_value), **new_children), + attr.evolve( + cast(Any, cast(_AttrValue, self._value).attr_value), **new_children + ), ) elif isinstance(self._value, _PydanticModelValue): return cast( _T, model_update( self._value.pydantic_model_value, - update={name: chill(child) for name, child in self._value.child_evolver_by_name.items()}, + update={ + name: chill(child) + for name, child in self._value.child_evolver_by_name.items() + }, ), ) elif isinstance(self._value, _TupleValue): - return cast(_T, tuple(evolver.chill() for evolver in self._value.tuple_evolvers)) + return cast( + _T, tuple(evolver.chill() for evolver in self._value.tuple_evolvers) + ) elif isinstance(self._value, _RegularValue): return cast(_T, self._value.regular_value) elif isinstance(self._value, _FrozenDictValue): return cast( _T, - FrozenDict({k: v.chill() for k, v in self._value.frozen_dict_evolvers.items()}), + FrozenDict( + {k: v.chill() for k, v in self._value.frozen_dict_evolvers.items()} + ), ) raise ValueError(f"This Evolver has no value to evolve, {type(self._value)=}.") - - def isinstance(self, cls: type[_T]) -> bool: - """Check if the object being evolved is an instance of the given class.""" - if isinstance(self._value, _AttrValue): - return isinstance(self._value.attr_value, cls) - elif isinstance(self._value, _PydanticModelValue): - return isinstance(self._value.pydantic_model_value, cls) - elif isinstance(self._value, _FrozenDictValue): - return cls == FrozenDict - return False diff --git a/imbue_core/imbue_core/pydantic_serialization.py b/imbue_core/imbue_core/pydantic_serialization.py @@ -2,17 +2,12 @@ import threading from typing import Any from typing import TypeVar from typing import cast -from typing import get_args from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Discriminator -from pydantic import GetCoreSchemaHandler -from pydantic import Json from pydantic.alias_generators import to_camel -from pydantic_core import core_schema as pyd_core_schema -from imbue_core.frozen_utils import FrozenDict from imbue_core.nested_evolver import _Evolver from imbue_core.nested_evolver import chill from imbue_core.nested_evolver import evolver @@ -28,7 +23,9 @@ class EvolvableModel: # pyre-ignore[47]: pyre is not so easily tricked def evolve(self: T, attribute: V, new_value: V) -> T: # pyre-ignore[16]: pyre doesn't know about evolved_obj - assert _threading_local.evolved_obj is not None, ".ref() must be called before evolve" + assert _threading_local.evolved_obj is not None, ( + ".ref() must be called before evolve" + ) assert isinstance(attribute, _Evolver) # Tricked you, type system! dest_evolver: _Evolver[T] = cast(_Evolver[T], attribute) @@ -45,20 +42,6 @@ class EvolvableModel: return _threading_local.evolved_obj -class FrozenModel(EvolvableModel, BaseModel): - """ - The base class for most internal data (that does not need to be serialized). - - We generally prefer to keep data immutable in order to avoid side effects, race conditions, etc - """ - - model_config = ConfigDict( - frozen=True, - extra="forbid", - arbitrary_types_allowed=False, - ) - - class MutableModel(BaseModel): """ The base class for any internal data that strictly must be mutable. Should be used sparingly. @@ -100,23 +83,6 @@ class SerializableModel(EvolvableModel, BaseModel, Serializable): pydantic_extra.clear() -def model_dump(obj: BaseModel, is_camel_case: bool = False) -> dict: - return obj.model_dump(by_alias=is_camel_case) - - -def model_load(model_type: type[T], data: dict) -> T: - return model_type.model_validate(data) - - -def model_dump_json(obj: BaseModel | Json, is_camel_case: bool = False) -> str: - # pyre-fixme[16]: pyre complains that obj can be pydantic.types.AnyType, which has no model_dump_json - return obj.model_dump_json(by_alias=is_camel_case) - - -def model_load_json(model_type: type[T], data: str) -> T: - return model_type.model_validate_json(data) - - # this is mostly here for the default cases. # When you want to upgrade a model (and keep it backwards compatible), you can make a custom discriminator # (eg, that looks for the old type name or converts the old class names) @@ -150,26 +116,3 @@ def build_discriminator( return getattr(obj, field_name) return Discriminator(discriminator=discriminator) - - -class PydanticFrozenDictAnnotation: - @classmethod - def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler - ) -> pyd_core_schema.CoreSchema: - def validate_from_dict(d: dict | FrozenDict) -> FrozenDict: - return FrozenDict(d) - - frozendict_schema = pyd_core_schema.chain_schema( - [ - # pyre-ignore[16]: pyre is confused by using dict as a type like this - handler.generate_schema(dict[*get_args(source_type)]), - pyd_core_schema.no_info_plain_validator_function(validate_from_dict), - pyd_core_schema.is_instance_schema(FrozenDict), - ] - ) - return pyd_core_schema.json_or_python_schema( - json_schema=frozendict_schema, - python_schema=frozendict_schema, - serialization=pyd_core_schema.plain_serializer_function_ser_schema(dict), - ) diff --git a/imbue_core/imbue_core/retry_utils.py b/imbue_core/imbue_core/retry_utils.py @@ -1,30 +0,0 @@ -from typing import Callable - -from loguru import logger -from tenacity import RetryCallState - - -def _log_before_sleep(retry_state: RetryCallState, log_fn: Callable[[str], None]) -> None: - fn_name = getattr(retry_state.fn, "__name__", "unknown") - sleep_time = retry_state.next_action.sleep if retry_state.next_action is not None else 0 - outcome = retry_state.outcome - if outcome is not None: - exception = outcome.exception() - error_message = type(exception).__name__ + ": " + str(exception) - else: - error_message = "unknown" - log_fn( - f"Retrying {fn_name} in {sleep_time:.2f} seconds, attempt {retry_state.attempt_number} after error: {error_message}" - ) - - -def log_before_sleep(retry_state: RetryCallState) -> None: - _log_before_sleep(retry_state, logger.debug) - - -def log_error_before_sleep(retry_state: RetryCallState) -> None: - _log_before_sleep(retry_state, logger.error) - - -def log_trace_before_sleep(retry_state: RetryCallState) -> None: - _log_before_sleep(retry_state, logger.trace) diff --git a/imbue_core/imbue_core/secrets_utils.py b/imbue_core/imbue_core/secrets_utils.py @@ -61,21 +61,6 @@ def parse_secrets_file(filepath: str | pathlib.Path) -> dict[str, str]: return out -# TODO: this is gross and bad--we should make better handling for secrets. -# Right now we read the necessary secrets out of the bashenv files def get_secret(secret_name: str) -> str | None: - value = os.environ.get(secret_name) - if value is not None: - return value - secrets_files = ( - "science/secrets/environment_vars/bashenv.sh", - "science/secrets/environment_vars/bashenv_secrets.sh", - "science/secrets/environment_vars/common_vars.sh", - ) - for file in secrets_files: - if os.path.exists(file): - secrets = parse_secrets_file(file) - value = secrets.get(secret_name, None) - if value is not None: - return value - return None + """Get a secret from environment variables.""" + return os.environ.get(secret_name) diff --git a/imbue_core/imbue_core/section.py b/imbue_core/imbue_core/section.py @@ -1,131 +0,0 @@ -""" -Provides a context manager for logging a potentially time-consuming process, or a "section". - -- Prints logs at start and end of a section. - -- Prints a Markdown-like heading for nested sections: "#" for top-level sections, "##" for one level down, and so on. - -- Emits structured logs for easier query. - -The SectionWrapper context manager can be used to time the __enter__ and __exit__ methods of an existing context manager. -""" - -import contextlib -import threading -import time -from asyncio import CancelledError -from types import TracebackType -from typing import TypeVar - -from loguru import logger - -_monotonic_base = time.monotonic() - - -def _monotonic_time() -> float: - """A wrapper around time.monotonic() to make the return values a bit smaller and easier to read by a human.""" - return time.monotonic() - _monotonic_base - - -class _ThreadLocal(threading.local): - def __init__(self) -> None: - self.next_section_level: int = 0 - - -_thread_local = _ThreadLocal() - - -class Section(contextlib.ContextDecorator): - def __init__(self, message: str, log_level: int | str = "INFO") -> None: - # TODO: loguru doesn't properly display integer log levels like e.g. logging.INFO - self.message = message - self.log_level = log_level - - def __enter__(self) -> "Section": - level = _thread_local.next_section_level - _thread_local.next_section_level += 1 - self.header = "#" * (level + 1) # pyre-ignore[16] - self.start_monotonic_time = _monotonic_time() # pyre-ignore[16] - start_clock_time = time.time() - self.section = { # pyre-ignore[16] - "name": self.message, - "level": level, - "start_monotonic_time": self.start_monotonic_time, - "start_clock_time": start_clock_time, - } - logger.log( - self.log_level, - f"{self.header} Start: {self.message}", - section=self.section, - ) - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - _thread_local.next_section_level -= 1 - finish_monotonic_time = _monotonic_time() - finish_clock_time = time.time() - # pyre-ignore[16]: we set this on __enter__ - duration_seconds = finish_monotonic_time - self.start_monotonic_time - section = self.section | { # pyre-ignore[16]: we set this on __enter__ - "finish_monotonic_time": finish_monotonic_time, - "finish_clock_time": finish_clock_time, - "duration_seconds": duration_seconds, - } - self.elapsed = duration_seconds # pyre-ignore[16] - - header = self.header # pyre-ignore[16]: we set this on __enter__ - - if exc_val is None: - logger.log( - self.log_level, - f"{header} Done: {self.message} (took {duration_seconds:.2f} seconds)", - section=section | {"result": "success"}, - ) - else: - if isinstance(exc_val, CancelledError): - logger.log( - self.log_level, - f"{header} Cancelled: {self.message} (took {duration_seconds:.2f} seconds)", - section=section | {"result": "cancelled"}, - ) - else: - logger.log( - self.log_level, - f"{header} Failed: {self.message} (within {duration_seconds:.2f} seconds)", - section=section | {"result": "failed"}, - ) - - -T = TypeVar("T") - - -# pyre-ignore[24]: pyre doesn't understand AbstractContextManager -class SectionWrapper(contextlib.AbstractContextManager[T]): - # pyre-ignore[24] - def __init__( - self, - cm: contextlib.AbstractContextManager[T], - enter_message: str, - exit_message: str, - ) -> None: - self._cm = cm - self._enter_message = enter_message - self._exit_message = exit_message - - def __enter__(self) -> T: - with Section(self._enter_message): - return self._cm.__enter__() - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - with Section(self._exit_message): - self._cm.__exit__(exc_type, exc_val, exc_tb) diff --git a/imbue_core/imbue_core/serialization.py b/imbue_core/imbue_core/serialization.py @@ -1,6 +1,5 @@ import builtins import datetime -import inspect import json from enum import Enum from functools import cached_property @@ -10,7 +9,6 @@ from pathlib import PosixPath from traceback import format_tb from types import TracebackType from typing import Any -from typing import Callable from typing import Hashable from typing import Iterable from typing import Mapping @@ -34,9 +32,9 @@ from imbue_core.fixed_traceback import FixedTraceback from imbue_core.pydantic_serialization import SerializableModel from imbue_core.serialization_types import Serializable -assert ( - version("yasoo") == "0.12.6" -), "This code was written for yasoo 0.12.6 and requires inheriting / monkeypatching the deserializer, so you probably don't want to use any other version without fixing TupleDeserializer" +assert version("yasoo") == "0.12.6", ( + "This code was written for yasoo 0.12.6 and requires inheriting / monkeypatching the deserializer, so you probably don't want to use any other version without fixing TupleDeserializer" +) T = TypeVar("T", bound=Hashable) @@ -57,7 +55,12 @@ class TupleDeserializer(Deserializer): return data if isinstance(data, list): list_types = self._get_list_types(obj_type, data) - return tuple([self._deserialize(d, t, type_key, allow_extra_fields, all_globals) for t, d in list_types]) + return tuple( + [ + self._deserialize(d, t, type_key, allow_extra_fields, all_globals) + for t, d in list_types + ] + ) assert isinstance(data, dict), f"Expected a dict, but got {type(data)}" @@ -65,7 +68,11 @@ class TupleDeserializer(Deserializer): if type_key is not None: type_data = data.get(type_key, None) - if type_data is not None and type_data.startswith("builtins.") and type_data != "builtins.dict": + if ( + type_data is not None + and type_data.startswith("builtins.") + and type_data != "builtins.dict" + ): return data["value"] # TODO: we need to potentially handle `builtins.dict` @@ -81,10 +88,15 @@ class TupleDeserializer(Deserializer): # return data["value"] # TODO: remove this hack. Many of our sqlite files (search s3_sqlite_path) have FrozenDicts - if isinstance(type_key, str) and data.get(type_key, None) == "flax.core.frozen_dict.FrozenDict": + if ( + isinstance(type_key, str) + and data.get(type_key, None) == "flax.core.frozen_dict.FrozenDict" + ): data[type_key] = "imbue_core.frozen_utils.FrozenMapping" # we deliberately pass in a `None` type_key sometimes, which results in just returning obj_type - obj_type = self._get_object_type(obj_type, data, type_key, all_globals) # pyre-ignore[6] + obj_type = self._get_object_type( + obj_type, data, type_key, all_globals + ) # pyre-ignore[6] if type_key in data: data.pop(type_key) real_type, generic_args = normalize_type(obj_type, all_globals) @@ -95,7 +107,9 @@ class TupleDeserializer(Deserializer): bases = {ancestor for b in bases for ancestor in b.__bases__} if not ignore_custom_deserializer: - deserialization_method = self._custom_deserializers.get(obj_type, self._custom_deserializers.get(real_type)) + deserialization_method = self._custom_deserializers.get( + obj_type, self._custom_deserializers.get(real_type) + ) if deserialization_method: return deserialization_method(data) for base_class, method in self._inheritance_deserializers.items(): @@ -119,7 +133,7 @@ class TupleDeserializer(Deserializer): if e.name.lower() == value.lower(): return e return real_type(value) - # TODO (49780118-61e5-446b-b44b-cabb3ffc0ba2): serialization currently breaks with builtin.dicts and dicts with non-string keys + # TODO: serialization currently breaks with builtin.dicts and dicts with non-string keys # if you have weird keys in your dict this branch won't be hit and your object won't be properly deserialized elif issubclass(real_type, Mapping): key_type = generic_args[0] if generic_args else None @@ -141,9 +155,13 @@ class TupleDeserializer(Deserializer): ) elif issubclass(real_type, Iterable): # If we got here it means data is not a list, so obj_type came from the data itself and is safe to use - return self._load_iterable(data, obj_type, type_key, allow_extra_fields, all_globals) + return self._load_iterable( + data, obj_type, type_key, allow_extra_fields, all_globals + ) elif real_type != obj_type: - return self._deserialize(data, real_type, type_key, allow_extra_fields, external_globals) + return self._deserialize( + data, real_type, type_key, allow_extra_fields, external_globals + ) else: raise @@ -165,7 +183,9 @@ class TupleDeserializer(Deserializer): # TODO: probably a good idea to ensure that all dicts are frozen as well... class FrozenSerializer(Serializer): - def __init__(self, force_serialization: bool, allow_unsafe_list_serialization: bool = False) -> None: + def __init__( + self, force_serialization: bool, allow_unsafe_list_serialization: bool = False + ) -> None: super().__init__() self._force_serialization = force_serialization self._allow_unsafe_list_serialization = allow_unsafe_list_serialization @@ -183,10 +203,12 @@ class FrozenSerializer(Serializer): logger.info("Converting list to tuple for serialization: {}", obj) obj = tuple(obj) else: - raise Exception(f"Lists are not allowed for serialization. Use tuples instead. Current iterable: {obj}") - assert isinstance( - obj, (tuple, frozenset, bytes) - ), f"All iterables should be tuples or frozenset. Received {obj}" + raise Exception( + f"Lists are not allowed for serialization. Use tuples instead. Current iterable: {obj}" + ) + assert isinstance(obj, (tuple, frozenset, bytes)), ( + f"All iterables should be tuples or frozenset. Received {obj}" + ) return cast( list[object], tuple( @@ -217,7 +239,9 @@ class FrozenSerializer(Serializer): return obj # type: ignore if globals: - self._custom_serializers = resolve_types(self._custom_serializers, globals) # type: ignore + self._custom_serializers = resolve_types( + self._custom_serializers, globals + ) # type: ignore result = self._serialize( obj, @@ -254,7 +278,9 @@ class SerializedException(SerializableModel): traceback_dict: JsonTypeAlias @classmethod - def build(cls, exception: BaseException, traceback: TracebackType | None = None) -> "SerializedException": + def build( + cls, exception: BaseException, traceback: TracebackType | None = None + ) -> "SerializedException": if traceback is None: traceback = exception.__traceback__ assert traceback is not None, " ".join( @@ -265,59 +291,16 @@ class SerializedException(SerializableModel): ) return SerializedException( # pyre-fixme[28]: pyre doesn't understand pydantic exception=get_fully_qualified_name_for_error(exception), - args=tuple(_convert_serialized_exception_args(x, traceback) for x in exception.args), + args=tuple( + _convert_serialized_exception_args(x, traceback) for x in exception.args + ), traceback_dict=FixedTraceback.from_tb(traceback).as_dict(), ) - @cached_property - def traceback(self) -> FixedTraceback | None: - traceback_dict = self.traceback_dict - if traceback_dict is None: - return None - return FixedTraceback.from_dict(traceback_dict) - - @cached_property - def exception_module(self) -> str: - if "." in self.exception: - return self.exception.rsplit(".", maxsplit=1)[0] - return "" - - @cached_property - def exception_type(self) -> str: - return self.exception.rsplit(".", maxsplit=1)[-1] - - @cached_property - def exception_class(self) -> type[BaseException]: - if self.exception_module: - return cast( - type[BaseException], - getattr(import_module(self.exception_module), self.exception_type, None), - ) - else: - return cast(type[BaseException], getattr(builtins, self.exception_type, None)) - - def construct_instance(self) -> BaseException: - try: - exception = self.exception_class(*cast(tuple[Serializable, ...], self.args)) - except TypeError as e: - message_with_arg_info = ( - f"Failed to construct exception {self.exception_class} with args {self.args}.", - "Ensure that the exception class is serializable and can be constructed with the provided args.", - ) - raise TypeError(" ".join(message_with_arg_info)) from e - - return exception - - def as_formatted_traceback(self) -> str: - if self.traceback is None: - traceback_str = "" - else: - # pyre-ignore[6]: pyre doesn't know that FixedTraceback is a traceback (since it's not a TracebackType) - traceback_str = "".join(format_tb(self.traceback)) - return f"Traceback (most recent call last):\n{traceback_str}\n{self.exception}: {self.args}" - -def _convert_serialized_exception_args(error: Serializable, traceback: TracebackType | None = None) -> JsonTypeAlias: +def _convert_serialized_exception_args( + error: Serializable, traceback: TracebackType | None = None +) -> JsonTypeAlias: if isinstance(error, BaseException): return SerializedException.build(error, traceback=traceback) elif isinstance(error, (list, tuple)): @@ -338,14 +321,24 @@ def _convert_to_json_serializable_with_better_errors( return obj # type: ignore if isinstance(obj, Mapping): return { - key: _convert_to_json_serializable_with_better_errors(value, f"{path}.{key}") for key, value in obj.items() + key: _convert_to_json_serializable_with_better_errors( + value, f"{path}.{key}" + ) + for key, value in obj.items() } if isinstance(obj, Iterable): - return [_convert_to_json_serializable_with_better_errors(item, f"{path}[{i}]") for i, item in enumerate(obj)] - raise TypeError(f'Found object of type "{type(obj).__name__}" at {path} which cannot be serialized') + return [ + _convert_to_json_serializable_with_better_errors(item, f"{path}[{i}]") + for i, item in enumerate(obj) + ] + raise TypeError( + f'Found object of type "{type(obj).__name__}" at {path} which cannot be serialized' + ) -SERIALIZER = FrozenSerializer(force_serialization=False, allow_unsafe_list_serialization=False) +SERIALIZER = FrozenSerializer( + force_serialization=False, allow_unsafe_list_serialization=False +) DESERIALIZER = TupleDeserializer() # note: you cannot change this without changing other calls to yasoo, this is its default @@ -408,14 +401,14 @@ def serialize_datetime(data: datetime.datetime) -> dict: @DESERIALIZER.register() def deserialize_datetime(data: dict) -> datetime.datetime: - return datetime.datetime.fromtimestamp(data["time"], datetime.timezone.utc if data.get("tzaware", None) else None) - + return datetime.datetime.fromtimestamp( + data["time"], datetime.timezone.utc if data.get("tzaware", None) else None + ) -def serialize_to_dict(obj: Any) -> dict[str, Any]: - return cast(dict[str, Any], SERIALIZER.serialize(obj)) - -def serialize_to_json(obj: Any, indent: int | None = None, sort_keys: bool = False) -> str: +def serialize_to_json( + obj: Any, indent: int | None = None, sort_keys: bool = False +) -> str: try: return json.dumps(SERIALIZER.serialize(obj), indent=indent, sort_keys=sort_keys) except Exception as e: @@ -424,39 +417,8 @@ def serialize_to_json(obj: Any, indent: int | None = None, sort_keys: bool = Fal def deserialize_from_json(data: str) -> Any: try: - return DESERIALIZER.deserialize(json.loads(data)) # pyre-ignore[20]: pyre doesn't understand deserialize + return DESERIALIZER.deserialize( + json.loads(data) + ) # pyre-ignore[20]: pyre doesn't understand deserialize except Exception as e: raise SerializationError(str(e)) from e - - -def deserialize_from_dict(data: dict[str, Any]) -> Any: - try: - return DESERIALIZER.deserialize(data) # pyre-ignore[20]: pyre doesn't understand deserialize - except Exception as e: - raise SerializationError(str(e)) from e - - -def deserialize_from_dict_with_type(data: dict[str, Any], obj_type: type[T]) -> T: - try: - result = DESERIALIZER.deserialize(data, obj_type=obj_type) - assert isinstance(result, obj_type), f"Expected an object of type {obj_type}, but got {result}" - return result - except Exception as e: - raise SerializationError(str(e)) from e - - -def deserialize_from_json_with_type(data: str | bytes | bytearray, obj_type: type[T]) -> T: - try: - return deserialize_from_dict_with_type(json.loads(data), obj_type=obj_type) - except Exception as e: - raise SerializationError(str(e)) from e - - -def get_serializable_properties(obj: Any) -> dict[str, Any]: - members = inspect.getmembers(type(obj)) - marked_members = [name for name, member in members if is_serializable_property(member)] - return {name: getattr(obj, name) for name in marked_members} - - -def is_serializable_property(func: Callable) -> bool: - return getattr(func, "_imbue_is_serializable_property", False) diff --git a/imbue_core/imbue_core/simple_git.py b/imbue_core/imbue_core/simple_git.py @@ -1,215 +0,0 @@ -"""This file implements a subset of the functionality found in git.py and compute_environment.py, but everything is synchronous.""" - -import shlex -import subprocess -import time -from pathlib import Path -from typing import Sequence - -from loguru import logger - -from imbue_core.async_monkey_patches import log_exception -from imbue_core.computing_environment.data_types import AnyPath -from imbue_core.computing_environment.data_types import RunCommandError - - -class SyncLocalGitRepo: - """ - Provides different operations that you can perform over a git repository. - - Implements a subset of our async LocalGitRepo, but also pulls in some of the functions from computing_environment.py - that normally would operate on a LocalGitRepo. - - Over time, we can probably replace LocalGitRepo and computing_environment more or less fully, as we're migrating - code away from asyncio. - - Feel free to move additional functions from computing_environment.py into this class as needed. Usually, you just need - to remove async/await keywords, and replace calls to compute_environment. member functions with self. calls. - """ - - _base_path: Path - - def __init__(self, base_path: Path) -> None: - self._base_path = base_path - - @property - def base_path(self) -> Path: - """The base path of the git repo.""" - return self._base_path - - def run_git( - self, - command: Sequence[str], - check: bool = True, - cwd: AnyPath | None = None, - is_error_logged: bool = True, - is_stripped: bool = True, - retry_on_git_lock_error: bool = True, - ) -> str: - """Run a git command in the repo. - - Example: - ``` - git_repo.run_git("status") - ``` - """ - command = ["git"] + list(command) - if not retry_on_git_lock_error: - result = self.run_command(command, check=check, is_error_logged=is_error_logged, cwd=cwd) - else: - result = self._run_command_with_retry_on_git_lock_error( - command, check=check, is_error_logged=is_error_logged, cwd=cwd - ) - if is_stripped: - return result.strip() - return result - - def run_command( - self, - command: Sequence[str], - check: bool = True, - secrets: dict[str, str] | None = None, - cwd: AnyPath | None = None, - is_error_logged: bool = True, - ) -> str: - """Run a command in the repo. - - Note, this can be used to run any command, not just git. - """ - command_string = shlex.join(command) - logger.trace( - f"Running command: {command_string=} from cwd={cwd or self.base_path} with {secrets=} {check=} {is_error_logged=}" - ) - completed_proc = subprocess.run( - command, - cwd=cwd or self._base_path, - stdin=subprocess.DEVNULL, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=secrets, - ) - # note, need to be carefull not to strip() lines since whitespace may be important (e.g. for diffs) - # return joined lines since mostly we only use the output for logging, and this way we arn't - # passing around lots of lists. Also it's easy to parse by lines if needed - try: - stdout = completed_proc.stdout.decode("UTF-8") - except UnicodeDecodeError as e: - # If we don't encounter this, it likely means something was fixed upstream and we can safely delete - log_exception( - e, - "Command {command_string} failed to decode stdout, replacing any invalid bytes which could lead to problems later", - command_string=command_string, - ) - stdout = completed_proc.stdout.decode("UTF-8", errors="replace") - stderr = completed_proc.stderr.decode("UTF-8") - if check and completed_proc.returncode != 0: - error_message = f"command run from cwd={self.base_path} failed with exit code {completed_proc.returncode} and stdout:\n{stdout}\nstderr:\n{stderr}" - if is_error_logged: - logger.error( - f"command attempted: '{command_string}' from cwd={self.base_path}\nerror message: {error_message}" - ) - # this should not be None, but do this to satisfy type checker, int or None we throw the same error - returncode = completed_proc.returncode or -1 - raise RunCommandError( - cmd=command_string, - stderr=stderr, - returncode=returncode, - cwd=cwd or self.base_path, - ) - return stdout - - def get_git_diff( - self, - commit_hash: str | None = None, - staged: bool = False, - is_error_logged: bool = True, - include_binary: bool = True, - ) -> str: - """Get the diff for the current repo state.""" - # make sure `is_stripped=False` otherwise patch can be invalid - command = ["diff", "--full-index"] - if include_binary: - # Without --binary, diffs of binary files will just contain a summary statement such as "Binary files a/file.bin and b/file.bin differ". - # Such diffs cannot be applied, but are useful for inclusion in LLM prompts. - command.append("--binary") - if staged: - command.append("--staged") - if commit_hash: - command.append(commit_hash) - return self.run_git(command, is_stripped=False, is_error_logged=is_error_logged) - - def get_untracked_files(self) -> tuple[str, ...]: - """Get the untracked files in the repo.""" - result = self.run_git(["ls-files", "--others", "--exclude-standard"], is_error_logged=False) - return tuple([line.strip() for line in result.splitlines() if line.strip()]) - - def get_untracked_file_diff(self, file_path: str, include_binary: bool = True) -> str: - """Get the diff for a untracked file. - - Note this function will raise a RunCommandError if the there is no diff for the untracked file or if there - is another error running the command. So it is best to use this function after checking that the file is untracked - using get_untracked_files function. - """ - command = ["diff", "--no-index"] - if include_binary: - command.append("--binary") - untracked_diff = self.run_git( - command + ["/dev/null", str(file_path)], - # Unfortunately, `--no-index` implies `--exit-code`, which will cause git diff to return an exit code of 1 - # if the diff is not empty. So we can't use check=True here. We'll check for an empty output to detect failures. - check=False, - is_error_logged=True, - is_stripped=False, - ) - if not untracked_diff: - raise RunCommandError(f"Unable to diff untracked file {file_path}") - return untracked_diff - - def is_commit_a_branch(self, commit_hash: str) -> bool: - """Check if the given git ref is a branch.""" - try: - self.run_git( - ("show-ref", "--verify", "-q", f"refs/heads/{commit_hash}"), - is_error_logged=False, - check=True, - ) - return True - except RunCommandError as e: - if e.returncode == 1: - return False - raise - - def get_merge_base(self, branch_name: str, target_branch: str) -> str: - """Get the merge base of the given branch and target branch. - - The merge base is the most recent commit that is on both branches. - """ - return self.run_git(["merge-base", branch_name, target_branch], is_error_logged=False) - - def _run_command_with_retry_on_git_lock_error( - self, - command: Sequence[str], - check: bool = True, - is_error_logged: bool = True, - cwd: AnyPath | None = None, - ) -> str: - max_retries = 50 - retry_count = 0 - retry_delay = 0.1 # seconds - while True: - try: - return self.run_command( - command, - check=check, - is_error_logged=is_error_logged and retry_count >= max_retries, - cwd=cwd, - ) - except RunCommandError as e: - error_message = str(e) - if "fatal: Unable to create" in error_message and ".git/index.lock': File exists" in error_message: - if retry_count >= max_retries: - raise - time.sleep(retry_delay) - retry_count += 1 - else: - raise diff --git a/imbue_core/imbue_core/test_utils.py b/imbue_core/imbue_core/test_utils.py @@ -1,88 +0,0 @@ -import contextlib -import shutil -import tempfile -import time -from pathlib import Path -from typing import AsyncGenerator -from typing import Callable -from typing import Generator - -from loguru import logger -from syrupy.assertion import SnapshotAssertion - -from imbue_core.agents.llm_apis.data_types import CachedCostedLanguageModelResponse -from imbue_core.agents.llm_apis.data_types import CachedCostedModelResponse -from imbue_core.agents.llm_apis.data_types import CachedCountTokensResponse -from imbue_core.agents.llm_apis.llm_testing_utils import check_llm_responses_in_cache -from imbue_core.caching import AsyncCache -from imbue_core.llm_testing_utils import get_cache_file_from_snapshot -from imbue_core.llm_testing_utils import get_count_tokens_cache_file_from_snapshot -from imbue_core.llm_testing_utils import preload_llm_cache -from imbue_core.llm_testing_utils import record_llm_responses_in_cache - - -def info_if_not_quiet(quiet: bool, message: str) -> None: - if not quiet: - logger.info(message) - - -@contextlib.contextmanager -def create_temp_dir(root_dir: Path) -> Generator[Path, None, None]: - with tempfile.TemporaryDirectory(dir=root_dir) as temp_dir: - yield Path(temp_dir) - shutil.rmtree(temp_dir) - - -def wait_until(condition: Callable[[], bool], timeout: float = 5.0, interval: float = 0.5) -> None: - start_time = time.monotonic() - while True: - if condition(): - return - if time.monotonic() - start_time > timeout: - raise TimeoutError("Condition not met within timeout period") - time.sleep(interval) - - -async def make_llm_cache_with_snapshot_core( - snapshot: SnapshotAssertion, - json_cache_file: Path, - value_cls: type[CachedCostedModelResponse], - quiet: bool = True, - suffix: str = "", -) -> AsyncGenerator[Path, None]: - info_if_not_quiet(quiet, f"Using llm_cache_pathfixture: {json_cache_file=}") - - with tempfile.TemporaryDirectory() as cache_path_string: - cache_path = Path(cache_path_string) - cache_context = AsyncCache(cache_path, value_cls) - - if not snapshot.session.update_snapshots: - if json_cache_file.exists(): - await preload_llm_cache(json_cache_file, cache_context) - - yield cache_path - info_if_not_quiet(quiet, "Finished with llm_cache_pathfixture, updating cache if needed.") - - if snapshot.session.update_snapshots: - await record_llm_responses_in_cache(cache_context, json_cache_file) - - await check_llm_responses_in_cache(snapshot, cache_context, suffix) - info_if_not_quiet(quiet, "Finished with llm_cache_pathfixture, checking cache.") - - -async def make_llm_cache_with_snapshot(snapshot: SnapshotAssertion, quiet: bool = True) -> AsyncGenerator[Path, None]: - json_cache_file = get_cache_file_from_snapshot(snapshot) - async for path in make_llm_cache_with_snapshot_core( - snapshot, json_cache_file, CachedCostedLanguageModelResponse, quiet - ): - yield path - - -async def make_count_tokens_cache_with_snapshot( - snapshot: SnapshotAssertion, quiet: bool = True -) -> AsyncGenerator[Path, None]: - json_cache_file = get_count_tokens_cache_file_from_snapshot(snapshot) - async for path in make_llm_cache_with_snapshot_core( - snapshot, json_cache_file, CachedCountTokensResponse, quiet, "_count_tokens" - ): - yield path diff --git a/imbue_core/pyproject.toml b/imbue_core/pyproject.toml @@ -11,7 +11,6 @@ authors = [ { name="Imbue", email="imbue@imbue.com" }, ] license = "MIT" -# NOTE: This list is replicated in sculptor/pyproject.toml, and the copy there must be kept in sync. dependencies = [ "anyio", "attrs", diff --git a/imbue_tools/README.md b/imbue_tools/README.md @@ -1,5 +1,5 @@ # Purpose -Shared functionality for imbue-cli tools like vet, imbue-retrieve, etc. +Shared functionality for CLI tools like vet. # Contents - formatting git repos as LLM input diff --git a/imbue_tools/imbue_tools/conftest.py b/imbue_tools/imbue_tools/conftest.py @@ -1,44 +0,0 @@ -from pathlib import Path -from typing import Callable -from typing import Generator - -import pytest -from pytest_asyncio import fixture as async_fixture -from syrupy.assertion import SnapshotAssertion - -from imbue_core.agents.configs import LanguageModelGenerationConfig -from imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName -from imbue_core.async_monkey_patches_test import explode_on_error # noqa: F401 -from imbue_core.test_repo_utils import make_simple_test_git_repo -from imbue_core.test_utils import make_llm_cache_with_snapshot - -llm_cache_path = async_fixture(make_llm_cache_with_snapshot) - - -# this is copied from sculptor/conftest.py -# (it must be copied rather than imported because of the autouse) -@pytest.fixture(autouse=True) -def always_explode_on_error( - explode_on_error: Callable[[], Generator[None, None, None]], -) -> Generator[None, None, None]: - """ - Ensures that we do not log errors or exceptions during testing. - - If your test is checking error handling behavior (and you expect to see a log_exception call), - use the `expect_exact_logged_errors` decorator to suppress the logging of those errors. - """ - yield - - -def llm_config_for_test( - llm_cache_path: Path, snapshot: SnapshotAssertion, is_caching_inputs: bool = False -) -> LanguageModelGenerationConfig: - return LanguageModelGenerationConfig( - model_name=AnthropicModelName.CLAUDE_4_SONNET_2025_05_14, - cache_path=llm_cache_path, - is_running_offline=not snapshot.session.update_snapshots, - is_caching_inputs=is_caching_inputs, - ) - - -simple_test_git_repo = pytest.fixture(make_simple_test_git_repo) diff --git a/imbue_tools/imbue_tools/get_conversation_history/get_conversation_history.py b/imbue_tools/imbue_tools/get_conversation_history/get_conversation_history.py @@ -1,5 +1,4 @@ import json -from pathlib import Path from typing import assert_never from loguru import logger @@ -11,9 +10,6 @@ from vet_types.messages import ChatInputUserMessage from vet_types.messages import ConversationMessageUnion from vet_types.messages import ResponseBlockAgentMessage -CONVERSATION_FILE_ENV_VAR = "CONVERSATION_FILE" -TASK_SOURCE_BRANCH_ENV_VAR = "TASK_SOURCE_BRANCH" - class ConversationLoadingError(Exception): pass @@ -58,14 +54,6 @@ def format_conversation_history_for_prompt( # === loading from file === -def load_conversation_history( - conversation_file_path: Path, -) -> tuple[ConversationMessageUnion, ...]: - """Load a jsonl file into a list of conversation messages""" - file_contents = conversation_file_path.read_text() - return parse_conversation_history(file_contents) - - def parse_conversation_history( conversation_str: str, ) -> tuple[ConversationMessageUnion, ...]: diff --git a/imbue_tools/imbue_tools/repo_utils/context_retrieval.py b/imbue_tools/imbue_tools/repo_utils/context_retrieval.py @@ -1,9 +1,6 @@ -import asyncio import threading import time -from contextlib import asynccontextmanager from pathlib import Path -from typing import AsyncGenerator from typing import Generator import pygit2 @@ -12,7 +9,6 @@ from pygit2.enums import ObjectType from pygit2.repository import Repository from imbue_core.async_utils import make_async -from imbue_core.git import LocalGitRepo from imbue_tools.repo_utils.diff_utils import apply_diffs_to_files from imbue_tools.repo_utils.file_system import FileContents from imbue_tools.repo_utils.file_system import InMemoryFileSystem @@ -34,8 +30,6 @@ class RepoContextManager: # We need the sync lock due to pygit2 being synchronous. # It is mostly used for the blob data cache, but also for the repo contents by git hash cache. self._lock = threading.Lock() - # We need the async lock for tests TODO: we can probably remove this - self._local_repo_async_lock: asyncio.Lock = asyncio.Lock() @classmethod def build(cls, repo_path: Path) -> "RepoContextManager": @@ -125,12 +119,3 @@ class RepoContextManager: else: raise ValueError(f"Unexpected entry type in git tree: {entry.type}") - - @asynccontextmanager - async def tmp_repo_context(self) -> AsyncGenerator[LocalGitRepo, None]: - """ - This function is only used in tests - TODO: we can probably remove it - """ - async with self._local_repo_async_lock: - yield LocalGitRepo(self.repo_path) diff --git a/imbue_tools/imbue_tools/repo_utils/diff_utils.py b/imbue_tools/imbue_tools/repo_utils/diff_utils.py @@ -8,7 +8,6 @@ from async_lru import alru_cache # type: ignore[undefined-attribute]: pyre on m from loguru import logger from imbue_tools.repo_utils.errors import DiffApplicationError -from imbue_tools.repo_utils.errors import DiffCalculationError from imbue_tools.repo_utils.file_system import FileContents from imbue_tools.repo_utils.file_system import InMemoryFileSystem from imbue_tools.repo_utils.file_system import SymlinkContents @@ -18,65 +17,6 @@ from imbue_tools.repo_utils.file_system_utils import ( from imbue_tools.repo_utils.file_system_utils import ( temporary_local_dir_from_in_memory_file_system, ) -from imbue_tools.repo_utils.file_system_utils import write_file_contents_to_dir - - -class NonZeroReturncodeError(Exception): - pass - - -async def get_diff_between_files(old_file_contents: InMemoryFileSystem, new_file_contents: InMemoryFileSystem) -> str: - with ( - tempfile.TemporaryDirectory() as old_repo_dir, - tempfile.TemporaryDirectory() as new_repo_dir, - ): - # Get all changed file contents to prevent writing more than necessary - changed_old_file_contents_dict = {} - changed_new_file_contents_dict = {} - old_file_contents_dict = old_file_contents.files - new_file_contents_dict = new_file_contents.files - for file_path in old_file_contents_dict.keys() | new_file_contents_dict.keys(): - if file_path not in old_file_contents_dict: - changed_new_file_contents_dict[file_path] = new_file_contents_dict[file_path] - elif file_path not in new_file_contents_dict: - changed_old_file_contents_dict[file_path] = old_file_contents_dict[file_path] - elif old_file_contents_dict[file_path] != new_file_contents_dict[file_path]: - changed_old_file_contents_dict[file_path] = old_file_contents_dict[file_path] - changed_new_file_contents_dict[file_path] = new_file_contents_dict[file_path] - - changed_old_file_contents = InMemoryFileSystem.build(changed_old_file_contents_dict) - changed_new_file_contents = InMemoryFileSystem.build(changed_new_file_contents_dict) - - await write_file_contents_to_dir(changed_old_file_contents, old_repo_dir) - await write_file_contents_to_dir(changed_new_file_contents, new_repo_dir) - - try: - result = subprocess.run( - ( - "git", - "diff", - "--no-index", - "--relative", - "--full-index", - "--binary", - old_repo_dir, - new_repo_dir, - ), - capture_output=True, - text=True, - timeout=10.0, - ) - if result.returncode == 0 or result.returncode == 1: - diff = result.stdout - else: - raise NonZeroReturncodeError(f"git diff process returned with non-zero returncode {result.returncode}") - except Exception as e: - raise DiffCalculationError from e - - diff = diff.replace(old_repo_dir, "") - diff = diff.replace(new_repo_dir, "") - - return diff @alru_cache diff --git a/imbue_tools/imbue_tools/repo_utils/errors.py b/imbue_tools/imbue_tools/repo_utils/errors.py @@ -1,6 +1,3 @@ -from imbue_core.errors import ExpectedError - - class PromptAssemblyError(Exception): """Raised when there is an error assembling the prompt.""" @@ -9,17 +6,5 @@ class ContextLengthExceededError(PromptAssemblyError): """Raised when the context length exceeds the maximum allowed length.""" -class InvalidVersionedConfigError(ExpectedError): - pass - - -class MissingVersionedConfigError(ExpectedError): - pass - - class DiffApplicationError(Exception): pass - - -class DiffCalculationError(Exception): - pass diff --git a/imbue_tools/imbue_tools/repo_utils/file_system.py b/imbue_tools/imbue_tools/repo_utils/file_system.py @@ -4,7 +4,6 @@ from typing import Mapping from imbue_core.frozen_utils import FrozenDict from imbue_core.frozen_utils import deep_freeze_mapping from imbue_core.pydantic_serialization import SerializableModel -from imbue_tools.repo_utils.subrepo_formatting import BaseFilenamePattern class SymlinkContents(SerializableModel): @@ -70,33 +69,3 @@ def _try_decode_file_contents(contents: FileContents) -> DecodedTextFileContents return contents.decode("utf-8") except UnicodeDecodeError: return None - - -def filter_files_patterns( - file_system: InMemoryFileSystem, - include_patterns: tuple[str, ...] | None = None, - exclude_patterns: tuple[str, ...] | None = None, -) -> InMemoryFileSystem: - """Filter all files based on include/exclude patterns. - If an include pattern is provided, only files that match the include pattern will be included. - If an exclude pattern is provided, files that match the exclude pattern will be excluded. If no include or exclude patterns are provided, all files will be included. - - Args: - file_system: The file system to filter - include_patterns: Glob patterns for files to include - exclude_patterns: Glob patterns for files to exclude - - Returns: - Filtered InMemoryFileSystem - """ - include_spec = BaseFilenamePattern.from_lines(include_patterns or ()) - exclude_spec = BaseFilenamePattern.from_lines(exclude_patterns or ()) - filtered_files = { - file_path: content - for file_path, content in file_system.files.items() - if ( - (not include_patterns or include_spec.match_file(file_path)) - and (not exclude_patterns or not exclude_spec.match_file(file_path)) - ) - } - return InMemoryFileSystem.build(filtered_files) diff --git a/imbue_tools/imbue_tools/repo_utils/find_relative_to.py b/imbue_tools/imbue_tools/repo_utils/find_relative_to.py @@ -1,30 +0,0 @@ -from pathlib import Path - -from imbue_core.simple_git import SyncLocalGitRepo - - -def find_relative_to_commit_hash(relative_to: str, repo_path: Path) -> str: - """ - Find the commit hash to use as the source to compare against. - - If relative_to is "HEAD", it will return the current commit hash. - - If relative_to is a branch name, it will find the last common ancestor between that branch and the current state. - - If relative_to is a commit hash or tag, it will return that commit hash. - """ - repo = SyncLocalGitRepo(repo_path) - if relative_to.startswith("HEAD"): - # The current commit hash or relative to it (e.g. "HEAD~1") - base_commit = repo.run_git(["rev-parse", relative_to], check=True) - else: - # Check if relative_to is the name of a branch. - is_branch = repo.is_commit_a_branch(relative_to) - if is_branch: - # Yes, it's a branch. - # Since we're comparing to a branch, this command will find the last common ancestor - # between that branch and the current state. This is typically what we want for branches. - # (Think of this as getting the diff that would be applied if this branch was to be merged into relative_to.) - base_commit = repo.get_merge_base(relative_to, "HEAD") - else: - # Not a branch. relative_to might be a commit hash or tag. - base_commit = relative_to - - return base_commit diff --git a/imbue_tools/imbue_tools/repo_utils/stubify_file.py b/imbue_tools/imbue_tools/repo_utils/stubify_file.py @@ -17,16 +17,6 @@ def check_on_body(stmt: cst.CSTNode, check: Callable[[cst.CSTNode], bool]) -> bo return check(first_body_item) -def is_assign(stmt: cst.CSTNode) -> bool: - if not m.matches(stmt, m.SimpleStatementLine()): - return False - # pyre-ignore[16]: m.SimpleStatementLine has a body attribute which is a Sequence - first_body_item = stmt.body[0] - if m.matches(first_body_item, m.Assign()): - return True - return False - - class CompressTransformer(cst.CSTTransformer): DESCRIPTION = str = "Replaces function body with ..." replacement_string = '"__FUNC_BODY_REPLACEMENT_STRING__"' @@ -35,7 +25,9 @@ class CompressTransformer(cst.CSTTransformer): self.keep_constant = keep_constant self.keep_indent = keep_indent - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + def leave_Module( + self, original_node: cst.Module, updated_node: cst.Module + ) -> cst.Module: new_body = [ stmt for stmt in updated_node.body @@ -43,12 +35,16 @@ class CompressTransformer(cst.CSTTransformer): or m.matches(stmt, m.FunctionDef()) or ( self.keep_constant - and check_on_body(stmt, lambda first_body_item: m.matches(first_body_item, m.Assign())) + and check_on_body( + stmt, lambda first_body_item: m.matches(first_body_item, m.Assign()) + ) ) ] return updated_node.with_changes(body=new_body) - def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + def leave_ClassDef( + self, original_node: cst.ClassDef, updated_node: cst.ClassDef + ) -> cst.ClassDef: # Remove docstring in the class body new_body = [ stmt @@ -56,13 +52,18 @@ class CompressTransformer(cst.CSTTransformer): if not check_on_body( stmt, lambda first_body_item: m.matches(first_body_item, m.Expr()) - or (hasattr(first_body_item, "value") and m.matches(first_body_item.value, m.SimpleString())), + or ( + hasattr(first_body_item, "value") + and m.matches(first_body_item.value, m.SimpleString()) + ), ) ] # pyre-fixme[6]: cst.IndentedBlock has a body attribute which is a Sequence[BaseStatement], not a Sequence[BaseSmallStatement] like new_body return updated_node.with_changes(body=cst.IndentedBlock(body=new_body)) - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.BaseStatement: + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.BaseStatement: if not self.keep_indent: # replace with unindented statement new_expr = cst.Expr(value=cst.SimpleString(value=self.replacement_string)) @@ -75,57 +76,9 @@ class CompressTransformer(cst.CSTTransformer): new_expr = [ cst.Expr(value=cst.SimpleString(value=self.replacement_string)), ] - return updated_node.with_changes(body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=new_expr)])) - - -class GlobalVariableVisitor(cst.CSTVisitor): - METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,) - - def __init__(self) -> None: - self.assigns: list[list[Any]] = [] - - def leave_Assign(self, original_node: cst.Assign) -> None: - stmt = original_node - start_pos = self.get_metadata(cst.metadata.PositionProvider, stmt).start - end_pos = self.get_metadata(cst.metadata.PositionProvider, stmt).end - self.assigns.append([stmt, start_pos, end_pos]) - - -def remove_lines(raw_code: str, remove_line_intervals: list[tuple[int, int]]) -> str: - # TODO: speed up this function - # remove_line_intervals.sort() - - # Remove lines - new_code = "" - for i, line in enumerate(raw_code.splitlines()): - # intervals are one-based - if not any(start <= i + 1 <= end for start, end in remove_line_intervals): - new_code += line + "\n" - if any(start == i + 1 for start, _ in remove_line_intervals): - new_code += "...\n" - return new_code - - -def compress_assign_stmts(raw_code: str, total_lines: int = 30, prefix_lines: int = 10, suffix_lines: int = 10) -> str: - try: - tree = cst.parse_module(raw_code) - except cst.ParserSyntaxError as e: - logger.debug( - "Failed to compress assign statements: {exception_class}, {exception}", - exception_class=e.__class__.__name__, - exception=e, - ) - return raw_code - - wrapper = cst.metadata.MetadataWrapper(tree) - visitor = GlobalVariableVisitor() - wrapper.visit(visitor) - - remove_line_intervals = [] - for stmt in visitor.assigns: - if stmt[2].line - stmt[1].line > total_lines: - remove_line_intervals.append((stmt[1].line + prefix_lines, stmt[2].line - suffix_lines)) - return remove_lines(raw_code, remove_line_intervals) + return updated_node.with_changes( + body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=new_expr)]) + ) def stubify_code_file( @@ -133,10 +86,6 @@ def stubify_code_file( raw_code: str, keep_constant: bool = True, keep_indent: bool = False, - compress_assign: bool = False, - total_lines: int = 30, - prefix_lines: int = 10, - suffix_lines: int = 10, ) -> str: try: tree = cst.parse_module(raw_code) @@ -148,14 +97,6 @@ def stubify_code_file( modified_tree = tree.visit(transformer) code = modified_tree.code - if compress_assign: - code = compress_assign_stmts( - code, - total_lines=total_lines, - prefix_lines=prefix_lines, - suffix_lines=suffix_lines, - ) - if keep_indent: code = code.replace(CompressTransformer.replacement_string + "\n", "...\n") code = code.replace(CompressTransformer.replacement_string, "...\n") diff --git a/imbue_tools/imbue_tools/repo_utils/subrepo_formatting.py b/imbue_tools/imbue_tools/repo_utils/subrepo_formatting.py @@ -1,5 +1,4 @@ import functools -from collections import defaultdict from enum import Enum from typing import Annotated from typing import Iterable @@ -239,28 +238,6 @@ def format_subrepo(formatted_repo_contents: Mapping[str, str]) -> str: return escape_all_jinja_variables(escape_prompt_markers(repo_context_str)) -def format_subrepo_context_full(repo: Mapping[str, str]) -> str: - """Like get_repo_context but there's no checking for context limits (so we can use the api checks instead) - and the selected strategy is always the full repo contents.""" - formatted_repo_context = format_subrepo( - format_all_for_agent( - format_subrepo_context_into_filecontexts( - full_repo_contents=repo, - path_to_format_style=defaultdict(lambda: ContextFormatStyle.FULL_FILE, {}), - ) - ) - ) - - repo_context_core_prompt = formatted_subrepo_to_prompt( - repo_context_str=formatted_repo_context, - is_shortened=False, - has_hidden_files=False, - template=BASE_REPO_CONTEXT_TEMPLATE, - ) - - return repo_context_core_prompt - - def formatted_subrepo_to_prompt( repo_context_str: str, is_shortened: bool, has_hidden_files: bool, template: str ) -> str: diff --git a/vet/api.py b/vet/api.py @@ -39,7 +39,7 @@ def get_issues_with_raw_responses( # should be not None and not empty if conversation_history: try: - # TODO: we use the imbue verify config here, but we may want to configure this separately + # TODO: we use the config here, but we may want to configure this separately goal = get_goal_from_conversation(conversation_history, config.language_model_generation_config) logger.info("Generated goal from conversation history: {}", goal) except Exception as e: diff --git a/vet/conftest.py b/vet/conftest.py @@ -9,8 +9,7 @@ from imbue_core.test_repo_utils import make_simple_test_git_repo simple_test_git_repo = pytest.fixture(make_simple_test_git_repo) -# this is copied from sculptor/conftest.py -# (it must be copied rather than imported because of the autouse) +# This fixture must be defined locally (not imported) because of the autouse flag. @pytest.fixture(autouse=True) def always_explode_on_error( explode_on_error: Callable[[], Generator[None, None, None]], diff --git a/vet/errors.py b/vet/errors.py @@ -1,2 +1,19 @@ +import subprocess +from typing import Any + + class GitException(Exception): pass + + +class RunCommandError(subprocess.CalledProcessError): + """Custom exception for errors encountered during shell commands.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.cwd = kwargs.get("cwd", None) + if "cwd" in kwargs: + del kwargs["cwd"] + super().__init__(*args, **kwargs) + + def __str__(self) -> str: + return f"Command `{self.cmd}` returned non-zero exit status {self.returncode}.\nOutput: {self.stdout}\nError: {self.stderr}\nCWD: {self.cwd}" diff --git a/vet/git.py b/vet/git.py @@ -0,0 +1,236 @@ +"""Git utilities for vet.""" + +import shlex +import subprocess +import time +from pathlib import Path +from typing import Sequence + +import anyio +from loguru import logger + +from imbue_core.async_monkey_patches import log_exception +from vet.errors import RunCommandError + +# Flexible path type alias +AnyPath = Path | str | anyio.Path + + +class SyncLocalGitRepo: + """ + Provides different operations that you can perform over a git repository. + """ + + _base_path: Path + + def __init__(self, base_path: Path) -> None: + self._base_path = base_path + + @property + def base_path(self) -> Path: + """The base path of the git repo.""" + return self._base_path + + def run_git( + self, + command: Sequence[str], + check: bool = True, + cwd: AnyPath | None = None, + is_error_logged: bool = True, + is_stripped: bool = True, + retry_on_git_lock_error: bool = True, + ) -> str: + """Run a git command in the repo. + + Example: + ``` + git_repo.run_git("status") + ``` + """ + command = ["git"] + list(command) + if not retry_on_git_lock_error: + result = self.run_command(command, check=check, is_error_logged=is_error_logged, cwd=cwd) + else: + result = self._run_command_with_retry_on_git_lock_error( + command, check=check, is_error_logged=is_error_logged, cwd=cwd + ) + if is_stripped: + return result.strip() + return result + + def run_command( + self, + command: Sequence[str], + check: bool = True, + secrets: dict[str, str] | None = None, + cwd: AnyPath | None = None, + is_error_logged: bool = True, + ) -> str: + """Run a command in the repo. + + Note, this can be used to run any command, not just git. + """ + command_string = shlex.join(command) + logger.trace( + f"Running command: {command_string=} from cwd={cwd or self.base_path} with {secrets=} {check=} {is_error_logged=}" + ) + completed_proc = subprocess.run( + command, + cwd=cwd or self._base_path, + stdin=subprocess.DEVNULL, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=secrets, + ) + # note, need to be carefull not to strip() lines since whitespace may be important (e.g. for diffs) + # return joined lines since mostly we only use the output for logging, and this way we arn't + # passing around lots of lists. Also it's easy to parse by lines if needed + try: + stdout = completed_proc.stdout.decode("UTF-8") + except UnicodeDecodeError as e: + # If we don't encounter this, it likely means something was fixed upstream and we can safely delete + log_exception( + e, + "Command {command_string} failed to decode stdout, replacing any invalid bytes which could lead to problems later", + command_string=command_string, + ) + stdout = completed_proc.stdout.decode("UTF-8", errors="replace") + stderr = completed_proc.stderr.decode("UTF-8") + if check and completed_proc.returncode != 0: + error_message = f"command run from cwd={self.base_path} failed with exit code {completed_proc.returncode} and stdout:\n{stdout}\nstderr:\n{stderr}" + if is_error_logged: + logger.error( + f"command attempted: '{command_string}' from cwd={self.base_path}\nerror message: {error_message}" + ) + # this should not be None, but do this to satisfy type checker, int or None we throw the same error + returncode = completed_proc.returncode or -1 + raise RunCommandError( + cmd=command_string, + stderr=stderr, + returncode=returncode, + cwd=cwd or self.base_path, + ) + return stdout + + def get_git_diff( + self, + commit_hash: str | None = None, + staged: bool = False, + is_error_logged: bool = True, + include_binary: bool = True, + ) -> str: + """Get the diff for the current repo state.""" + # make sure `is_stripped=False` otherwise patch can be invalid + command = ["diff", "--full-index"] + if include_binary: + # Without --binary, diffs of binary files will just contain a summary statement such as "Binary files a/file.bin and b/file.bin differ". + # Such diffs cannot be applied, but are useful for inclusion in LLM prompts. + command.append("--binary") + if staged: + command.append("--staged") + if commit_hash: + command.append(commit_hash) + return self.run_git(command, is_stripped=False, is_error_logged=is_error_logged) + + def get_untracked_files(self) -> tuple[str, ...]: + """Get the untracked files in the repo.""" + result = self.run_git(["ls-files", "--others", "--exclude-standard"], is_error_logged=False) + return tuple([line.strip() for line in result.splitlines() if line.strip()]) + + def get_untracked_file_diff(self, file_path: str, include_binary: bool = True) -> str: + """Get the diff for a untracked file. + + Note this function will raise a RunCommandError if the there is no diff for the untracked file or if there + is another error running the command. So it is best to use this function after checking that the file is untracked + using get_untracked_files function. + """ + command = ["diff", "--no-index"] + if include_binary: + command.append("--binary") + untracked_diff = self.run_git( + command + ["/dev/null", str(file_path)], + # Unfortunately, `--no-index` implies `--exit-code`, which will cause git diff to return an exit code of 1 + # if the diff is not empty. So we can't use check=True here. We'll check for an empty output to detect failures. + check=False, + is_error_logged=True, + is_stripped=False, + ) + if not untracked_diff: + raise RunCommandError(f"Unable to diff untracked file {file_path}") + return untracked_diff + + def is_commit_a_branch(self, commit_hash: str) -> bool: + """Check if the given git ref is a branch.""" + try: + self.run_git( + ("show-ref", "--verify", "-q", f"refs/heads/{commit_hash}"), + is_error_logged=False, + check=True, + ) + return True + except RunCommandError as e: + if e.returncode == 1: + return False + raise + + def get_merge_base(self, branch_name: str, target_branch: str) -> str: + """Get the merge base of the given branch and target branch. + + The merge base is the most recent commit that is on both branches. + """ + return self.run_git(["merge-base", branch_name, target_branch], is_error_logged=False) + + def _run_command_with_retry_on_git_lock_error( + self, + command: Sequence[str], + check: bool = True, + is_error_logged: bool = True, + cwd: AnyPath | None = None, + ) -> str: + max_retries = 50 + retry_count = 0 + retry_delay = 0.1 # seconds + while True: + try: + return self.run_command( + command, + check=check, + is_error_logged=is_error_logged and retry_count >= max_retries, + cwd=cwd, + ) + except RunCommandError as e: + error_message = str(e) + if "fatal: Unable to create" in error_message and ".git/index.lock': File exists" in error_message: + if retry_count >= max_retries: + raise + time.sleep(retry_delay) + retry_count += 1 + else: + raise + + +def find_relative_to_commit_hash(relative_to: str, repo_path: Path) -> str: + """ + Find the commit hash to use as the source to compare against. + - If relative_to is "HEAD", it will return the current commit hash. + - If relative_to is a branch name, it will find the last common ancestor between that branch and the current state. + - If relative_to is a commit hash or tag, it will return that commit hash. + """ + repo = SyncLocalGitRepo(repo_path) + if relative_to.startswith("HEAD"): + # The current commit hash or relative to it (e.g. "HEAD~1") + base_commit = repo.run_git(["rev-parse", relative_to], check=True) + else: + # Check if relative_to is the name of a branch. + is_branch = repo.is_commit_a_branch(relative_to) + if is_branch: + # Yes, it's a branch. + # Since we're comparing to a branch, this command will find the last common ancestor + # between that branch and the current state. This is typically what we want for branches. + # (Think of this as getting the diff that would be applied if this branch was to be merged into relative_to.) + base_commit = repo.get_merge_base(relative_to, "HEAD") + else: + # Not a branch. relative_to might be a commit hash or tag. + base_commit = relative_to + + return base_commit diff --git a/vet/issue_identifiers/agentic_issue_collation.py b/vet/issue_identifiers/agentic_issue_collation.py @@ -132,7 +132,7 @@ def collate_issues_with_agent( issues: The issues to collate. identifier_inputs: The inputs which determine the content provided to the identifiers. project_context: Loaded data corresponding to the inputs, e.g. diffs or files. - config: Settings for imbue verify. + config: Settings enabled_issue_codes: The issue types used by the issue identifiers. Returns: diff --git a/vet/issue_identifiers/base.py b/vet/issue_identifiers/base.py @@ -39,7 +39,7 @@ class IssueIdentifier(SerializableModel, abc.ABC, Generic[T]): Args: identifier_inputs: The inputs which determine the content provided to the identifier. project_context: Loaded data corresponding to the inputs, e.g. diffs or files. - config: Settings for imbue verify. + config: Settings Returns: A generator of identified issues. When done iterating, returns the debug info. diff --git a/vet/issue_identifiers/issue_deduplication.py b/vet/issue_identifiers/issue_deduplication.py @@ -119,7 +119,7 @@ def deduplicate_issues( Args: issues: The issues to deduplicate. - config: Settings for imbue verify. + config: Settings enabled_issue_codes: The issue types used by the issue identifiers. Returns: diff --git a/vet/issue_identifiers/issue_evaluation.py b/vet/issue_identifiers/issue_evaluation.py @@ -266,7 +266,7 @@ def filter_issues( results: The issues to filter. inputs: The inputs which determine the content provided to the evaluator. project_context: Loaded data corresponding to the inputs, e.g. diffs or files. - config: Settings for imbue verify. + config: Settings Returns: A generator of issues with the passes_filtration flag set. diff --git a/vet/repo_utils.py b/vet/repo_utils.py @@ -1,10 +1,10 @@ from pathlib import Path from imbue_core.async_monkey_patches import log_exception -from imbue_core.computing_environment.data_types import RunCommandError -from imbue_core.simple_git import SyncLocalGitRepo -from imbue_tools.repo_utils.find_relative_to import find_relative_to_commit_hash from vet.errors import GitException +from vet.errors import RunCommandError +from vet.git import SyncLocalGitRepo +from vet.git import find_relative_to_commit_hash # Maximum length of LLM prompts used within vet in tokens, without the repository-specific context. # Currently, the prompt is well under 10k tokens, but this value might need to be bumped up if we add a lot of additional diff --git a/vet_types/vet_types/chat_state.py b/vet_types/vet_types/chat_state.py @@ -56,9 +56,7 @@ class CommandBlock(ContentBlock): object_type: str = "CommandBlock" type: Literal["command"] = "command" command: str - is_automated: bool = Field( - default=False, description="Whether the command is automated" - ) + is_automated: bool = Field(default=False, description="Whether the command is automated") ToolInput = dict[str, Any] @@ -69,17 +67,13 @@ class ToolUseBlock(ContentBlock): type: Literal["tool_use"] = "tool_use" id: ToolUseID = Field(..., description="Unique identifier for this tool use") name: str = Field(..., description="Name of the tool being used") - input: ToolInput = Field( - default_factory=ToolInput, description="Input parameters for the tool" - ) + input: ToolInput = Field(default_factory=ToolInput, description="Input parameters for the tool") class ToolResultContent(SerializableModel): """Base class for tool result content with type discriminator""" - content_type: str = Field( - ..., description="Type discriminator for tool result content" - ) + content_type: str = Field(..., description="Type discriminator for tool result content") class SimpleToolContent(ToolResultContent): @@ -114,15 +108,9 @@ class ToolResultBlock(ContentBlock): type: Literal["tool_result"] = "tool_result" tool_use_id: ToolUseID = Field(..., description="ID of the corresponding tool use") tool_name: str = Field(..., description="Name of the tool that was used") - invocation_string: str = Field( - ..., description="String representation of how the tool was invoked" - ) - content: ToolResultContentType = Field( - ..., description="Result content from the tool execution" - ) - is_error: bool = Field( - default=False, description="Whether the tool execution resulted in an error" - ) + invocation_string: str = Field(..., description="String representation of how the tool was invoked") + content: ToolResultContentType = Field(..., description="Result content from the tool execution") + is_error: bool = Field(default=False, description="Whether the tool execution resulted in an error") class WarningBlock(ContentBlock): @@ -130,9 +118,7 @@ class WarningBlock(ContentBlock): type: Literal["warning"] = "warning" message: str = Field(..., description="Warning message") traceback: str | None = Field(..., description="Warning traceback") - warning_type: str | None = Field( - ..., description="Type of warning, i.e. name of the exception that was raised" - ) + warning_type: str | None = Field(..., description="Type of warning, i.e. name of the exception that was raised") class ErrorBlock(ContentBlock): @@ -140,9 +126,7 @@ class ErrorBlock(ContentBlock): type: Literal["error"] = "error" message: str = Field(..., description="Error message") traceback: str = Field(..., description="Error traceback") - error_type: str = Field( - ..., description="Type of error, i.e. name of the exception that was raised" - ) + error_type: str = Field(..., description="Type of error, i.e. name of the exception that was raised") class FileBlock(ContentBlock): diff --git a/vet_types/vet_types/ids.py b/vet_types/vet_types/ids.py @@ -32,23 +32,17 @@ class ObjectID(TypeID, ABC): # For convenience, don't require the caller to strip the prefix from existing IDs. if prefix is not None: if prefix != self.tag: - raise TypeIDPrefixMismatchError( - f"Expected prefix '{self.tag}', got '{prefix}'" - ) + raise TypeIDPrefixMismatchError(f"Expected prefix '{self.tag}', got '{prefix}'") value = suffix super().__init__(self.tag, value) @classmethod - def __get_pydantic_core_schema__( - cls, source_type: type, handler: GetCoreSchemaHandler - ) -> core_schema.CoreSchema: + def __get_pydantic_core_schema__(cls, source_type: type, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: """ Support transparently deserializing strings into ObjectID instances and vice versa. """ return core_schema.no_info_before_validator_function( - lambda raw_value: ( - cls(raw_value) if isinstance(raw_value, str) else raw_value - ), + lambda raw_value: (cls(raw_value) if isinstance(raw_value, str) else raw_value), core_schema.union_schema( [ core_schema.is_instance_schema(cls), @@ -77,16 +71,12 @@ class NonEmptyStr(str): return value @classmethod - def __get_pydantic_core_schema__( - cls, source_type: type, handler: GetCoreSchemaHandler - ) -> core_schema.CoreSchema: + def __get_pydantic_core_schema__(cls, source_type: type, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: """ Support transparently deserializing strings into ObjectID instances and vice versa. """ return core_schema.no_info_before_validator_function( - lambda raw_value: ( - cls(raw_value) if isinstance(raw_value, str) else raw_value - ), + lambda raw_value: (cls(raw_value) if isinstance(raw_value, str) else raw_value), core_schema.union_schema( [ core_schema.is_instance_schema(cls), diff --git a/vet_types/vet_types/messages.py b/vet_types/vet_types/messages.py @@ -36,20 +36,18 @@ class LLMModel(StrEnum): class AgentMessageSource(StrEnum): """ - Messages can come the AGENT (in-container LLM), USER (chat messages & direct interactions), - SCULPTOR_SYSTEM (multifaceted sculptor app and service code) and RUNNER (the process - controlling a task on the server.) + Messages can come from the AGENT (LLM), USER (chat messages & direct interactions), + SYSTEM (app and service code) and RUNNER (the process controlling a task). """ - # Messages coming directly from the agent from inside the environment. + # Messages coming directly from the agent. AGENT = "AGENT" - # Messages coming directly from a user interacting with the interface, ie chat + # Messages coming directly from a user interacting with the interface, ie chat. USER = "USER" - # Messages coming from sculptor-mediated actions and automations, like local sync updates - # or manual sync operations. - SCULPTOR_SYSTEM = "SCULPTOR_SYSTEM" + # Messages coming from system-mediated actions and automations. + SYSTEM = "SYSTEM" # Messages coming from the task runner wrapper, such as environment shutdown. RUNNER = "RUNNER" @@ -65,15 +63,11 @@ class Message(SerializableModel): # the source of the message, which can be either the agent, user, or runner. source: AgentMessageSource # roughly when the message was created, in UTC. - approximate_creation_time: datetime.datetime = Field( - default_factory=get_current_time - ) + approximate_creation_time: datetime.datetime = Field(default_factory=get_current_time) @property def is_ephemeral(self) -> bool: - raise NotImplementedError( - "All messages must be subclassed off of PersistentMessage or EphemeralMessage" - ) + raise NotImplementedError("All messages must be subclassed off of PersistentMessage or EphemeralMessage") class PersistentMessage(Message):