commit be084ce91b81503c5bb3978eb760b34fdb050e9b
parent 91088763849a9c1db036598089fbfe0f6e725e09
Author: andrewlaack-collab <andrew.laack@imbue.com>
Date: Sun, 1 Feb 2026 02:35:56 +0000
Remove unnecessary things (#5)
* Removed some internal notes
* Remove useless git stuff
* Refactoring git stuff and computing environment
* Unused files removal
* Dead code
* Removed more unused code.
* More pruning
* More code
* Removing
* Removals
Diffstat:
52 files changed, 438 insertions(+), 3866 deletions(-)
diff --git a/README.md b/README.md
@@ -1,4 +1,4 @@
-# Vet : Verify EveryThing
+# Vet : Verify Everything
Vet is a standalone verification tool for **code changes** and **coding agent behavior**.
@@ -7,7 +7,7 @@ It reviews git diffs, and optionally an agent's conversation history, to find is
## Installation
```bash
-pip install vet
+pip install verify-everything
```
## Quickstart
diff --git a/imbue_core/imbue_core/agents/agent_api/client.py b/imbue_core/imbue_core/agents/agent_api/client.py
@@ -80,7 +80,6 @@ class CachedAgentClient(AgentClient[AgentOptionsT]):
# This means that if the generator is not exhausted, the cache will not be updated.
# If we do want a way to still cache interactions, even if we early exit the generator,
# then we could use a separate thread to get the agent response and cache it in the background.
- # See https://gitlab.com/generally-intelligent/generally_intelligent/-/merge_requests/7323#note_2897340073
agent_interaction_record = AgentInteractionRecord.from_agent_interaction(agent_interaction)
update_cache(agent_interaction_record, cache_path)
diff --git a/imbue_core/imbue_core/agents/data_types/__init__.py b/imbue_core/imbue_core/agents/data_types/__init__.py
diff --git a/imbue_core/imbue_core/agents/data_types/ids.py b/imbue_core/imbue_core/agents/data_types/ids.py
@@ -1,64 +0,0 @@
-from abc import ABC
-
-from pydantic import GetCoreSchemaHandler
-from pydantic_core import core_schema
-from typeid import TypeID
-from typeid import get_prefix_and_suffix
-
-
-class TypeIDPrefixMismatchError(Exception):
- pass
-
-
-class ObjectID(TypeID, ABC):
- """
- A convenience class for string-based object IDs.
-
- Use in place of strings for IDs. (We don't use raw UUIDs because they are not supported by SQLite.)
-
- Use `tag` to prefix the ID with the ID type. (We don't use `prefix` because it's already taken by the ancestor class.)
-
- """
-
- # Override this in subclasses to specify the ID type.
- tag: str = "oid"
-
- def __init__(self, value: str | None = None) -> None:
- if value is not None:
- prefix, suffix = get_prefix_and_suffix(value)
- # For convenience, don't require the caller to strip the prefix from existing IDs.
- if prefix is not None:
- if prefix != self.tag:
- raise TypeIDPrefixMismatchError(f"Expected prefix '{self.tag}', got '{prefix}'")
- value = suffix
- super().__init__(self.tag, value)
-
- @classmethod
- def __get_pydantic_core_schema__(cls, source_type: type, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
- """
- Support transparently deserializing strings into ObjectID instances and vice versa.
- """
- return core_schema.no_info_before_validator_function(
- lambda raw_value: (cls(raw_value) if isinstance(raw_value, str) else raw_value),
- core_schema.union_schema(
- [
- core_schema.is_instance_schema(cls),
- core_schema.str_schema(),
- ]
- ),
- serialization=core_schema.plain_serializer_function_ser_schema(
- lambda instance: str(instance), return_schema=core_schema.str_schema()
- ),
- )
-
-
-class TaskID(ObjectID):
- tag: str = "tsk"
-
-
-class ProjectID(ObjectID):
- tag: str = "prj"
-
-
-class AgentMessageID(ObjectID):
- tag: str = "agm"
diff --git a/imbue_core/imbue_core/agents/llm_apis/anthropic_api.py b/imbue_core/imbue_core/agents/llm_apis/anthropic_api.py
@@ -781,13 +781,9 @@ class AnthropicAPI(LanguageModelAPI):
def _get_api_key_or_auth_token() -> tuple[str | None, str | None]:
api_key = get_secret("ANTHROPIC_API_KEY")
- # The standard environment variable for this is ANTHROPIC_AUTH_TOKEN,
- # but we don't use it since it has some bad interactions with Claude Code.
- auth_token = get_secret("IMBUE_ANTHROPIC_AUTH_TOKEN")
+ auth_token = get_secret("ANTHROPIC_AUTH_TOKEN")
if not api_key and not auth_token:
- raise MissingAPIKeyError(
- "Neither of ANTHROPIC_API_KEY and IMBUE_ANTHROPIC_AUTH_TOKEN environment variables is set"
- )
+ raise MissingAPIKeyError("Neither ANTHROPIC_API_KEY nor ANTHROPIC_AUTH_TOKEN environment variable is set")
return api_key, auth_token
diff --git a/imbue_core/imbue_core/agents/llm_apis/language_model_api.py b/imbue_core/imbue_core/agents/llm_apis/language_model_api.py
@@ -155,7 +155,7 @@ class LanguageModelAPI(abc.ABC, MutableModel):
if "PYTEST_CURRENT_TEST" not in os.environ:
logger.warning(
- "You are trying to call a language model from outside of a hammer with no global resource limits set. That is a bad idea because the spend will not be restricted, and you may end up accidentally spending much more than you expected."
+ "You are trying to call a language model with no global resource limits set. That is a bad idea because the spend will not be restricted, and you may end up accidentally spending much more than you expected."
)
return None
diff --git a/imbue_core/imbue_core/agents/llm_apis/llm_testing_utils.py b/imbue_core/imbue_core/agents/llm_apis/llm_testing_utils.py
@@ -1,86 +0,0 @@
-from google.genai.types import CountTokensResponse
-from syrupy.assertion import SnapshotAssertion
-from syrupy.extensions.single_file import SingleFileAmberSnapshotExtension
-from syrupy.extensions.single_file import SingleFileSnapshotExtension
-
-from imbue_core.agents.llm_apis.data_types import CachedCostedModelResponse
-from imbue_core.agents.llm_apis.data_types import CostedLanguageModelResponse
-from imbue_core.caching import AsyncCache
-from imbue_core.frozen_utils import FrozenMapping
-
-
-async def check_llm_responses_in_cache(snapshot: SnapshotAssertion, temp_cache: AsyncCache, suffix: str = "") -> None:
- """Runs as the test fixture completes to check that the LLM inputs and outputs stay the same, in a human-readable format."""
-
- async with temp_cache as cache:
- all_keys: tuple[str, ...] = await cache.get_all_keys() # Contains both the streaming and non-streaming keys?
- value_by_key: FrozenMapping[str, CachedCostedModelResponse | None] = await cache.get_all(all_keys)
-
- cache_items: list[tuple[str, CachedCostedModelResponse]] = [
- (k, v) for k, v in value_by_key.items() if v is not None
- ]
- cache_items.sort(key=lambda x: x[1].timestamp)
- for cache_index, (cache_key, cached_response) in enumerate(cache_items):
- prompt: bytes = b""
- joined_responses: bytes = b""
- metadata_lines: list[str] = []
-
- metadata_lines.append(f"{cache_index=} (when cache is sorted by timestamp)")
- metadata_lines.append(f"{cache_key=}") # Keys must be stable and not too big.
- if cached_response.inputs is not None:
- prompt = cached_response.inputs.prompt.encode("utf-8")
- metadata_lines.append(f"request metdata ({type(cached_response.inputs)})")
- for field, field_value in cached_response.inputs.__dict__.items():
- if field != "prompt": # print the prompt separately below.
- metadata_lines.append(f" {field}: {field_value}")
-
- metadata_lines.append("cached_response metadata:")
- for field, field_value in cached_response.__dict__.items():
- if field not in ("inputs", "response"): # already printed above
- metadata_lines.append(f" {field}: {field_value}")
-
- if cached_response.response is not None:
- match cached_response.response:
- case CostedLanguageModelResponse():
- joined_responses = "".join([r.text for r in cached_response.response.responses]).encode("utf-8")
- for response_index, response in enumerate(cached_response.response.responses):
- metadata_lines.append(f"response[{response_index}] metadata:")
- for (
- field,
- field_value,
- ) in cached_response.response.__dict__.items():
- if field != "responses": # already printed the responses above
- metadata_lines.append(f" {field}: {field_value}")
- case CountTokensResponse():
- metadata_lines.append("response metadata:")
- for field, field_value in cached_response.response.__dict__.items():
- metadata_lines.append(f" {field}: {field_value}")
-
- snapshotted_prompt = snapshot(
- extension_class=SingleFileSnapshotExtension,
- name=f"{cache_index:03d}_inputs{suffix}",
- )
-
- # TODO nasty syrupy hacking
- snapshot_contents, _ = snapshotted_prompt._recall_data(snapshotted_prompt.index)
-
- assert (
- snapshotted_prompt == prompt
- ), f"Your prompt changed, did you mean for this to happen?\nExpected prompt: {snapshot_contents!r}\nPrompt: {prompt!r}"
-
- snapshotted_response = snapshot(
- extension_class=SingleFileSnapshotExtension,
- name=f"{cache_index:03d}_response{suffix}",
- )
-
- assert (
- snapshotted_response == joined_responses
- ), "Your response changed; maybe you aren't actually hitting the cache?"
-
- snapshotted_metadata = snapshot(
- extension_class=SingleFileAmberSnapshotExtension,
- name=f"{cache_index:03d}_metadata{suffix}",
- )
- assert snapshotted_metadata == "\n".join(
- metadata_lines
- ), "Metadata changed; maybe you aren't actually hitting the cache?"
diff --git a/imbue_core/imbue_core/async_utils.py b/imbue_core/imbue_core/async_utils.py
@@ -1,45 +1,13 @@
import asyncio
import functools
-import inspect
-import os
-import platform
-import sys
import threading
-import traceback
-from contextlib import AbstractAsyncContextManager
-from contextlib import AbstractContextManager
-from contextlib import contextmanager
-from datetime import datetime
-from http.server import BaseHTTPRequestHandler
-from http.server import HTTPServer
-from pathlib import Path
-from types import FrameType
-from typing import Any
-from typing import AsyncGenerator
from typing import Awaitable
from typing import Callable
-from typing import Coroutine
-from typing import Generator
-from typing import Generic
-from typing import Iterable
-from typing import Iterator
from typing import ParamSpec
from typing import TypeVar
-from typing import cast
-from urllib.parse import parse_qs
-from urllib.parse import urlparse
-
-from loguru import logger
-from traceback_with_variables.core import _iter_lines
-
-from imbue_core.async_monkey_patches import log_exception
-from imbue_core.async_monkey_patches import safe_cancel
P = ParamSpec("P")
R = TypeVar("R")
-S = TypeVar("S")
-
-ALL_EVENT_LOOPS: list[asyncio.AbstractEventLoop] = []
def sync(func: Callable[P, Awaitable[R]]) -> Callable[P, R]:
@@ -55,56 +23,6 @@ def sync(func: Callable[P, Awaitable[R]]) -> Callable[P, R]:
return wrapper
-def sync_generator(
- func: Callable[P, AsyncGenerator[R, None]],
-) -> Callable[P, Generator[R, None, None]]:
- """Decorator that runs an async generator synchronously by dispatching to
- an event loop running in a separate thread.
- """
-
- @functools.wraps(func)
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> Generator[R, None, None]:
- loop = _get_or_create_event_loop()
- agen = func(*args, **kwargs)
- while True:
- try:
- future = asyncio.run_coroutine_threadsafe(agen.__anext__(), loop)
- yield future.result()
- except StopAsyncIteration:
- break
-
- return wrapper
-
-
-@contextmanager
-# pyre-ignore[24]: pyre doesn't understand AbstractAsyncContextManager
-def sync_contextmanager(
- async_context_manager: AbstractAsyncContextManager[S],
-) -> Generator[S, None, None]:
- sync_aenter = sync(async_context_manager.__aenter__)
- sync_aexit = sync(async_context_manager.__aexit__)
-
- enter_result = sync_aenter()
- try:
- yield enter_result
- except BaseException as e:
- if not sync_aexit(e.__class__, e, e.__traceback__):
- raise
- else:
- sync_aexit(None, None, None)
-
-
-# pyre doesn't understand AbstractAsyncContextManager
-def sync_contextmanager_func(
- cm_func: Callable[P, AbstractAsyncContextManager[S]], # pyre-ignore[24]
-) -> Callable[P, AbstractContextManager[S]]: # pyre-ignore[24]
- @functools.wraps(cm_func)
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> AbstractContextManager[S]: # pyre-ignore[24]
- return sync_contextmanager(cm_func(*args, **kwargs))
-
- return wrapper
-
-
_LOOP: asyncio.AbstractEventLoop | None = None
_LOOP_LOCK: threading.Lock = threading.Lock()
@@ -124,395 +42,6 @@ def _get_or_create_event_loop() -> asyncio.AbstractEventLoop:
return _LOOP # pyre-ignore[7]: we just made _LOOP, so it's not None unless it got destroyed just now
-def shorten_filename(filename: str) -> str:
- path = Path(filename)
- while path.parent:
- path = path.parent
- if not (path / "__init__.py").exists():
- break
-
- try:
- shortened = str(Path(filename).relative_to(path))
- except ValueError:
- shortened = filename # in case the path cannot be made relative
-
- return shortened
-
-
-# TODO: I'd really like to print these task groups in a hierarchical way instead of flat -- we know which groups
-# launched which other groups, so we could print them in a tree structure. That would be a lot more readable.
-# It might also be nice to print without any stacks at all. As long as the tasks had good names, that would make it
-# possible to very easily understand everything that was currently executing.
-# I could even imagine controls that allowed for printing just the task groups themselves, which would also be easier
-# to understand.
-def get_all_async_task_stacks(
- num_skipped_frames: int = 0,
- log_variables: bool = False,
- loop: asyncio.AbstractEventLoop | None = None,
-) -> Iterator[str]:
- """Yields the lines of a report for all stack frames of all async tasks including variables."""
- tasks_by_task_group: dict[asyncio.TaskGroup | None, list[asyncio.Task]] = {}
- owning_task_by_task_group: dict[asyncio.TaskGroup, asyncio.Task] = {}
-
- for task in asyncio.all_tasks(loop=loop):
- if task.done():
- continue
- task_group = cast(asyncio.TaskGroup | None, getattr(task, "task_group", None))
- owned_task_group = cast(asyncio.TaskGroup | None, getattr(task, "owned_task_group", None))
- if owned_task_group is not None:
- owning_task_by_task_group[owned_task_group] = task
- if owned_task_group not in tasks_by_task_group:
- tasks_by_task_group[owned_task_group] = []
- else:
- tasks_by_task_group.setdefault(task_group, []).append(task)
-
- all_owning_tasks = set(owning_task_by_task_group.values())
-
- task_group_keys = list(tasks_by_task_group.keys())
- for task_group in cast(list[asyncio.TaskGroup | None], [None]) + [x for x in task_group_keys if x is not None]:
- if task_group not in tasks_by_task_group:
- continue
- tasks = tasks_by_task_group[task_group]
- if task_group is None:
- yield f"\n\n{'=' * 40}\nNo TaskGroup:\n"
- else:
- yield f"\n\n{'=' * 40}\nTaskGroup: {getattr(task_group, 'name', 'unknown')}\n"
- owning_task = None
- if task_group is not None:
- owning_task = owning_task_by_task_group.get(task_group)
- is_first_line_skipped = False
- all_tasks = tasks
- if owning_task is not None:
- yield f"Owning Task: {owning_task.get_name()}\n"
- is_first_line_skipped = True
- all_tasks.insert(0, owning_task)
- for task in all_tasks:
- # skip these -- they'll be printed at the top of the group that they own
- if task_group is None and task in all_owning_tasks:
- continue
- if is_first_line_skipped:
- is_first_line_skipped = False
- else:
- yield f"{'-' * 40}\nTask {task.get_name()}:\n"
- frames = extract_frames(task)
- for frame in frames:
- frame_infos = inspect.getouterframes(frame)[num_skipped_frames:]
- # Use private method _iter_lines to traceback async tasks, which is not explicitly handled in the API
- if log_variables:
- for line in _iter_lines(
- e=None,
- frame_infos=frame_infos,
- fmt=None,
- for_file=None,
- ):
- yield line + "\n"
- else:
- frame_summaries = [
- traceback.FrameSummary(
- shorten_filename(info.filename),
- lineno=info.lineno,
- name=info.function,
- line=(info.code_context[0].strip() if info.code_context else None),
- )
- for info in frame_infos
- ]
- yield from traceback.format_list(frame_summaries)
-
-
-def extract_frames(task: asyncio.Task) -> list[FrameType]:
- """Extract the stack frames of an async task."""
- coro = task.get_coro()
- assert isinstance(coro, Coroutine)
- frames = []
- while coro is not None and coro.cr_frame is not None:
- frames.append(coro.cr_frame)
- coro = coro.cr_await # type: ignore
- # this happens at the very bottom of the call stack, there it seems to often be a FutureIter, Event, etc
- if type(coro).__name__ != "coroutine":
- break
- return frames
-
-
-def print_all_async_task_stacks(log_variables: bool = False) -> None:
- """Prints the stack frames of all running tasks."""
- for line in get_all_async_task_stacks(log_variables=log_variables):
- print(line)
-
-
-def dump_all_async_task_stacks(log_path: str | Path, log_variables: bool = False) -> None:
- """Dump the stack frames of all running tasks to file."""
- with open(log_path, "w") as f:
- for line in get_all_async_task_stacks(log_variables=log_variables):
- if log_variables:
- line += "\n"
- f.write(line)
-
-
-async def periodically_log_async_stacks(log_dir: str | Path, interval: float, log_variables: bool = False) -> None:
- """Periodically print the stack traces of all running tasks."""
- Path(log_dir).mkdir(parents=True, exist_ok=True)
- while True:
- log_path = Path(log_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
- dump_all_async_task_stacks(log_path=log_path, log_variables=log_variables)
- logger.debug("Dumped asyncio stack traces to {}", log_path)
- await asyncio.sleep(interval)
-
-
-async def is_task_group_complete(
- task_group: asyncio.TaskGroup, trace_task: asyncio.Task, buffer_time: float = 1.0
-) -> None:
- """Continuously check if all tasks except the stack trace logger are done."""
- while True:
- if all(task.done() for task in task_group._tasks if task is not trace_task):
- await asyncio.sleep(buffer_time) # Wait for buffer time in case new tasks are added
- # Recheck to confirm
- if all(task.done() for task in task_group._tasks if task is not trace_task):
- break
- await asyncio.sleep(1.0)
-
-
-async def inject_async_stack_trace_logger(
- task_group: asyncio.TaskGroup,
- log_dir: str | Path,
- log_interval: float = 60.0,
- log_variables: bool = False,
-) -> None:
- """Inject a stack trace logger into the task group."""
- trace_task = asyncio.create_task(
- periodically_log_async_stacks(log_dir=log_dir, interval=log_interval, log_variables=log_variables),
- name="periodically_log_async_stacks",
- )
- await is_task_group_complete(task_group, trace_task)
- safe_cancel(trace_task)
- try:
- await trace_task
- except asyncio.CancelledError:
- pass
-
-
-class AsyncTaskStacksHandler(BaseHTTPRequestHandler):
- def do_GET(self) -> None:
- try:
- parsed_url = urlparse(self.path)
- query_params = parse_qs(parsed_url.query)
- log_variables = query_params.get("locals", ["false"])[0].lower() in [
- "true",
- "1",
- "yes",
- ]
-
- self.send_response(200)
- self.send_header("Content-type", "text/plain")
- self.end_headers()
-
- pid = os.getpid()
- command_line = " ".join(sys.argv)
-
- # Path to the Python executable
- python_executable = sys.executable
-
- # Python version
- python_version = platform.python_version()
-
- # Print the collected information
- self.wfile.write(f"Process {pid}: {python_executable} {command_line}\n".encode("utf-8"))
- self.wfile.write(f"Python v{python_version} ({python_executable})\n\n".encode("utf-8"))
-
- for loop in ALL_EVENT_LOOPS:
- for line in get_all_async_task_stacks(log_variables=log_variables, loop=loop):
- self.wfile.write(line.encode("utf-8"))
- except BaseException as e:
- log_exception(e, "exception in AsyncTaskStacksHandler")
- raise
-
-
-def run_async_stackframe_server_thread(port_range_low: int, port_range_high: int) -> None:
- try:
- success = False
- for port in range(port_range_low, port_range_high):
- try:
- server_address = ("localhost", port)
- httpd = HTTPServer(server_address, AsyncTaskStacksHandler)
- success = True
- print(f"Starting async stack trace server on port {port}. Process pid: {os.getpid()}")
- break
- except OSError:
- continue
-
- if not success:
- logger.info("Could not find an open port to start the async stack trace server, continuing without it.")
- return
-
- httpd.serve_forever()
- except BaseException as e:
- log_exception(e, "exception in run_async_stackframe_server_thread")
- raise
-
-
-IS_STACKFRAME_SERVER_RUNNING = False
-STACKFRAME_SERVER_PORT_LOW = 60000
-STACKFRAME_SERVER_PORT_HIGH = 61000
-
-
-def run_async_stackframe_server_for_loop(event_loop: asyncio.AbstractEventLoop) -> None:
- ALL_EVENT_LOOPS.append(event_loop)
- run_async_stackframe_server()
-
-
-def run_async_stackframe_server() -> None:
- global IS_STACKFRAME_SERVER_RUNNING
- if not IS_STACKFRAME_SERVER_RUNNING:
- IS_STACKFRAME_SERVER_RUNNING = True
-
- port_str = os.environ.get("ASYNC_STACKFRAME_SERVER_PORT")
- if port_str is not None:
- try:
- port = int(port_str)
- port_range_low = port
- port_range_high = port + 1
- except ValueError:
- logger.error("ASYNC_STACKFRAME_SERVER_PORT is not an integer: {}", port_str)
- raise
- else:
- port_range_low = STACKFRAME_SERVER_PORT_LOW
- port_range_high = STACKFRAME_SERVER_PORT_HIGH
-
- thread = threading.Thread(
- target=run_async_stackframe_server_thread,
- daemon=True,
- kwargs={
- "port_range_low": port_range_low,
- "port_range_high": port_range_high,
- },
- )
- thread.start()
-
-
-def with_timeout(func: Callable[P, Awaitable[R]], timeout_secs: float) -> Callable[P, Awaitable[R]]:
- """Decorator that adds a timeout to an async function."""
-
- @functools.wraps(func)
- async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
- return await asyncio.wait_for(func(*args, **kwargs), timeout_secs)
-
- return wrapper
-
-
-T = TypeVar("T")
-
-
-async def gather_with_limited_concurrency(coros: Iterable[Awaitable[T]], n: int) -> list[T]:
- """Like asyncio.gather() but will only run `n` in parallel at a time.
-
- Note that a call like `asyncio.gather(*coros)` is now `gather_with_limited_concurrency(coros, n=10),
- without the splat.
- """
- semaphore = asyncio.Semaphore(n)
-
- async def sem_coro(coro: Awaitable[T]) -> T:
- async with semaphore:
- return await coro
-
- return list(await asyncio.gather(*(sem_coro(c) for c in coros)))
-
-
-_NOT_FOUND = object()
-
-
-class AsyncCachedProperty(Generic[T]):
- """A descriptor factory that behaves very similarly to `functools.cached_property`, but for
- async methods!
-
- The type annotations here are rough; it's not realistic to get them perfect without using a .pyi file.
- """
-
- def __init__(self, func: Callable[[Any], Coroutine[None, None, T]]) -> None:
- self.func = func
- self.attrname: str | None = None
- self.__doc__ = func.__doc__
-
- def __set_name__(self, owner: type, name: str) -> None:
- if self.attrname is None:
- self.attrname = name
- elif name != self.attrname:
- raise TypeError("Cannot assign the same AsyncCachedProperty to multiple names")
-
- def _get_attrname(self) -> str:
- if self.attrname is None:
- raise TypeError("Cannot use AsyncCachedProperty instance without calling __set_name__")
- return self.attrname
-
- def _get_cache(self, instance: object) -> dict[str, Any]:
- try:
- return instance.__dict__
- except AttributeError:
- raise TypeError(
- "Cannot use AsyncCachedProperty with instances that do not have a __dict__ attribute"
- ) from None
-
- def __get__(self, instance: object, owner: type | None = None) -> Awaitable[T]:
- if instance is None:
- return self # type: ignore
- attrname = self._get_attrname()
- cache = self._get_cache(instance)
- val = cache.get(attrname, _NOT_FOUND)
- if val is not _NOT_FOUND:
- return cast(Awaitable[T], val)
-
- task = asyncio.create_task(self.func(instance))
- cache[attrname] = task
- return task
-
- def __delete__(self, instance: object) -> None:
- if instance is None:
- raise TypeError("Cannot delete AsyncCachedProperty on a class")
- attrname = self._get_attrname()
- cache = self._get_cache(instance)
- try:
- awaitable = cache.pop(attrname)
- if not awaitable.done():
- safe_cancel(awaitable)
- except KeyError:
- raise AttributeError(f"Cannot delete attribute {self.attrname!r}") from None
-
- def __set__(self, instance: object, value: T) -> None:
- if instance is None:
- raise TypeError("Cannot set AsyncCachedProperty on a class")
- attrname = self._get_attrname()
- cache = self._get_cache(instance)
- existing = cache.pop(attrname, None)
- if existing is not None and not existing.done():
- safe_cancel(existing)
- fut: asyncio.Future[T] = asyncio.Future()
- fut.set_result(value)
- cache[attrname] = fut
-
-
-def wrapped_asyncio_run(
- main: Awaitable[T],
- *,
- debug: bool | None = None,
- loop_factory: Callable[..., asyncio.AbstractEventLoop] | None = None,
-) -> T:
- """
- This is a lightweight wrapper with a singular purpose -- it is here to enable apyspy
-
- apyspy is our async equivalent to py-spy.
-
- Without apyspy, it's really annoying to debug why some async task is stuck.
- """
-
- async def wrapper_main() -> T:
- ALL_EVENT_LOOPS.append(asyncio.get_event_loop())
- result = await main
- ALL_EVENT_LOOPS.pop()
- return result
-
- run_async_stackframe_server()
- with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
- return runner.run(wrapper_main())
-
-
def make_async(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
"""
Turn the annotated function into an async function by running it in a thread.
diff --git a/imbue_core/imbue_core/caching.py b/imbue_core/imbue_core/caching.py
@@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
-import os
from functools import lru_cache
from pathlib import Path
from types import TracebackType
@@ -98,7 +97,9 @@ class AsyncCache(AsyncCacheInterface[ValueType], Generic[ValueType]):
loop = asyncio.get_running_loop()
cache = self.cache
assert cache is not None
- result = await loop.run_in_executor(None, cache.__exit__, exc_type, exc_val, exc_tb)
+ result = await loop.run_in_executor(
+ None, cache.__exit__, exc_type, exc_val, exc_tb
+ )
self.cache = None
return result
@@ -115,15 +116,13 @@ class AsyncCache(AsyncCacheInterface[ValueType], Generic[ValueType]):
cache = self.cache
assert cache is not None
loop = asyncio.get_running_loop()
- assert isinstance(value, self.value_cls), f"Expected {self.value_cls}, got {type(value)}"
+ assert isinstance(value, self.value_cls), (
+ f"Expected {self.value_cls}, got {type(value)}"
+ )
serialized_value = serialize_to_json(value)
- return await loop.run_in_executor(None, cache.set, key, serialized_value, expire, read, tag, retry)
-
- async def delete(self, key: str, retry: bool = False) -> bool:
- cache = self.cache
- assert cache is not None
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, cache.delete, key, retry)
+ return await loop.run_in_executor(
+ None, cache.set, key, serialized_value, expire, read, tag, retry
+ )
async def get(
self,
@@ -137,13 +136,15 @@ class AsyncCache(AsyncCacheInterface[ValueType], Generic[ValueType]):
cache = self.cache
assert cache is not None
loop = asyncio.get_running_loop()
- value = await loop.run_in_executor(None, cache.get, key, None, read, expire_time, tag, retry)
+ value = await loop.run_in_executor(
+ None, cache.get, key, None, read, expire_time, tag, retry
+ )
if value is None:
return default
deserialized_value = deserialize_from_json(value)
- assert isinstance(
- deserialized_value, self.value_cls
- ), f"Expected {self.value_cls}, got {type(deserialized_value)}"
+ assert isinstance(deserialized_value, self.value_cls), (
+ f"Expected {self.value_cls}, got {type(deserialized_value)}"
+ )
return deserialized_value
# TODO: this is not smart implementation, but at least it will be possible to optimize later without refactoring
@@ -180,56 +181,3 @@ def get_cache(data_path: Path) -> Cache:
eviction_policy="none",
size_limit=2**36,
)
-
-
-def get_default_llm_response_cache() -> Path:
- return Path(os.environ.get("RESPONSE_CACHE_PATH", os.path.expanduser("~/.llm_response_cache")))
-
-
-def get_default_count_tokens_cache() -> Path:
- return Path(os.environ.get("COUNT_TOKENS_CACHE_PATH", os.path.expanduser("~/.count_tokens_cache")))
-
-
-def get_test_llm_response_cache() -> Path:
- return Path(os.path.expanduser("~/.llm_test_response_cache"))
-
-
-class InMemoryCache(AsyncCacheInterface[ValueType], Generic[ValueType]):
- def __init__(self, values: tuple[ValueType, ...]) -> None:
- self.values = values
-
- async def __aenter__(self) -> Self:
- return self
-
- async def __aexit__(
- self,
- exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- pass
-
- async def get(
- self,
- key: str,
- default: ValueType | None = None,
- read: bool = False,
- expire_time: bool = False,
- tag: bool = False,
- retry: bool = False,
- ) -> ValueType | None:
- return self.values[int(key)]
-
- async def get_all(
- self,
- keys: Sequence[str],
- default: ValueType | None = None,
- read: bool = False,
- expire_time: bool = False,
- tag: bool = False,
- retry: bool = False,
- ) -> FrozenMapping[str, ValueType | None]:
- return FrozenDict(zip(await self.get_all_keys(), self.values))
-
- async def get_all_keys(self, reverse: bool = False) -> tuple[str, ...]:
- return tuple(map(str, range(len(self.values))))
diff --git a/imbue_core/imbue_core/cattrs_serialization.py b/imbue_core/imbue_core/cattrs_serialization.py
@@ -868,10 +868,9 @@ def _serialize_to_json_dumpable_object(
# with `is_reversible=False` since we won't know the type to be able to recreate the object.
assert is_reversible, "Cannot restructure inputs if is_reversible=False"
- # TODO: this is a hack to make it possible to serialize ExecutionContexts for class method hammers.
+ # TODO: this is a hack to make it possible to serialize ExecutionContexts for class methods.
# This lets us serialize ExecutionContexts for calls to class methods without serializing the class itself.
- # The long-term solutions are 1) either get rid of all class method hammers,
- # or 2) write a custom hook that can serialize type objects.
+ # The long-term solution is to write a custom hook that can serialize type objects.
if type(obj) is dict and "__class__" in obj:
del obj["__class__"]
@@ -990,24 +989,3 @@ def deserialize_from_dict(
return _deserialize_using_type_marker(data, as_type, converter=converter)
except Exception as e:
raise SerializationError(str(e)) from e
-
-
-def deserialize_from_dict_with_type(data: dict[str, Any], obj_type: type[T]) -> T:
- try:
- converter = CONVERTER_FACTORY.get_converter(for_javascript=False, exclude_dont_serialize_fields=False)
- result = converter.structure(data, obj_type)
- assert isinstance(result, obj_type), f"Expected an object of type {obj_type}, but got {result}"
- return result
- except Exception as e:
- raise SerializationError(str(e)) from e
-
-
-def deserialize_from_json_with_type(data: str | bytes | bytearray, obj_type: type[T]) -> T:
- try:
- converter = CONVERTER_FACTORY.get_converter(for_javascript=False, exclude_dont_serialize_fields=False)
- return cast(
- T,
- _deserialize_serialized_object(json.loads(data), obj_type, converter=converter),
- )
- except Exception as e:
- raise SerializationError(str(e)) from e
diff --git a/imbue_core/imbue_core/common.py b/imbue_core/imbue_core/common.py
@@ -1,135 +1,5 @@
-import functools
-import hashlib
-import inspect
-import os
-import platform
-import sys
import uuid
-from pathlib import Path
-from types import FrameType
-
-import pathspec
-
-
-def is_on_osx() -> bool:
- return platform.system().lower() == "darwin"
-
-
-def is_running_within_a_pytest_tree() -> bool:
- """
- This is true if this, or any parent process, is running under pytest.
-
- This is different from `is_running_within_a_pytest_process` in that it is true if we are logically testing or not.
-
- This is usually what you want to check
- (eg, this will be true even if you are a separately launched integration server process)
- """
- return "PYTEST_CURRENT_TEST" in os.environ
-
-
-def is_running_within_a_pytest_process() -> bool:
- """
- This is true if the current process is literally running pytest.
-
- This is different from `is_running_within_a_pytest_tree` in that it checks if the current process is pytest itself,
- which is most useful for knowing whether we are running a bunch of unit tests in this process or not.
- """
- return "pytest" in sys.modules
-
-
-def is_live_debugging() -> bool:
- """
- Returns True if the current process is being debugged, for example by PyCharm or another IDE.
- """
- # this is unfortunately true when measuring coverage and in other cases, sigh
- # return sys.gettrace() is not None
- # but this is only true when debugging in pycharm, I think?
- return sys.breakpointhook.__module__ != "sys"
-
-
-@functools.lru_cache(maxsize=1)
-def get_filesystem_root() -> str:
- env_value = os.getenv("SCIENCE_FILESYSTEM_ROOT")
- if not env_value:
- if is_on_osx():
- return "/tmp/science"
- else:
- # When on the physical cluster (and possibly other core clusters), this path is mounted to a unique per-container file path.
- # Anything produced at runtime >10mb should likely go here, as well as anything you might want to dig up for later debugging.
- # The hosts clean up the paths from dead containers periodically, but large data processing jobs should still clean up after themselves.
- return "/mnt/private"
- return env_value
-
-
-@functools.lru_cache(maxsize=1)
-def get_temp_dir() -> str:
- temp_dir = os.path.join(get_filesystem_root(), "tmp")
- os.makedirs(temp_dir, exist_ok=True)
- return temp_dir
-
-
-def hash_string(string: str) -> str:
- return hashlib.md5(string.encode("utf-8")).hexdigest()
-
-
-def get_current_function_name() -> str:
- frame = inspect.currentframe()
- if frame is None:
- return "no_frame"
- prev_frame = frame.f_back
- if prev_frame is None or not isinstance(prev_frame, FrameType):
- return "no_previous_frame"
- return prev_frame.f_code.co_name
-
-
-def filter_excluded_files(files: list[Path], directory: Path, exclude_file_name: str = ".gitignore") -> list[Path]:
- """Remove files from the list that are matched by a .gitignore or similarly-specified exclude file such as
- .gitignore or ratchet_excluded.txt.
- """
-
- # Underneath the root directory, find all the excluders.
- # They can occur in subfolders and if they do they apply only to that subfolder.
- excluders = {path for path in directory.rglob(exclude_file_name) if not path.is_symlink()}
-
- # Per excluder, make a pathspec.
- for excluder in excluders:
- with excluder.open("r") as exclude_file:
- exclude_spec = pathspec.GitIgnoreSpec.from_lines(exclude_file)
-
- # Now we have two cases - We keep the file if the excluder doesn't apply because it's in a different
- # folder, or if it applies but doesn't match
- prefix = os.path.dirname(excluder)
- files = [
- file
- for file in files
- if not (file.is_relative_to(prefix) and exclude_spec.match_file(file.relative_to(prefix)))
- ]
-
- return files
def generate_id() -> str:
return uuid.uuid4().hex
-
-
-def generate_id_from_existing_id(existing_id: str, seed: int) -> str:
- return hashlib.md5(f"{existing_id}-{seed}".encode()).hexdigest()
-
-
-def truncate_string(s: str, max_length: int) -> str:
- if len(s) <= max_length:
- return s
- return s[: max_length - 3] + "..."
-
-
-def parse_bool_environment_variable(var_name: str) -> bool:
- env_var = os.environ.get(var_name, "0").lower()
-
- assert env_var in (
- "0",
- "1",
- "true",
- "false",
- ), f"{var_name} environment variable must be '0', '1', 'true', or 'false'. Current value: '{env_var}'"
-
- return env_var in ("1", "true")
diff --git a/imbue_core/imbue_core/computing_environment/__init__.py b/imbue_core/imbue_core/computing_environment/__init__.py
diff --git a/imbue_core/imbue_core/computing_environment/computing_environment.py b/imbue_core/imbue_core/computing_environment/computing_environment.py
@@ -1,1080 +0,0 @@
-from __future__ import annotations
-
-import asyncio
-import shlex
-import time
-from datetime import datetime
-from pathlib import Path
-from typing import Protocol
-from typing import Sequence
-from typing import TYPE_CHECKING
-from uuid import uuid4
-
-import anyio
-from loguru import logger
-from tenacity import TryAgain
-from tenacity import retry
-from tenacity import retry_all
-from tenacity import retry_if_exception_type
-from tenacity import stop_after_attempt
-from tenacity import wait_random_exponential
-
-from imbue_core.async_monkey_patches import log_exception
-from imbue_core.computing_environment.data_types import AnyPath
-from imbue_core.computing_environment.data_types import FailedToMakeCommitError
-from imbue_core.computing_environment.data_types import PatchApplicationError
-from imbue_core.computing_environment.data_types import RunCommandError
-from imbue_core.git_data_types import CommitTimestamp
-from imbue_core.retry_utils import log_before_sleep
-from imbue_core.section import Section
-from imbue_core.time_utils import get_current_time
-
-# Import the types needed for file modes
-if TYPE_CHECKING:
- # for proper file mode typing
- from _typeshed import OpenBinaryModeReading
- from _typeshed import OpenBinaryModeWriting
- from _typeshed import OpenTextModeReading
- from _typeshed import OpenTextModeWriting
-
-
-class ComputingEnvironment(Protocol):
- """Protocol defining the interface for a computing environment.
-
- This protocol specifies the required methods for interacting with a computing
- environment, including running commands and file operations.
- """
-
- async def run_command(
- self,
- command: Sequence[str],
- check: bool = True,
- secrets: dict[str, str] | None = None,
- cwd: AnyPath | None = None,
- is_error_logged: bool = True,
- ) -> str: ...
-
- async def run_git(
- self,
- command: Sequence[str],
- check: bool = True,
- cwd: AnyPath | None = None,
- is_error_logged: bool = True,
- is_stripped: bool = True,
- retry_on_git_lock_error: bool = True,
- ) -> str: ...
-
- async def write_file(
- self,
- relative_path: AnyPath,
- content: str | bytes | None,
- cwd: AnyPath | None = None,
- mode: OpenTextModeWriting | OpenBinaryModeWriting = "w",
- mkdir_if_missing: bool = True,
- ) -> None: ...
-
- async def read_file(
- self,
- relative_path: AnyPath,
- cwd: AnyPath | None = None,
- mode: OpenTextModeReading | OpenBinaryModeReading = "r",
- mkdir_if_missing: bool = True,
- ) -> str | bytes: ...
-
- async def delete_file(
- self,
- relative_path: AnyPath,
- cwd: AnyPath | None = None,
- ) -> None: ...
-
-
-def _get_temp_patch_file() -> anyio.Path:
- # this is a bad idea because it triggers the file watcher
- # patch_file = (self.base_path / str(uuid4())).with_suffix(".patch")
- patch_file = (Path("/tmp") / uuid4().hex).with_suffix(".patch")
- return anyio.Path(patch_file)
-
-
-async def run_command_with_retry_on_git_lock_error(
- computing_environment: ComputingEnvironment,
- command: Sequence[str],
- check: bool = True,
- is_error_logged: bool = True,
- cwd: AnyPath | None = None,
-) -> str:
- max_retries = 50
- retry_count = 0
- retry_delay = 0.1 # seconds
- while True:
- try:
- return await computing_environment.run_command(
- command,
- check=check,
- is_error_logged=is_error_logged and retry_count >= max_retries,
- cwd=cwd,
- )
- except RunCommandError as e:
- error_message = str(e)
- if "fatal: Unable to create" in error_message and ".git/index.lock': File exists" in error_message:
- if retry_count >= max_retries:
- raise
- await asyncio.sleep(retry_delay)
- retry_count += 1
- else:
- raise
-
-
-@retry(
- wait=wait_random_exponential(multiplier=0.1, max=2, exp_base=2),
- reraise=True,
- stop=stop_after_attempt(50),
- before_sleep=log_before_sleep,
-)
-async def wait_for_git_index_lock_to_be_free(local_sync_repo_path: Path) -> None:
- # Path to the git index lock file using anyio.Path to avoid blocking
- lock_file_path = anyio.Path(local_sync_repo_path / ".git" / "index.lock")
- if await lock_file_path.exists():
- raise TryAgain
-
-
-async def apply_patch_without_git(computing_environment: ComputingEnvironment, diff: str) -> None:
- if diff.strip() == "":
- return
- patch_file = _get_temp_patch_file()
- try:
- await computing_environment.write_file(patch_file, diff)
- await computing_environment.run_command(("bash", "-c", f"patch -p1 < {patch_file}"))
- except RunCommandError as e:
- raise PatchApplicationError(f"Failed to apply patch: {e}") from e
- finally:
- await computing_environment.delete_file(patch_file)
-
-
-async def is_repo_dirty(computing_environment: ComputingEnvironment, is_untracked_considered: bool = True) -> bool:
- """Check if the repo has any uncommitted changes."""
- return bool(
- await computing_environment.run_git(
- (
- "status",
- "--porcelain",
- *([] if is_untracked_considered else ["--untracked-files=no"]),
- )
- )
- )
-
-
-async def are_all_commits_pushed(computing_environment: ComputingEnvironment) -> bool:
- """Check if the repo has any unpushed commits."""
- output = await computing_environment.run_git(("cherry",))
- return output.strip() == ""
-
-
-async def are_all_remote_commits_pulled(
- computing_environment: ComputingEnvironment,
-) -> bool:
- # note this will fail if the branch hasn't been pushed to the remote
- output = await computing_environment.run_command(
- ("bash", "-c", "git fetch && git rev-list HEAD..@{upstream} --count")
- )
- return output.strip() == "0"
-
-
-async def assert_repo_is_clean(computing_environment: ComputingEnvironment) -> None:
- """Assert that the repo has no uncommitted changes."""
- assert not await is_repo_dirty(
- computing_environment
- ), "You have untracked files. Please address them before using this script (this is to prevent accidentally adding large files unintentionally)"
-
-
-async def get_branch_name(computing_environment: ComputingEnvironment, is_error_logged: bool = True) -> str:
- """Get the name of the current branch."""
- return await computing_environment.run_git(("symbolic-ref", "--short", "HEAD"), is_error_logged=is_error_logged)
-
-
-async def rename_branch(
- computing_environment: ComputingEnvironment,
- old_name: str,
- new_name: str,
- force_if_exists: bool = True,
-) -> None:
- """Rename the given branch."""
- if force_if_exists:
- await computing_environment.run_git(("branch", "-M", old_name, new_name))
- else:
- await computing_environment.run_git(("branch", "-m", old_name, new_name))
-
-
-async def get_branch_description(computing_environment: ComputingEnvironment, branch_name: str) -> str:
- """Get the description of the given branch."""
- try:
- return await computing_environment.run_git(
- ("config", f"branch.{branch_name}.description"), is_error_logged=False
- )
- except RunCommandError as e:
- if e.returncode == 1:
- # no description set
- return ""
- raise
-
-
-async def is_branch_exists(computing_environment: ComputingEnvironment, branch_name: str) -> bool:
- """Check if the given branch exists."""
- result = await computing_environment.run_git(
- ("rev-parse", "--verify", "--quiet", branch_name),
- is_error_logged=False,
- check=False,
- )
- return result.strip() != ""
-
-
-async def is_detached_head(computing_environment: ComputingEnvironment) -> bool:
- """Check if the current HEAD is detached."""
- result = await computing_environment.run_git(("rev-parse", "--abbrev-ref", "HEAD"), is_error_logged=False)
- return result.strip() == "HEAD"
-
-
-async def set_branch_description(
- computing_environment: ComputingEnvironment, branch_name: str, description: str
-) -> None:
- """Set the description of the given branch."""
- await computing_environment.run_git(("config", f"branch.{branch_name}.description", description))
-
-
-async def get_branch_commit(computing_environment: ComputingEnvironment, branch_name: str) -> str:
- """Get the commit of the given branch."""
- return await computing_environment.run_git(("rev-parse", branch_name))
-
-
-async def get_all_branch_names_pointing_to_commit(
- computing_environment: ComputingEnvironment, commit_hash: str
-) -> tuple[str, ...]:
- """Get all branch names that point to the given commit."""
- result = await computing_environment.run_git(
- (
- "for-each-ref",
- "refs/heads/",
- "--format='%(refname:short)'",
- "--points-at",
- commit_hash,
- )
- )
- branch_names = tuple(result.splitlines())
- # strip the quotes
- return tuple(branch_name.strip("'") for branch_name in branch_names)
-
-
-async def is_branch_child_of_branch(
- computing_environment: ComputingEnvironment,
- child_branch_name: str,
- parent_branch_name: str,
-) -> bool:
- """Check if the given branch is a child of the parent branch."""
- try:
- await computing_environment.run_git(
- ("merge-base", "--is-ancestor", parent_branch_name, child_branch_name),
- is_error_logged=False,
- )
- return True
- except RunCommandError as e:
- if e.stderr.strip() == "" and e.returncode == 1:
- # we expect this command to give an empty stderr and a return code of 1
- # if the child branch is not an ancestor of the parent branch
- return False
- raise
-
-
-async def is_commit_on_branch(
- computing_environment: ComputingEnvironment,
- commit_hash: str,
- branch_name: str,
- local_only: bool = True,
-) -> bool:
- """Check if the given commit is on the given branch."""
- if local_only:
- result = await computing_environment.run_git(("branch", "--contains", commit_hash))
- else:
- result = await computing_environment.run_git(("branch", "-a", "--contains", commit_hash))
- return any(branch_name == x.strip() for x in result.splitlines())
-
-
-async def fetch_and_get_first_entry_in_fetch_head(
- computing_environment: ComputingEnvironment, remote: str, fetch_refs: Sequence[str]
-) -> str:
- """Fetch the given refs from the remote and return the first entry in FETCH_HEAD."""
- refs_str = " ".join(fetch_refs)
- command = [
- "bash",
- "-c",
- (
- f"git fetch {remote} {refs_str} && "
- # get first commit from FETCH_HEAD
- "git rev-parse FETCH_HEAD"
- ),
- ]
- result = await run_command_with_retry_on_git_lock_error(computing_environment, command)
- return result.strip()
-
-
-async def fetch_branch(computing_environment: ComputingEnvironment, branch_name: str) -> None:
- """Fetch the given branch from the remote."""
- await computing_environment.run_git(("fetch", "origin", branch_name))
-
-
-async def is_branch_present(computing_environment: ComputingEnvironment, branch_name: str) -> bool:
- """Check if branch with given name is present."""
- result = await computing_environment.run_git(("branch",))
- return branch_name in result.splitlines()
-
-
-async def create_reset_and_checkout_branch(computing_environment: ComputingEnvironment, branch_name: str) -> str:
- """Create new branch with given name."""
- return await computing_environment.run_git(("switch", "-C", branch_name))
-
-
-async def switch_branch(computing_environment: ComputingEnvironment, branch_name: str) -> str:
- """Switch to branch with given name."""
- return await computing_environment.run_git(("switch", branch_name))
-
-
-async def delete_branch(computing_environment: ComputingEnvironment, branch_name: str, delete_remote: bool) -> str:
- """Delete branch with given name."""
- result = await computing_environment.run_git(("branch", "-D", branch_name))
- if delete_remote:
- result = await computing_environment.run_git(("push", "origin", "--delete", branch_name))
- return result
-
-
-async def update_branch_to_hash(computing_environment: ComputingEnvironment, branch_name: str, git_hash: str) -> None:
- """Update the given branch to reference the given git hash."""
- # here we do it without checking out the branch
- await computing_environment.run_git(("branch", "-f", branch_name, git_hash))
-
-
-async def switch_and_create_branch_if_needed(computing_environment: ComputingEnvironment, branch_name: str) -> str:
- """Switch to new branch, creating it if it doesn't already exist."""
- if await is_branch_present(computing_environment, branch_name):
- await switch_branch(computing_environment, branch_name)
- else:
- await create_reset_and_checkout_branch(computing_environment, branch_name)
- return await get_branch_name(computing_environment)
-
-
-async def merge_branches(
- computing_environment: ComputingEnvironment,
- base_branch_name: str,
- merge_branch_name: str,
- is_moving_to_base_branch: bool = True,
-) -> str:
- """Merge `merge_branch_name` into `base_branch_name`."""
- await switch_branch(computing_environment, base_branch_name)
- await computing_environment.run_git(("merge", merge_branch_name))
- if not is_moving_to_base_branch:
- await switch_branch(computing_environment, "-")
- return await get_branch_name(computing_environment)
-
-
-async def get_merge_base(computing_environment: ComputingEnvironment, branch_name: str, target_branch: str) -> str:
- """Get the merge base of the given branch and target branch.
-
- The merge base is the most recent commit that is on both branches.
- """
- return await computing_environment.run_git(["merge-base", branch_name, target_branch], is_error_logged=False)
-
-
-async def checkout_hash(computing_environment: ComputingEnvironment, git_hash: str) -> str:
- """Checkout given git hash."""
- return await computing_environment.run_git(("checkout", git_hash))
-
-
-async def force_add(computing_environment: ComputingEnvironment, *paths: str) -> None:
- """Force-add the specified paths to the git index."""
- await computing_environment.run_git(("add", "-f", *paths))
-
-
-async def git_add(computing_environment: ComputingEnvironment, *paths: str) -> None:
- """Add the specified paths to the git index."""
- await computing_environment.run_git(("add", *paths))
-
-
-def convert_datetime_to_git_timestamp(dt: datetime) -> str:
- return datetime.isoformat(dt)
-
-
-def convert_git_timestamp_to_datetime(timestamp: str) -> datetime:
- return datetime.fromisoformat(timestamp)
-
-
-def get_commit_ts_for_current_time() -> CommitTimestamp:
- """Get the commit timestamp for the current time."""
- current_time = get_current_time()
- return CommitTimestamp(
- author_ts=convert_datetime_to_git_timestamp(current_time),
- committer_ts=convert_datetime_to_git_timestamp(current_time),
- )
-
-
-def _convert_time_to_commit_ts(
- time: str | datetime | CommitTimestamp | None,
-) -> CommitTimestamp:
- if time is None:
- return get_commit_ts_for_current_time()
- elif isinstance(time, datetime):
- return CommitTimestamp(
- author_ts=convert_datetime_to_git_timestamp(time),
- committer_ts=convert_datetime_to_git_timestamp(time),
- )
- elif isinstance(time, CommitTimestamp):
- return time
- else:
- # assume it's a git timestamp
- return CommitTimestamp(author_ts=time, committer_ts=time)
-
-
-async def make_commit(
- computing_environment: ComputingEnvironment,
- commit_message: str,
- allow_empty: bool = False,
- amend: bool = False,
- commit_time: str | datetime | CommitTimestamp | None = None,
-) -> str:
- if commit_message.strip() == "":
- commit_message = "No commit message provided"
-
- commit_ts = _convert_time_to_commit_ts(commit_time)
- time_args = f'GIT_AUTHOR_DATE="{commit_ts.author_ts}" GIT_COMMITTER_DATE="{commit_ts.committer_ts}" '
-
- commit_message = shlex.quote(commit_message)
- no_changes_message = "No changes to commit"
- amend_args = "--amend " if amend else ""
- if allow_empty or amend:
- bash_command = f"""git add . && {time_args}git commit {amend_args}--allow-empty -m {commit_message} > /dev/null && git rev-parse HEAD"""
- else:
- bash_command = f"""git add . && ( git status | grep -q "nothing to commit" && echo "{no_changes_message}" ) || ( {time_args}git commit {amend_args}-m {commit_message} > /dev/null && git rev-parse HEAD )"""
-
- with Section(f"committing changes with message: '{commit_message}'", log_level="DEBUG"):
- stdout = await run_command_with_retry_on_git_lock_error(
- computing_environment,
- ["bash", "-c", bash_command],
- )
- stdout = stdout.strip()
- if stdout == no_changes_message:
- raise FailedToMakeCommitError(f"Failed to make commit with message: {commit_message}. {bash_command=}")
- new_git_hash = stdout
- return new_git_hash
-
-
-async def get_tree_hash_for_commit(computing_environment: ComputingEnvironment, commit: str) -> str:
- """Get the tree hash for the given commit."""
- return await computing_environment.run_git(["rev-parse", commit + "^{tree}"])
-
-
-async def get_commit_timestamp(computing_environment: ComputingEnvironment, commit: str) -> CommitTimestamp:
- """Get the commit timestamp for the given commit."""
- split_token = "<|>"
- result = await computing_environment.run_git(["show", "-s", "--format=%aI<|>%cI", commit])
- author_ts, committer_ts = result.split(split_token)
- return CommitTimestamp(author_ts=author_ts.strip(), committer_ts=committer_ts.strip())
-
-
-async def tag_commit(computing_environment: ComputingEnvironment, tag: str, commit_hash: str) -> None:
- """Tag the given commit with the given tag."""
- # We use -f to force the tag to be created even if it already exists.
- await computing_environment.run_git(("tag", "-f", tag, commit_hash))
-
-
-async def git_push(computing_environment: ComputingEnvironment, branch_name: str) -> str:
- """Push changes to remote branch with given name."""
- return await computing_environment.run_git(("push", "origin", branch_name))
-
-
-async def force_push(computing_environment: ComputingEnvironment, branch_name: str) -> str:
- """Push changes to remote branch with given name."""
- return await computing_environment.run_git(("push", "--force", "origin", branch_name))
-
-
-async def force_push_commit_with_retry(
- computing_environment: ComputingEnvironment,
- commit: str,
- branch_name: str,
- timeout: float = 30.0,
-) -> None:
- start_time = time.monotonic()
- sleep_time = 0.5
- while True:
- try:
- await force_push_commit(computing_environment, commit, branch_name)
- break
- except Exception as exc:
- if time.monotonic() - start_time > timeout:
- raise TimeoutError(
- f"Timeout reached: Could not force push {commit} to {branch_name} in {timeout} seconds."
- ) from exc
- logger.info("Force push of {} to {} failed; trying again...", commit, branch_name)
- await asyncio.sleep(sleep_time)
- sleep_time *= 2
-
-
-async def force_push_commit(computing_environment: ComputingEnvironment, commit: str, branch_name: str) -> None:
- try:
- await computing_environment.run_git(["push", "-f", "origin", f"{commit}:{branch_name}"], is_error_logged=False)
- except RunCommandError as e:
- if "fatal: bad object" in e.stderr:
- # TODO : We're retrying failed fetches here. However, there is also a separate
- # force_push_commit_with_retry method that retries the entire force_push_commit.
- # We should probably try the fetch only once, and then rely on the outer
- # force_push_commit_with_retry to retry the entire force_push_commit call when retrying is =
- # desired?
- NUM_TRIES = 3
- for _ in range(NUM_TRIES):
- try:
- await computing_environment.run_git(["fetch", "origin", commit], is_error_logged=False)
- except RunCommandError as fetch_e:
- if "not our ref" in fetch_e.stderr:
- # FIXME: actually, this has been getting worse... I suspect perhaps rate limiting or something? We are checking thing out much more than usual...
- await asyncio.sleep(2)
- else:
- raise fetch_e
- else:
- start_time = time.monotonic()
- while time.monotonic() - start_time < 10:
- try:
- await computing_environment.run_git(["push", "-f", "origin", f"{commit}:{branch_name}"])
- except RunCommandError as repush_e:
- if "not our ref" in repush_e.stderr:
- # FIXME: actually, this has been getting worse... I suspect perhaps rate limiting or something? We are checking thing out much more than usual...
- await asyncio.sleep(2)
- else:
- raise repush_e
- else:
- return
- raise Exception(f"Could not force push commit {commit}")
- raise Exception(f"Could not fetch commit {commit} to force push it")
- else:
- raise
-
-
-async def get_staged_files(
- computing_environment: ComputingEnvironment,
-) -> tuple[str, ...]:
- """Get list of all files in repo that are currently staged."""
- result = await computing_environment.run_git(("diff", "--full-index", "--binary", "--name-only", "--cached"))
- return tuple(result.splitlines())
-
-
-async def get_unstaged_files(
- computing_environment: ComputingEnvironment,
-) -> tuple[str, ...]:
- """Get list of all files in repo that are currently unstaged."""
- result = await computing_environment.run_git(("diff", "--full-index", "--binary", "--name-only"))
- return tuple(result.splitlines())
-
-
-async def restore_all_staged_files(computing_environment: ComputingEnvironment) -> None:
- """Restore all staged files."""
- await computing_environment.run_git(("restore", "--staged", "."))
-
-
-async def restore_all_unstaged_changes(
- computing_environment: ComputingEnvironment,
-) -> None:
- """Restore all unstaged changes."""
- await computing_environment.run_git(("restore", "."))
-
-
-async def apply_patch_via_git_with_conflict_markers(
- computing_environment: ComputingEnvironment,
- git_diff: str,
- is_error_logged: bool = True,
-) -> None:
- """Apply a diff to repo with conflict markers."""
- if git_diff.strip() == "":
- return
- if not git_diff.endswith("\n"):
- # git requires a newline at the end of the patch
- git_diff += "\n"
- patch_file = _get_temp_patch_file()
- try:
- await computing_environment.write_file(patch_file, git_diff)
- await computing_environment.run_command(
- [
- "bash",
- "-c",
- f"git add . && git apply --verbose {patch_file} || git apply -3 --verbose {patch_file}",
- ],
- is_error_logged=is_error_logged,
- )
- except RunCommandError as e:
- raise PatchApplicationError(f"Failed to apply patch: {e}") from e
- finally:
- await computing_environment.delete_file(patch_file)
-
-
-async def is_repo_conflicted(computing_environment: ComputingEnvironment) -> bool:
- output = await computing_environment.run_git(["status"], is_error_logged=False, check=False)
- if "Unmerged paths:" in output:
- return True
- return False
-
-
-async def get_head_hash(computing_environment: ComputingEnvironment) -> str:
- """Get the hash of the current HEAD commit."""
- git_hash = await computing_environment.run_git(["rev-parse", "HEAD"])
- assert len(git_hash) == 40, f"Expected 40-character git hash, got {git_hash}"
- return git_hash
-
-
-async def get_parent_commit_hash(computing_environment: ComputingEnvironment, commit_hash: str) -> str:
- """Get the parent commit hash of the given commit hash."""
- git_hash = await computing_environment.run_git(["rev-parse", f"{commit_hash}^"])
- assert len(git_hash) == 40, f"Expected 40-character git hash, got {git_hash}"
- return git_hash
-
-
-async def get_most_recent_sibling_branch_of_branch(
- computing_environment: ComputingEnvironment, target_branch_name: str
-) -> tuple[str, ...] | None:
- """Get the most recent sibling branch of the given branch name.
-
- This is the first branch that shares a common ancestor commit with the given branch name.
- Note, that it is possible that there are multiple branches that share the most recent common ancestor.
- In this case, we return all of them.
-
- Also if there are no sibling branches (either at all or within the max lookback), we return None.
- """
- # FIXME: this is imprecise -- it really just needs to be head 2, but I didn't want to figure out the regex...
- output = await computing_environment.run_command(
- [
- "bash",
- "-c",
- f'git log --decorate=full --oneline {target_branch_name} | grep "refs/heads/" | head -n 10',
- ],
- check=True,
- )
- for line in output.splitlines():
- branch_list_string = line.split(" (", maxsplit=1)[-1].split(")", maxsplit=1)[0]
- parsed_branch_names = []
- for branch_name_string in branch_list_string.split(", "):
- parsed_branch_name = branch_name_string.rsplit(" ")[-1]
- if parsed_branch_name.startswith("refs/heads") and parsed_branch_name != f"refs/heads/{target_branch_name}":
- parsed_branch_names.append(parsed_branch_name.replace("refs/heads/", "", 1))
- # otherwise, just return the first one
- if len(parsed_branch_names) > 0:
- return tuple(parsed_branch_names)
- return None
-
-
-async def get_nth_commit_ago(
- computing_environment: ComputingEnvironment,
- commit_hash: str,
- n: int,
- is_error_logged: bool = True,
-) -> str:
- """Get the nth commit ago of the given commit hash."""
- git_hash = await computing_environment.run_git(["rev-parse", f"{commit_hash}~{n}"], is_error_logged=is_error_logged)
- assert len(git_hash) == 40, f"Expected 40-character git hash, got {git_hash}"
- return git_hash
-
-
-async def get_initial_repo_commit_hash(computing_environment: ComputingEnvironment, commit_hash: str = "HEAD") -> str:
- """Get the initial commit hash of the repo.
-
- As written, if invoked on an empty repo with no commits (immediately after `git init`), this fails with:
- `fatal: ambiguous argument 'HEAD': unknown revision or path not in the working tree.`
- """
- # --max-parents=0: only consider commits with no parents
- # --date-order: sort by date (newest first)
- output = await computing_environment.run_git(["rev-list", "--max-parents=0", commit_hash, "--date-order"])
- # assume the oldest commit with no parents is the initial repo commit
- all_root_commits = output.splitlines()
- root_commit = all_root_commits[-1]
- assert len(root_commit) == 40, f"Expected 40-character git hash, got {root_commit}"
- return root_commit
-
-
-async def get_upto_nth_commit_ago(computing_environment: ComputingEnvironment, commit_hash: str, n: int) -> str:
- """Get the commit hash of the upto nth commit ago of the given commit hash.
-
- If the commit history is shorter than n, it will return the first commit.
- """
- try:
- return await get_nth_commit_ago(computing_environment, commit_hash, n, is_error_logged=False)
- except RunCommandError as e:
- if "unknown revision or path not in the working tree" in e.stderr:
- return await get_initial_repo_commit_hash(computing_environment, commit_hash)
- raise
-
-
-async def get_commit_message(computing_environment: ComputingEnvironment, commit_hash: str) -> str:
- """Get the commit message of the given commit hash."""
- return await computing_environment.run_git(["log", "-1", "--pretty=%B", commit_hash])
-
-
-async def get_commit_count_between_hashes(
- computing_environment: ComputingEnvironment, old_hash: str, new_hash: str
-) -> int:
- """Get the number of commits between two hashes."""
- output = await computing_environment.run_git(["rev-list", "--count", f"{old_hash}..{new_hash}"])
- return int(output.strip())
-
-
-async def fetch_and_checkout_hash(
- computing_environment: ComputingEnvironment,
- git_hash: str,
- is_error_logged: bool = True,
-) -> None:
- await computing_environment.run_command(
- ["bash", "-c", f"git fetch origin {git_hash} && git checkout {git_hash}"],
- is_error_logged=is_error_logged,
- )
-
-
-async def fetch_ref_and_checkout_hash(
- computing_environment: ComputingEnvironment,
- ref: str,
- git_hash: str,
- is_error_logged: bool = True,
-) -> None:
- await computing_environment.run_command(
- ["bash", "-c", f"git fetch origin {ref} && git checkout {git_hash}"],
- is_error_logged=is_error_logged,
- )
-
-
-@retry(
- stop=stop_after_attempt(5),
- wait=wait_random_exponential(min=0.5, max=30, exp_base=2),
- reraise=True,
- retry=retry_all(retry_if_exception_type(RunCommandError)),
- before_sleep=log_before_sleep,
-)
-async def wait_for_git_hash_to_checkout(computing_environment: ComputingEnvironment, git_hash: str) -> None:
- await fetch_and_checkout_hash(computing_environment, git_hash, is_error_logged=False)
-
-
-async def wait_for_git_hash_with_ref_to_checkout(
- computing_environment: ComputingEnvironment,
- git_hash: str,
- ref: str,
- timeout: float = 20.0,
-) -> None:
- with Section(f"Checking out git hash {git_hash} with ref {ref}", log_level="DEBUG"):
- start_time = time.monotonic()
- while True:
- try:
- await fetch_ref_and_checkout_hash(computing_environment, ref, git_hash, is_error_logged=False)
- break
- except RunCommandError as exc:
- if time.monotonic() - start_time > timeout:
- log_exception(
- exc,
- "Timeout reached: Git hash {git_hash} is not available after {timeout} seconds.",
- git_hash=git_hash,
- timeout=timeout,
- )
- raise TimeoutError(f"Timeout reached: Git hash {git_hash} is not available.") from exc
- logger.info("Waiting for git hash {} to be available...", git_hash)
- await asyncio.sleep(0.5)
-
-
-async def force_checkout_git_hash_immediate_on_branch(
- computing_environment: ComputingEnvironment, git_hash: str, branch_name: str
-) -> None:
- # Here we clear any uncommited changes, then change branch, before checking out the new hash
- # We need to do these three steps so we don't affect the currently checked out branch
- command = f"git reset --hard && git checkout -B {branch_name} && git reset --hard {git_hash}"
- logger.debug("Running command: {}", command)
- await run_command_with_retry_on_git_lock_error(
- computing_environment,
- [
- "bash",
- "-c",
- command,
- ],
- )
-
-
-async def get_git_folder_paths(
- computing_environment: ComputingEnvironment,
-) -> tuple[str, ...]:
- """Get the paths of all the git folders in the repo."""
- result = await computing_environment.run_command(["ls", ".git"])
- return tuple(result.splitlines())
-
-
-async def apply_patch_via_git(
- computing_environment: ComputingEnvironment, git_diff: str, is_error_logged: bool
-) -> None:
- """Apply a diff to repo."""
- if git_diff.strip() == "":
- return
- patch_file = _get_temp_patch_file()
- await computing_environment.write_file(patch_file, git_diff)
- # NOTE: --allow-empty is necessary because the patch may be empty, or result in no changes
- # update (2024-11-22) --allow-empty is not available in the git version in our devcontainers
- # so we have to do a janky error check below
- try:
- await computing_environment.run_git(("apply", "--verbose", str(patch_file)), is_error_logged=is_error_logged)
- except RunCommandError as e:
- raise PatchApplicationError(f"Failed to apply patch: {e}") from e
- finally:
- await computing_environment.delete_file(patch_file)
-
-
-async def get_git_diff(
- computing_environment: ComputingEnvironment,
- commit_hash: str | None = None,
- staged: bool = False,
- is_error_logged: bool = True,
- include_binary: bool = True,
-) -> str:
- """Get the diff for the current repo state."""
- # make sure `is_stripped=False` otherwise patch can be invalid
- command = ["diff", "--full-index"]
- if include_binary:
- # Without --binary, diffs of binary files will just contain a summary statement such as "Binary files a/file.bin and b/file.bin differ".
- # Such diffs cannot be applied, but are useful for inclusion in LLM prompts.
- command.append("--binary")
- if staged:
- command.append("--staged")
- if commit_hash:
- command.append(commit_hash)
- return await computing_environment.run_git(command, is_stripped=False, is_error_logged=is_error_logged)
-
-
-async def get_diff_between_hashes(computing_environment: ComputingEnvironment, old_hash: str, new_hash: str) -> str:
- """Get the diff between two git hashes."""
- # make sure `is_stripped=False` otherwise patch can be invalid
- return await computing_environment.run_git(
- ["diff", "--full-index", "--binary", old_hash, new_hash], is_stripped=False
- )
-
-
-async def get_patch_for_commit(computing_environment: ComputingEnvironment, commit_hash: str) -> str:
- """Get the patch for a given commit hash."""
- return await computing_environment.run_git(["show", "--pretty=format:", "--patch", commit_hash], is_stripped=False)
-
-
-async def get_untracked_files(
- computing_environment: ComputingEnvironment,
-) -> tuple[str, ...]:
- """Get the untracked files in the repo."""
- result = await computing_environment.run_git(["ls-files", "--others", "--exclude-standard"], is_error_logged=False)
- return tuple([line.strip() for line in result.splitlines() if line.strip()])
-
-
-async def get_untracked_file_diff(
- computing_environment: ComputingEnvironment,
- file_path: str,
- include_binary: bool = True,
-) -> str:
- """Get the diff for a untracked file.
-
- Note this function will raise a RunCommandError if the there is no diff for the untracked file or if there
- is another error running the command. So it is best to use this function after checking that the file is untracked
- using get_untracked_files function.
- """
- command = ["diff", "--no-index"]
- if include_binary:
- command.append("--binary")
- untracked_diff = await computing_environment.run_git(
- command + ["/dev/null", str(file_path)],
- # Unfortunately, `--no-index` implies `--exit-code`, which will cause git diff to return an exit code of 1
- # if the diff is not empty. So we can't use check=True here. We'll check for an empty output to detect failures.
- check=False,
- is_error_logged=True,
- is_stripped=False,
- )
- if not untracked_diff:
- raise RunCommandError(f"Unable to diff untracked file {file_path}")
- return untracked_diff
-
-
-async def get_staged_unstaged_and_combined_diffs(
- computing_environment: ComputingEnvironment, base_commit: str = "HEAD"
-) -> tuple[str, str, str]:
- """Get the staged diff, the unstaged diff, and the combined diff"""
- staged_diff = await get_git_diff(computing_environment, staged=True)
- unstaged_diff = await get_git_diff(computing_environment, staged=False)
- combined_diff = await get_git_diff(computing_environment, commit_hash=base_commit)
- return staged_diff, unstaged_diff, combined_diff
-
-
-async def get_unmerged_blob_hashes(
- computing_environment: ComputingEnvironment,
-) -> tuple[str, ...]:
- """Get the blob hashes of all the unmerged files in the repo."""
- result = await computing_environment.run_command(
- ["bash", "-c", "git ls-files --unmerged | awk '{print $2}' | sort -u"],
- check=False,
- )
- return tuple(line.strip() for line in result.splitlines() if line.strip() != "")
-
-
-async def get_staged_blob_hashes(
- computing_environment: ComputingEnvironment,
-) -> tuple[str, ...]:
- """Get the blob hashes of all the staged files in the repo."""
- staged_blob_hashes = await computing_environment.run_command(
- [
- "bash",
- "-c",
- 'staged_blobs=$(git diff --full-index --binary --cached --name-only --diff-filter=ACMRT | while read file; do git ls-files --stage "$file" | awk \'{print $2}\'; done); echo "$staged_blobs"',
- ]
- )
- return tuple(line.strip() for line in staged_blob_hashes.splitlines() if line.strip() != "")
-
-
-async def get_blob_content_by_hash(computing_environment: ComputingEnvironment, blob_hash: str) -> bytes:
- """Get the content of a blob by its hash."""
- result = await computing_environment.run_git(["cat-file", "-p", blob_hash], is_stripped=False)
- return result.encode("utf-8")
-
-
-async def get_unmerged_and_staged_blob_contents_by_hash(
- computing_environment: ComputingEnvironment,
-) -> dict[str, bytes]:
- """Get the contents of all the unmerged and staged blobs in the repo."""
- unmerged_blob_hashes = await get_unmerged_blob_hashes(computing_environment)
- staged_blob_hashes = await get_staged_blob_hashes(computing_environment)
- all_relevant_blob_hashes = unmerged_blob_hashes + staged_blob_hashes
- blob_content_tasks_by_hash = {
- blob_hash: get_blob_content_by_hash(computing_environment, blob_hash) for blob_hash in all_relevant_blob_hashes
- }
- blob_contents = await asyncio.gather(*blob_content_tasks_by_hash.values())
- return {
- blob_hash: blob_content for blob_hash, blob_content in zip(blob_content_tasks_by_hash.keys(), blob_contents)
- }
-
-
-async def write_blob_content(computing_environment: ComputingEnvironment, blob_hash: str, blob_content: bytes) -> None:
- """Write the content of a blob to the repo."""
- # write the blob content to a temp file
- temp_file = anyio.Path(f"/tmp/{blob_hash}")
- try:
- await computing_environment.write_file(temp_file, blob_content, mode="wb")
- # write the blob to the repo
- result = await computing_environment.run_git(["hash-object", "-w", str(temp_file)])
- assert result.strip() == blob_hash, f"Expected blob hash {blob_hash}, got {result.strip()}"
- finally:
- await computing_environment.delete_file(temp_file)
-
-
-async def write_blob_content_by_hash(
- computing_environment: ComputingEnvironment, blob_content_by_hash: dict[str, bytes]
-) -> None:
- """Write the content of all the blobs to the repo."""
- tasks = []
- for blob_hash, blob_content in blob_content_by_hash.items():
- tasks.append(write_blob_content(computing_environment, blob_hash, blob_content))
- await asyncio.gather(*tasks)
-
-
-async def get_modified_files_with_conflicts(
- computing_environment: ComputingEnvironment,
-) -> tuple[str, ...]:
- """Get the modified files with conflicts."""
- commands = [
- "diff --check --staged --full-index --binary",
- "diff --check --full-index --binary",
- ]
- conflicted_files = set()
- for command in commands:
- result = await computing_environment.run_git(command.split(), check=False, is_error_logged=False)
- # output is of the form:
- # test.txt:2: leftover conflict marker
-
- for line in result.splitlines():
- parts = line.split(":", maxsplit=1)
- if len(parts) == 2:
- file_path, message = parts
- if "leftover conflict marker" in message.strip():
- conflicted_files.add(file_path)
- return tuple(conflicted_files)
-
-
-async def get_conflicted_pathnames(
- computing_environment: ComputingEnvironment,
-) -> tuple[str, ...]:
- """Get the pathnames of all the conflicted files in the repo."""
- result = await computing_environment.run_git(["diff", "--full-index", "--binary", "--name-only", "--diff-filter=U"])
- return tuple(result.splitlines())
-
-
-async def get_conflicted_contents_by_path(
- computing_environment: ComputingEnvironment,
-) -> dict[str, bytes]:
- """Get the contents of all the conflicted files in the repo."""
- conflicted_files = await get_conflicted_pathnames(computing_environment)
- conflicted_contents_by_path: dict[str, bytes] = {}
- for file_path in conflicted_files:
- content = await computing_environment.read_file(file_path, mode="rb")
- assert isinstance(content, bytes), f"Expected bytes, got {type(content)}"
- conflicted_contents_by_path[file_path] = content
- return conflicted_contents_by_path
-
-
-async def get_modified_pathnames(
- computing_environment: ComputingEnvironment,
-) -> tuple[str, ...]:
- """Get the pathnames of all the modified files in the repo."""
- result = await computing_environment.run_command(["bash", "-c", "git status --porcelain | awk '{print $2}'"])
- return tuple(result.splitlines())
-
-
-async def get_modified_file_contents_by_path(
- computing_environment: ComputingEnvironment,
-) -> dict[str, bytes]:
- """Get the contents of all the modified files in the repo."""
- modified_pathnames = await get_modified_pathnames(computing_environment)
- modified_file_contents_by_path: dict[str, bytes] = {}
- for pathname in modified_pathnames:
- content = await computing_environment.read_file(pathname, mode="rb")
- assert isinstance(content, bytes), f"Expected bytes, got {type(content)}"
- modified_file_contents_by_path[pathname] = content
- return modified_file_contents_by_path
-
-
-async def get_repo_url(computing_environment: ComputingEnvironment) -> str:
- repo_url = await computing_environment.run_git(["remote", "get-url", "origin"])
- if repo_url.startswith("git@"):
- # convert ssh url to https
- repo_url = repo_url.replace(":", "/")
- repo_url = f"https://{repo_url[4:]}"
- if "https://oauth2:" in repo_url:
- # remove the oauth2 prefix
- # repo_url is something like https://oauth2:{token}@gitlab.com/.../.git
- # change it to https://gitlab.com/.../.git
- suffix = repo_url.split("@")[-1]
- repo_url = "https://" + suffix
- return repo_url
-
-
-async def get_main_branch_name_for_repo(
- computing_environment: ComputingEnvironment, default_branch: str | None = None
-) -> str:
- """Get the name of the main branch for the repo.
-
- Attempts to detect whether the repository uses 'main', 'master', or another name
- as its primary branch by checking for common branch names in order of preference.
- """
- possible_main_branches = ["main", "master", "trunk", "development"]
-
- if default_branch is not None and default_branch not in possible_main_branches:
- possible_main_branches.insert(0, default_branch)
-
- # First check if any of the common main branch names exist
- # and return the first one that does
- for branch in possible_main_branches:
- if await is_branch_exists(computing_environment, branch):
- return branch
-
- # If we couldn't find a common main branch, try to determine the default branch
- # This gets the branch that HEAD points to in a newly cloned repo
- default_remote_branch = await computing_environment.run_git(
- ["symbolic-ref", "refs/remotes/origin/HEAD"], is_error_logged=False
- )
- if default_remote_branch:
- # Format is typically refs/remotes/origin/main, so extract the last part
- default_remote_branch = default_remote_branch.strip().split("/")[-1]
- return default_remote_branch
- raise ValueError("Could not detect main branch for repo.")
diff --git a/imbue_core/imbue_core/computing_environment/data_types.py b/imbue_core/imbue_core/computing_environment/data_types.py
@@ -1,31 +0,0 @@
-from __future__ import annotations
-
-import subprocess
-from pathlib import Path
-from typing import Any
-
-import anyio
-
-# Use AnyPath type to match Sanctum
-AnyPath = Path | str | anyio.Path
-
-
-class RunCommandError(subprocess.CalledProcessError):
- """Custom exception for errors encountered during Git commands."""
-
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- self.cwd = kwargs.get("cwd", None)
- if "cwd" in kwargs:
- del kwargs["cwd"]
- super().__init__(*args, **kwargs)
-
- def __str__(self) -> str:
- return f"Command `{self.cmd}` returned non-zero exit status {self.returncode}.\nOutput: {self.stdout}\nError: {self.stderr}\nCWD: {self.cwd}"
-
-
-class PatchApplicationError(Exception):
- """Custom exception for errors encountered during patch application."""
-
-
-class FailedToMakeCommitError(Exception):
- """Custom exception for errors encountered during commit creation."""
diff --git a/imbue_core/imbue_core/error_utils.py b/imbue_core/imbue_core/error_utils.py
@@ -1,28 +0,0 @@
-"""Error handling utilities."""
-
-import sys
-import traceback
-
-import traceback_with_variables
-from traceback_with_variables import Format
-
-
-def get_traceback_with_vars(exception: BaseException | None = None) -> str:
-
- # be careful of potential performance regressions with increasing these limits
- tb_format = Format(max_value_str_len=100_000, max_exc_str_len=2_000_000)
- if exception is None:
- # no exception passed in; get the current exception. this will still be None if not in an exception handler
- exception = sys.exception()
- try:
- if exception is not None:
- # we are in an exception handler, use that for the traceback
- # for some reason this breaks when casting to an `Exception`, so just using type: ignore
- return traceback_with_variables.format_exc(exception, fmt=tb_format) # type: ignore
- else:
- # not in an exception handler, just get the current stack
- return traceback_with_variables.format_cur_tb(fmt=tb_format)
- except Exception as e:
- return (
- f"got exception while formatting traceback with `traceback_with_variables`: {traceback.format_exception(e)}"
- )
diff --git a/imbue_core/imbue_core/frozen_utils.py b/imbue_core/imbue_core/frozen_utils.py
@@ -6,10 +6,7 @@ from typing import Any
from typing import Iterable
from typing import Mapping
from typing import NoReturn
-from typing import Protocol
-from typing import Sequence
from typing import TYPE_CHECKING
-from typing import TypeAlias
from typing import TypeVar
from typing import cast
@@ -17,13 +14,8 @@ if TYPE_CHECKING:
from _typeshed import SupportsKeysAndGetItem
-class _SupportsLessThan(Protocol):
- def __lt__(self, __other: Any) -> bool: ...
-
-
T = TypeVar("T")
TV = TypeVar("TV")
-TK = TypeVar("TK", bound=_SupportsLessThan)
class FrozenMapping(Mapping[T, TV], ABC):
@@ -91,10 +83,6 @@ class FrozenDict(dict[T, TV], FrozenMapping[T, TV]):
return (FrozenDict, (dict(self),))
-def empty_mapping() -> FrozenDict[Any, Any]:
- return FrozenDict()
-
-
def deep_freeze_mapping(mapping: Mapping[T, TV]) -> FrozenDict[T, Any]:
return FrozenDict({key: cast(TV, _deep_freeze_any(value)) for key, value in mapping.items()})
@@ -118,24 +106,3 @@ def _deep_freeze_any(input_object: object) -> object:
return tuple(_freeze_iterable_values(input_object))
return input_object
-
-
-def deep_freeze_sequence(sequence: Sequence[T]) -> tuple[Any, ...]:
- return tuple(_freeze_iterable_values(sequence))
-
-
-# Recursive type alias that captures the possible types of JSON objects (e.g. from json.loads).
-JSON: TypeAlias = "str | int | bool | float | None | dict[str, JSON] | list[JSON]"
-
-
-# Immutable version of JSON.
-FrozenJSON: TypeAlias = "str | int | bool | float | None | FrozenDict[str, FrozenJSON] | tuple[FrozenJSON, ...]"
-
-
-def deep_freeze_json(json: JSON) -> FrozenJSON:
- if isinstance(json, dict):
- return FrozenDict({k: deep_freeze_json(v) for k, v in json.items()})
- elif isinstance(json, list):
- return tuple(deep_freeze_json(v) for v in json)
- else:
- return json
diff --git a/imbue_core/imbue_core/git.py b/imbue_core/imbue_core/git.py
@@ -1,587 +0,0 @@
-"""Utility abstractions for interacting with git repositories."""
-
-from __future__ import annotations
-
-import asyncio
-import contextlib
-import shlex
-import shutil
-import subprocess
-import sys
-from asyncio.subprocess import PIPE
-from asyncio.subprocess import STDOUT
-from contextlib import asynccontextmanager
-from io import StringIO
-from pathlib import Path
-from types import TracebackType
-from typing import Any
-from typing import AsyncGenerator
-from typing import AsyncIterator
-from typing import Self
-from typing import Sequence
-from typing import TYPE_CHECKING
-from typing import TextIO
-
-import anyio
-import attr
-from loguru import logger
-
-from imbue_core.async_monkey_patches import log_exception
-from imbue_core.async_utils import sync
-from imbue_core.async_utils import sync_contextmanager_func
-from imbue_core.computing_environment.computing_environment import assert_repo_is_clean
-from imbue_core.computing_environment.computing_environment import get_head_hash
-from imbue_core.computing_environment.computing_environment import git_add
-from imbue_core.computing_environment.computing_environment import is_repo_dirty
-from imbue_core.computing_environment.computing_environment import make_commit
-from imbue_core.computing_environment.computing_environment import (
- restore_all_staged_files,
-)
-from imbue_core.computing_environment.computing_environment import (
- restore_all_unstaged_changes,
-)
-from imbue_core.computing_environment.computing_environment import (
- run_command_with_retry_on_git_lock_error,
-)
-from imbue_core.computing_environment.data_types import AnyPath
-from imbue_core.computing_environment.data_types import RunCommandError
-
-if TYPE_CHECKING:
- # for proper file mode typing
- from _typeshed import OpenBinaryMode
- from _typeshed import OpenBinaryModeReading
- from _typeshed import OpenBinaryModeWriting
- from _typeshed import OpenTextMode
- from _typeshed import OpenTextModeReading
- from _typeshed import OpenTextModeWriting
-
-PYTHON_EXTENSION = ".py"
-
-
-def is_path_in_git_repo(path: Path) -> bool:
- """Check if a path is in a git repository."""
- if path.is_file():
- path = path.parent
- completed_process = subprocess.run(
- ["git", "-C", path, "rev-parse", "--is-inside-work-tree"],
- stdin=subprocess.DEVNULL,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- check=False,
- )
- if completed_process.returncode != 0:
- return False
- result = completed_process.stdout.decode().strip()
- assert result in ("true", "false"), result
- return result == "true"
-
-
-def get_git_repo_root() -> Path:
- """Gets a Path to the current git repo root, assuming that our cwd is somewhere inside the repo."""
- completed_process = subprocess.run(
- ("git", "rev-parse", "--show-toplevel"),
- stdin=subprocess.DEVNULL,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- check=True,
- )
- root_dir = Path(completed_process.stdout.decode().strip())
- assert root_dir.is_dir(), f"{root_dir} must be a directory"
- return root_dir
-
-
-def get_git_repo_root_from_path(path: Path) -> Path:
- """Gets a Path to the git repo root for the given path."""
- if path.is_file():
- path = path.parent
- completed_process = subprocess.run(
- ["git", "-C", path, "rev-parse", "--show-toplevel"],
- stdin=subprocess.DEVNULL,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- check=True,
- )
- root_dir = Path(completed_process.stdout.decode().strip())
- assert root_dir.is_dir(), f"{root_dir} must be a directory"
- return root_dir
-
-
-@attr.s(auto_attribs=True, frozen=True)
-class LocalGitRepo:
- """
- DEPRECATED: Unless you either need asyncio or the ability to run git commands remotely on non-local compute environments, consider using SyncLocalGitRepo
- from simple_git.py instead. SimpleGitRepo provides a subset of the functions available through computing_environment.py + LocalGitRepo, but
- in a single class made for synchronous use.
- """
-
- base_path: Path
-
- @classmethod
- def build_from_cwd(cls) -> Self:
- """Create a `LocalGitRepo` instance from the current working directory."""
- return cls(get_git_repo_root())
-
- async def run_git(
- self,
- command: Sequence[str],
- check: bool = True,
- cwd: AnyPath | None = None,
- is_error_logged: bool = True,
- is_stripped: bool = True,
- retry_on_git_lock_error: bool = True,
- ) -> str:
- """Run a git command in the repo.
-
- Example:
- ```
- git_repo.run_git("status")
- ```
- """
- # TODO: check for whether hooks should actually be run when we call this function
- # Note: this used to be within an asyncio lock to prevent the program from concurrently running git commands.
- # This lock was removed since it was within global state, a dangerous pattern, and wasn't preventing other users from interacting with the git repo.
- command = ["git"] + list(command)
- if not retry_on_git_lock_error:
- result = await self.run_command(command, check=check, is_error_logged=is_error_logged, cwd=cwd)
- else:
- result = await run_command_with_retry_on_git_lock_error(
- self, command, check=check, is_error_logged=is_error_logged, cwd=cwd
- )
- if is_stripped:
- return result.strip()
- return result
-
- sync_run_git = sync(run_git)
-
- async def run_command(
- self,
- command: Sequence[str],
- check: bool = True,
- secrets: dict[str, str] | None = None,
- cwd: AnyPath | None = None,
- is_error_logged: bool = True,
- ) -> str:
- """Run a command in the repo.
-
- Note, this can be used to run any command, not just git.
- """
- command_string = shlex.join(command)
- logger.trace(
- f"Running command: {command_string=} from cwd={cwd or self.base_path} with {secrets=} {check=} {is_error_logged=}"
- )
- proc = await asyncio.create_subprocess_exec(
- *command,
- cwd=cwd or self.base_path,
- stdin=subprocess.DEVNULL,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- env=secrets,
- )
- # note, need to be carefull not to strip() lines since whitespace may be important (e.g. for diffs)
- # return joined lines since mostly we only use the output for logging, and this way we arn't
- # passing around lots of lists. Also it's easy to parse by lines if needed
- stdout_bytes, stderr_bytes = await proc.communicate()
- try:
- stdout = stdout_bytes.decode("UTF-8")
- except UnicodeDecodeError as e:
- # If we don't encounter this, it likely means something was fixed upstream and we can safely delete
- log_exception(
- e,
- "Command {command_string} failed to decode stdout, replacing any invalid bytes which could lead to problems later",
- command_string=command_string,
- )
- stdout = stdout_bytes.decode("UTF-8", errors="replace")
- stderr = stderr_bytes.decode("UTF-8")
- if check and proc.returncode != 0:
- error_message = f"command run from cwd={self.base_path} failed with exit code {proc.returncode} and stdout:\n{stdout}\nstderr:\n{stderr}"
- if is_error_logged:
- logger.error(
- f"command attempted: '{command_string}' from cwd={self.base_path}\nerror message: {error_message}"
- )
- # this should not be None, but do this to satisfy type checker, int or None we throw the same error
- returncode = proc.returncode or -1
- raise RunCommandError(
- cmd=command_string,
- stderr=stderr,
- returncode=returncode,
- cwd=cwd or self.base_path,
- )
- return stdout
-
- @contextlib.asynccontextmanager
- async def _open_file(
- self,
- relative_path: AnyPath,
- cwd: AnyPath | None = None,
- mode: OpenTextMode | OpenBinaryMode = "r",
- mkdir_if_missing: bool = True,
- ) -> AsyncGenerator[anyio.AsyncFile[Any], None]:
- logger.trace("opening file {} in cwd {} with mode {}", relative_path, cwd, mode)
- if cwd is not None:
- sb_file_path = str(Path(cwd) / relative_path)
- else:
- sb_file_path = str(self.base_path / relative_path)
-
- if mkdir_if_missing:
- parent_dir = anyio.Path(sb_file_path).parent
- await parent_dir.mkdir(parents=True, exist_ok=True)
-
- f: anyio.AsyncFile[Any] | None = None
- try:
- f = await anyio.Path(sb_file_path).open(mode=mode) # type: ignore
- yield f
- finally:
- if f is not None:
- await f.aclose()
-
- async def write_file(
- self,
- relative_path: AnyPath,
- content: str | bytes | None,
- cwd: AnyPath | None = None,
- mode: OpenTextModeWriting | OpenBinaryModeWriting = "w",
- mkdir_if_missing: bool = True,
- ) -> None:
- if content is None:
- await self.delete_file(relative_path, cwd=cwd)
- return
-
- async with self._open_file(relative_path, cwd=cwd, mode=mode, mkdir_if_missing=mkdir_if_missing) as f:
- logger.trace("writing to file {} in cwd {} with mode {}", relative_path, cwd, mode)
- # pyre-fixme[6]: content can be bytes
- await f.write(content)
-
- async def delete_file(self, relative_path: AnyPath, cwd: AnyPath | None = None) -> None:
- logger.trace("deleting the file {} in cwd {}", relative_path, cwd)
- if cwd is not None:
- sb_file_path = str(Path(cwd) / relative_path)
- else:
- sb_file_path = str(self.base_path / relative_path)
- await anyio.Path(sb_file_path).unlink()
-
- async def read_file(
- self,
- relative_path: AnyPath,
- cwd: AnyPath | None = None,
- mode: OpenTextModeReading | OpenBinaryModeReading = "r",
- mkdir_if_missing: bool = True,
- ) -> str | bytes:
- async with self._open_file(relative_path, cwd=cwd, mode=mode, mkdir_if_missing=mkdir_if_missing) as f:
- logger.trace("reading file {} in cwd {} with mode {}", relative_path, cwd, mode)
- content = await f.read()
- assert isinstance(content, str) or isinstance(content, bytes)
- return content
-
- async def head_hash(self) -> str:
- """Get the hash of the current HEAD commit."""
- return await get_head_hash(self)
-
- async def is_git_repo(self) -> bool:
- """Check that repo is valid git repo."""
- return await anyio.Path(self.base_path / ".git").exists()
-
- sync_is_git_repo = sync(is_git_repo)
-
- async def assert_clean(self) -> None:
- await assert_repo_is_clean(self)
-
- sync_assert_clean = sync(assert_clean)
-
- async def configure_git(
- self,
- git_user_name: str | None = None,
- git_user_email: str | None = None,
- initial_commit_message: str = "initial commit",
- is_recreating: bool = False,
- ) -> None:
- """Configure git repo with user name and email."""
- if is_recreating:
- if await self.is_git_repo():
- await asyncio.to_thread(shutil.rmtree, self.base_path / ".git")
-
- # order here is important
- # ref https://stackoverflow.com/questions/11656761/git-please-tell-me-who-you-are-error?noredirect=1
- await self.run_git(("init",))
- if git_user_name:
- await self.run_git(("config", "user.name", f"'{git_user_name}'"))
- if git_user_email:
- await self.run_git(("config", "user.email", f"'{git_user_email}'"))
- await self.run_git(("add", "."))
- await self.run_git(("commit", "-m", f"'{initial_commit_message}'"))
- branch_name = await self.run_git(("symbolic-ref", "HEAD"))
- if not branch_name == "refs/heads/main":
- # rename master to main for consistency
- await self.run_git(("branch", "-m", "master", "main"))
-
- sync_configure_git = sync(configure_git)
-
- @asynccontextmanager
- async def temporary_commit(
- self,
- tag_prefix: str,
- commit_message: str,
- raise_on_head_hash_change: bool = False,
- ) -> AsyncIterator[str]:
- """Context manager to make a temporary commit and tag in the repo."""
- await self.run_git(("commit", "-am", commit_message, "--allow-empty", "--no-verify"))
- head_hash = await self.head_hash()
- tag = f"{tag_prefix}/{head_hash}"
- await self.run_git(("tag", "--force", tag))
- await self.run_git(("push", "origin", tag, "--no-verify"))
- try:
- yield head_hash
- finally:
- # This is susceptible to a race condition (if the user makes a commit between the time we check the head hash and the time we reset the state).
- # So it's important to keep any block that uses this context manager short - make the commit, copy it to the controller, and work there. Don't hold the repo hostage.
- current_head_hash = await self.head_hash()
- if current_head_hash != head_hash and raise_on_head_hash_change:
- raise AssertionError(
- f"Head hash has changed from {head_hash} to {current_head_hash} since the temporary commit was made. Giving up on resetting git state, please address this manually."
- )
- else:
- await self.run_git(("reset", "HEAD~"))
-
- sync_temporary_commit = sync_contextmanager_func(temporary_commit)
-
- async def copy_repo(self, new_repo_path: Path, exists_ok: bool = True) -> "LocalGitRepo":
- """Make a full copy of this repo in a new directory.
-
- Note, this will copy all the files in the repo into a new local directory, but will not handle
- configuring the new directory as a git repo.
- """
- if await anyio.Path(new_repo_path).exists():
- if not exists_ok:
- raise FileExistsError(
- f"New repo path '{new_repo_path} already exists. Set `exists_ok=True` if you are happy overwriting it, otherwise select new path."
- )
- await asyncio.to_thread(shutil.rmtree, new_repo_path)
- await asyncio.to_thread(
- shutil.copytree,
- self.base_path,
- new_repo_path,
- dirs_exist_ok=True,
- ignore=shutil.ignore_patterns(".git", ".gitsecret"),
- )
- return LocalGitRepo(new_repo_path)
-
- sync_copy_repo = sync(copy_repo)
-
- async def is_path_in_repo(self, file_path: str | Path | anyio.Path) -> bool:
- """Check whether a given file path is within this repo.
-
- FIXME: It doesn't seem entirely necessary to enumerate all of the files with a particular extension
- just to check if a single file (whose path we know) is in the repo.
- """
- if isinstance(file_path, (str, Path)):
- file_path = anyio.Path(file_path)
- extension = file_path.suffix
- return file_path in await self.get_all_files_by_extension(extension=extension)
-
- async def _get_file_path(self, file_path: str | Path) -> anyio.Path:
- path = anyio.Path(file_path)
- if not path.is_absolute():
- path = anyio.Path(self.base_path / path)
- assert await path.exists(), f"File {path} does not exist."
- return path
-
- async def safely_read_file_from_repo(self, file_path: str | Path) -> str:
- """Safely read file from repo."""
- path = await self._get_file_path(file_path)
- assert await self.is_path_in_repo(path), f"File {path} is not in repo."
- return await path.read_text()
-
- sync_safely_read_file_from_repo = sync(safely_read_file_from_repo)
-
- async def get_all_files_by_extension(self, extension: str = PYTHON_EXTENSION) -> tuple[Path, ...]:
- """Get absolute path of all files in the repo with given extension."""
- paths: list[Path] = []
- async for path in anyio.Path(self.base_path).rglob(f"*{extension}"):
- paths.append(Path(path))
- return tuple(paths)
-
-
-@attr.s(auto_attribs=True, frozen=True, kw_only=True)
-class WritableLocalGitRepo(LocalGitRepo):
- """A Local Git Repo with support for modifying files and reseting to an initial state.
-
- Note, this does not handle creating a copy of an existing repo, or anything. Rather it adds some additional
- functionality for actually writing to files in the repo. For the creation of a separate copy of an existing
- repo which supports making changes without affecting the main repo see the `temp_writable_local_git_repo` context
- manager.
-
- It is also recommended the `build_from_repo` function when creating a WritableLocalGitRepo, as this will
- make sure that any untracked and uncommited changes are managed correctly.
- """
-
- initial_git_hash: str
- stash_git_hash: str | None
-
- @classmethod
- async def build_from_repo(cls, repo: LocalGitRepo) -> "WritableLocalGitRepo":
- """Create a writable repo from an local repo."""
- init_hash = await repo.head_hash()
- if await is_repo_dirty(repo):
- await repo.run_git(("add", "."))
- stash_hash = await make_commit(repo, "stashing uncommited and untracked changes")
- else:
- stash_hash = None
-
- return cls(
- base_path=repo.base_path,
- initial_git_hash=init_hash,
- stash_git_hash=stash_hash,
- )
-
- async def _setup(self) -> None:
- init_hash = await self.head_hash()
- expected_hash = self.stash_git_hash or self.initial_git_hash
- assert init_hash == expected_hash, "git repo is not currently at expected commit"
- assert await self.is_git_repo(), f"{self.base_path} is not a git repo"
- await self.assert_clean()
-
- async def reset(self) -> None:
- """Reset the repo to the state it was in when this class was created."""
- await restore_all_staged_files(self)
- await restore_all_unstaged_changes(self)
- if self.stash_git_hash:
- # hard to reset to commit with stashed untracked and uncommited changes
- await self.run_git(("reset", "--hard", self.stash_git_hash))
- # soft reset to return to initial commit but keep the untracked and uncommited changes
- await self.run_git(("reset", "--soft", self.initial_git_hash))
- # unstage untracked and uncommited changes
- await restore_all_staged_files(self)
- else:
- await self.run_git(("reset", "--hard", self.initial_git_hash))
- await self.run_git(("clean", "-f"))
- await self.run_git(("checkout", self.initial_git_hash))
-
- current_git_hash = await self.head_hash()
- assert (
- current_git_hash == self.initial_git_hash
- ), f"base branch changed, current git hash ({current_git_hash}) != initial git hash ({self.initial_git_hash})"
- await self.assert_clean()
-
- async def apply_change_to_file(self, file_path: str | Path, new_contents: str) -> None:
- """Apply change to a single file."""
- path = await self._get_file_path(file_path)
- assert await self.is_path_in_repo(path), f"File {path} is not in repo."
- await path.write_text(new_contents)
- await git_add(self, str(path))
-
- async def __aenter__(self) -> "WritableLocalGitRepo":
- await self._setup()
- return self
-
- async def __aexit__(
- self,
- exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- await self.reset()
-
-
-async def copy_files_from_one_repo_to_another(
- src_repo_path: Path, dst_repo_path: Path, relative_file_paths: Sequence[str | Path]
-) -> None:
- """Copies files from src to dst repo using the relative file paths."""
- for relative_path in relative_file_paths:
- src_file_path = src_repo_path / relative_path
- dst_file_path = anyio.Path(dst_repo_path / relative_path)
- # make sure necessary directories exist in destination
- await dst_file_path.parent.mkdir(parents=True, exist_ok=True)
- await asyncio.to_thread(shutil.copy2, src_file_path, dst_file_path)
-
-
-def get_repo_url_from_folder(repo_path: Path) -> str:
- try:
- repo_url = subprocess.check_output(
- ["git", "remote", "get-url", "origin"],
- cwd=repo_path,
- universal_newlines=True,
- ).strip()
- except subprocess.CalledProcessError:
- raise
- else:
- if repo_url.startswith("git@"):
- # convert ssh url to https
- repo_url = repo_url.replace(":", "/")
- repo_url = f"https://{repo_url[4:]}"
- if "https://oauth2:" in repo_url:
- # remove the oauth2 prefix
- # repo_url is something like https://oauth2:{token}@gitlab.com/.../.git
- # change it to https://gitlab.com/.../.git
- # This will happen if repo was originallycloned using oauth2
- suffix = repo_url.split("@")[-1]
- repo_url = "https://" + suffix
- return repo_url
-
-
-def get_repo_base_path() -> Path:
- working_directory = Path(__file__).parent
- try:
- return Path(
- _run_command_and_capture_output(["git", "rev-parse", "--show-toplevel"], cwd=working_directory).strip()
- )
- except subprocess.CalledProcessError as e:
- try:
- return working_directory.parents[1]
- except IndexError:
- raise UnableToFindRepoBase() from e
-
-
-def _run_command_and_capture_output(args: Sequence[str], cwd: Path | None = None) -> str:
- arg_str = " ".join(shlex.quote(arg) for arg in args)
- print(f"Running command: {arg_str}", file=sys.stderr)
- with subprocess.Popen(args, text=True, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as proc:
- with StringIO() as output:
- _handle_output(proc, output, sys.stderr)
- if proc.wait() != 0:
- raise subprocess.CalledProcessError(proc.returncode, cmd=args, output=output.getvalue())
- return output.getvalue()
-
-
-class UnableToFindRepoBase(Exception):
- """Raised when the base of the repository cannot be found."""
-
-
-def _handle_output(process: subprocess.Popen[str], *files: TextIO) -> None:
- process_stdout = process.stdout
- assert process_stdout is not None
- while True:
- output = process_stdout.read(1)
- if output:
- for f in files:
- f.write(output)
- elif process.poll() is not None:
- break
-
-
-def get_diff_without_index(diff: str) -> str:
- new_lines = []
- for line in diff.splitlines():
- if line.startswith("index "):
- # We replace index lines with "index 0000000..0000000 100644" because:
- # - `0000000..0000000` ensures no real object hashes are referenced, making the diff neutral.
- # - `100644` is the standard file mode for non-executable files in git diffs, ensuring compatibility.
- # - This keeps the diff format valid while removing specific index information.
- new_lines.append("index 0000000..0000000 100644")
- else:
- new_lines.append(line)
- return "\n".join(new_lines).strip()
-
-
-def is_diffs_without_index_equal(diff_1: str, diff_2: str) -> bool:
- return get_diff_without_index(diff_1) == get_diff_without_index(diff_2)
-
-
-# Copy-pasted from imbue to avoid moving the whole hammers machinery over to imbue-core.
-async def get_lines_from_process(shell_command: str, is_exit_code_validated: bool = True, **kwargs: Any) -> list[str]:
- p = await asyncio.create_subprocess_shell(shell_command, stdin=PIPE, stdout=PIPE, stderr=STDOUT, **kwargs)
- lines = [x.decode("UTF-8") for x in (await p.communicate())[0].splitlines()]
- if is_exit_code_validated:
- joined_lines = "\n".join(lines)
- assert (
- p.returncode == 0
- ), f"command failed: {shell_command}\nwith output:\n{joined_lines} with exit code {p.returncode}"
- return lines
diff --git a/imbue_core/imbue_core/git_data_types.py b/imbue_core/imbue_core/git_data_types.py
@@ -1,35 +0,0 @@
-from datetime import datetime
-from typing import Annotated
-
-from pydantic.functional_validators import PlainValidator
-
-from imbue_core.pydantic_serialization import FrozenModel
-from imbue_core.pydantic_serialization import SerializableModel
-
-
-def _validate_git_timestamp(value: str) -> str:
- try:
- datetime.fromisoformat(value)
- return value
- except ValueError:
- raise ValueError(f"Invalid git timestamp: {value}")
-
-
-class CommitTimestamp(SerializableModel):
- author_ts: Annotated[str, PlainValidator(_validate_git_timestamp)]
- committer_ts: Annotated[str, PlainValidator(_validate_git_timestamp)]
-
-
-class CommitMetadata(FrozenModel):
- commit: str
- tree_hash: str
- message: str
- commit_time: CommitTimestamp
-
- @property
- def body(self) -> str:
- return self.message.split("\n", 1)[-1]
-
- @property
- def subject(self) -> str:
- return self.message.split("\n", 1)[0]
diff --git a/imbue_core/imbue_core/ids.py b/imbue_core/imbue_core/ids.py
@@ -1,44 +0,0 @@
-from typing import Any
-from typing import Self
-
-from pydantic import GetCoreSchemaHandler
-from pydantic_core import core_schema
-
-
-class NonEmptyStr(str):
- # pyre-fixme[11]: pyre seems to have some trouble with Self in some specific cases, including type[Self]
- def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self:
- value = str.__new__(cls, *args, **kwargs)
- if len(value) == 0:
- raise ValueError("NonEmptyStr cannot be empty")
- return value
-
- @classmethod
- def __get_pydantic_core_schema__(cls, source_type: type, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
- """
- Support transparently deserializing strings into ObjectID instances and vice versa.
- """
- return core_schema.no_info_before_validator_function(
- lambda raw_value: (cls(raw_value) if isinstance(raw_value, str) else raw_value),
- core_schema.union_schema(
- [
- core_schema.is_instance_schema(cls),
- core_schema.str_schema(),
- ]
- ),
- serialization=core_schema.plain_serializer_function_ser_schema(
- lambda instance: str(instance), return_schema=core_schema.str_schema()
- ),
- )
-
-
-class ExternalID(NonEmptyStr):
- pass
-
-
-class AssistantMessageID(ExternalID):
- pass
-
-
-class ToolUseID(ExternalID):
- pass
diff --git a/imbue_core/imbue_core/itertools.py b/imbue_core/imbue_core/itertools.py
@@ -1,17 +1,13 @@
import contextlib
import itertools
-from typing import AsyncGenerator
-from typing import Callable
from typing import Generator
from typing import Iterable
from typing import Sequence
from typing import TypeVar
-from typing import cast
from imbue_core.errors import ImbueError
T = TypeVar("T")
-TV = TypeVar("TV")
class ImbueItertoolsValueError(ImbueError, ValueError):
@@ -43,30 +39,6 @@ def first(iterable: Iterable[T]) -> T | None:
return next(iter(iterable), None)
-async def iterable_to_async(iterable: Iterable[T]) -> AsyncGenerator[T, None]:
- for item in iterable:
- yield item
-
-
-# TODO delete/migrate out computronium/computronium/universal/utils.py
-def generate_unique(
- source: Iterable[T],
- key: Callable[[T], TV] = cast(Callable[[T], TV], lambda item: item),
-) -> Generator[T, None, None]:
- unique = set()
- for item in source:
- value = key(item)
- if value in unique:
- continue
- yield item
- unique.add(value)
-
-
-def generate_flattened(iterable: Iterable[Iterable[T]]) -> Generator[T, None, None]:
- for item in iterable:
- yield from item
-
-
# TODO replace with itertools.batched when we can require Python 3.12+
def generate_chunks(iterable: Iterable[T], chunk_size: int) -> Generator[tuple[T, ...], None, None]:
"""Yield successive n-sized chunks from any iterable"""
diff --git a/imbue_core/imbue_core/llm_testing_utils.py b/imbue_core/imbue_core/llm_testing_utils.py
@@ -1,57 +0,0 @@
-from pathlib import Path
-
-from loguru import logger
-from syrupy.assertion import SnapshotAssertion
-
-from imbue_core.caching import AsyncCache
-from imbue_core.cattrs_serialization import deserialize_from_json
-from imbue_core.cattrs_serialization import serialize_to_json
-
-
-async def preload_llm_cache(persistent_cache_path: Path, temp_cache: AsyncCache) -> None:
- logger.info(
- "Loading existing cache from {persistent_cache_path}",
- persistent_cache_path=persistent_cache_path,
- )
- assert persistent_cache_path.exists(), f"Cache file {persistent_cache_path} does not exist."
- existing_data = deserialize_from_json(persistent_cache_path.read_text())
- async with temp_cache as cache:
- for key, value in existing_data.items():
- await cache.set(key, value)
-
-
-async def record_llm_responses_in_cache(temp_cache: AsyncCache, persistent_cache_path: Path) -> None:
- logger.info(
- "Updating cache (!!!) at {persistent_cache_path}",
- persistent_cache_path=persistent_cache_path,
- )
- async with temp_cache as cache:
- all_keys = await cache.get_all_keys()
- data = await cache.get_all(all_keys)
- if data:
- persistent_cache_path.parent.mkdir(parents=True, exist_ok=True)
- persistent_cache_path.write_text(serialize_to_json(data), encoding="utf-8")
-
-
-def _sanitize_snapshot_name(snapshot_name: str) -> str:
- return snapshot_name.replace("/", "").replace("\\", "")
-
-
-def get_cache_file_from_snapshot_core(snapshot: SnapshotAssertion, suffix: str) -> Path:
- # To prevent syrupy from cleaning the cache up immediately after written, we use this suffix and register it with pytest.
- # Make sure to add a line like `--snapshot-ignore-file-extensions={suffix}` to the project's pytest.ini
-
- # Goal here is a cache file per test, not per test-file.
- test_file = Path(snapshot.test_location.filepath)
- snapshot_dir = test_file.parent / "__snapshots__" / test_file.stem
- snapshot_dir.mkdir(parents=True, exist_ok=True)
- cache_file = snapshot_dir / f"{_sanitize_snapshot_name(snapshot.test_location.testname)}.{suffix}"
- return cache_file
-
-
-def get_cache_file_from_snapshot(snapshot: SnapshotAssertion) -> Path:
- return get_cache_file_from_snapshot_core(snapshot, "llm_cache_json")
-
-
-def get_count_tokens_cache_file_from_snapshot(snapshot: SnapshotAssertion) -> Path:
- return get_cache_file_from_snapshot_core(snapshot, "count_tokens_cache_json")
diff --git a/imbue_core/imbue_core/nested_evolver.py b/imbue_core/imbue_core/nested_evolver.py
@@ -2,8 +2,6 @@
One of the design goals is that mypy, autocomplete, and automatic refactoring work for the assignments made into these nested structures.
-See: https://imbue-ai.slack.com/archives/C05D0SM2RT5/p1726185313480779?thread_ts=1722865932.537289&cid=C05D0SM2RT5
-
If you make changes here and then the tests fail with:
```
E RecursionError: maximum recursion depth exceeded
@@ -17,7 +15,6 @@ import threading
from typing import Any
from typing import Callable
from typing import Generic
-from typing import TypeGuard
from typing import TypeVar
from typing import cast
@@ -28,7 +25,8 @@ from imbue_core.frozen_utils import FrozenDict
from imbue_core.pydantic_utils import model_update
_T = TypeVar("_T")
-ObjectType = TypeVar("ObjectType")
+
+_threading_local = threading.local()
def evolver(obj: _T) -> _T:
@@ -62,47 +60,6 @@ def chill(evolver: _T) -> _T:
return cast_evolver.chill()
-_threading_local = threading.local()
-
-
-# TODO: since mutate and thaw are stateful, if you call one without the other, you run into problems.
-def thaw(obj: _T) -> _T:
- global _threading_local
- if hasattr(_threading_local, "evolved_obj"):
- raise ValueError("Thaw does not support nested thawing.")
- # pyre-ignore[16]: we're deliberately setting evolved_obj for the first time here
- _threading_local.evolved_obj = evolver(obj)
- return _threading_local.evolved_obj
-
-
-# TODO: mypy complains because the input isn't anything related to ObjectType, but the output is.
-# This also means the type checking doesn't quite work since it can't infer the return type of this function correctly
-def mutate(dest: _T, src: Callable[[], _T]) -> ObjectType: # type: ignore
- assign(dest, src)
- try:
- # pyre-ignore[34]: we don't have generic functions yet, so pyre complains that ObjectType isn't in the input
- evolved_obj: ObjectType = _threading_local.evolved_obj # pyre-ignore[16]: pyre doesn't know about evolved_obj
- return chill(evolved_obj)
- except AttributeError as e:
- raise ValueError("You must call mutate on a thawed object") from e
- finally:
- delattr(_threading_local, "evolved_obj")
-
-
-def mutate_from_dict(dest: ObjectType, src: dict[str, Any]) -> ObjectType:
- # Warning: using this function doesn't provide mypy type checking at the call site, but it allows a single interface for attrs and pydantic classes
- # In most cases the above function should be used instead
- evolved_obj = evolver(dest)
- for key, value in src.items():
- assign(getattr(evolved_obj, key), lambda: value)
- return chill(evolved_obj)
-
-
-def evolver_isinstance(evolver: Any, cls: type[_T]) -> TypeGuard[_T]:
- assert isinstance(evolver, _Evolver) # Tricked you, type system!
- return evolver.isinstance(cls)
-
-
class _RegularValue:
regular_value: Any
@@ -146,7 +103,13 @@ class _FrozenDictValue:
class _Evolver(Generic[_T]):
# pyre-ignore[13]: pyre is confused by the trickery here
- _value: _RegularValue | _AttrValue | _TupleValue | _FrozenDictValue | _PydanticModelValue
+ _value: (
+ _RegularValue
+ | _AttrValue
+ | _TupleValue
+ | _FrozenDictValue
+ | _PydanticModelValue
+ )
def __init__(self, initial_value: _T) -> None:
super().__init__()
@@ -174,14 +137,18 @@ class _Evolver(Generic[_T]):
if item not in value.child_evolver_by_name:
child_obj = getattr(value.attr_value, item)
result = evolver(child_obj)
- assert isinstance(result, _Evolver), "Expose a lie to the type system."
+ assert isinstance(result, _Evolver), (
+ "Expose a lie to the type system."
+ )
value.child_evolver_by_name[item] = result
return value.child_evolver_by_name[item]
elif isinstance(value, _PydanticModelValue):
if item not in value.child_evolver_by_name:
child_obj = getattr(value.pydantic_model_value, item)
result = evolver(child_obj)
- assert isinstance(result, _Evolver), "Expose a lie to the type system."
+ assert isinstance(result, _Evolver), (
+ "Expose a lie to the type system."
+ )
value.child_evolver_by_name[item] = result
return value.child_evolver_by_name[item]
raise TypeError(
@@ -204,7 +171,9 @@ class _Evolver(Generic[_T]):
elif isinstance(value, _FrozenDictValue):
if key not in value.frozen_dict_evolvers:
# Presumably we're going to evolver_assign to this very soon.
- cast(_FrozenDictValue, self._value).frozen_dict_evolvers[key] = _Evolver(_RegularValue(None))
+ cast(_FrozenDictValue, self._value).frozen_dict_evolvers[key] = (
+ _Evolver(_RegularValue(None))
+ )
return cast(_FrozenDictValue, self._value).frozen_dict_evolvers[key]
raise TypeError(
f"You're using [square_brackets] access {key=} on an object of {type(self._value)=} that doesn't support this (should have been a mypy error)."
@@ -214,38 +183,38 @@ class _Evolver(Generic[_T]):
"""Recursively apply the recorded changes to the original object and return a new frozen instance."""
if isinstance(self._value, _AttrValue):
new_children: dict[str, Any] = {
- name: chill(child) for name, child in self._value.child_evolver_by_name.items()
+ name: chill(child)
+ for name, child in self._value.child_evolver_by_name.items()
}
assert attr.has(self._value.attr_value.__class__)
return cast(
_T,
- attr.evolve(cast(Any, cast(_AttrValue, self._value).attr_value), **new_children),
+ attr.evolve(
+ cast(Any, cast(_AttrValue, self._value).attr_value), **new_children
+ ),
)
elif isinstance(self._value, _PydanticModelValue):
return cast(
_T,
model_update(
self._value.pydantic_model_value,
- update={name: chill(child) for name, child in self._value.child_evolver_by_name.items()},
+ update={
+ name: chill(child)
+ for name, child in self._value.child_evolver_by_name.items()
+ },
),
)
elif isinstance(self._value, _TupleValue):
- return cast(_T, tuple(evolver.chill() for evolver in self._value.tuple_evolvers))
+ return cast(
+ _T, tuple(evolver.chill() for evolver in self._value.tuple_evolvers)
+ )
elif isinstance(self._value, _RegularValue):
return cast(_T, self._value.regular_value)
elif isinstance(self._value, _FrozenDictValue):
return cast(
_T,
- FrozenDict({k: v.chill() for k, v in self._value.frozen_dict_evolvers.items()}),
+ FrozenDict(
+ {k: v.chill() for k, v in self._value.frozen_dict_evolvers.items()}
+ ),
)
raise ValueError(f"This Evolver has no value to evolve, {type(self._value)=}.")
-
- def isinstance(self, cls: type[_T]) -> bool:
- """Check if the object being evolved is an instance of the given class."""
- if isinstance(self._value, _AttrValue):
- return isinstance(self._value.attr_value, cls)
- elif isinstance(self._value, _PydanticModelValue):
- return isinstance(self._value.pydantic_model_value, cls)
- elif isinstance(self._value, _FrozenDictValue):
- return cls == FrozenDict
- return False
diff --git a/imbue_core/imbue_core/pydantic_serialization.py b/imbue_core/imbue_core/pydantic_serialization.py
@@ -2,17 +2,12 @@ import threading
from typing import Any
from typing import TypeVar
from typing import cast
-from typing import get_args
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Discriminator
-from pydantic import GetCoreSchemaHandler
-from pydantic import Json
from pydantic.alias_generators import to_camel
-from pydantic_core import core_schema as pyd_core_schema
-from imbue_core.frozen_utils import FrozenDict
from imbue_core.nested_evolver import _Evolver
from imbue_core.nested_evolver import chill
from imbue_core.nested_evolver import evolver
@@ -28,7 +23,9 @@ class EvolvableModel:
# pyre-ignore[47]: pyre is not so easily tricked
def evolve(self: T, attribute: V, new_value: V) -> T:
# pyre-ignore[16]: pyre doesn't know about evolved_obj
- assert _threading_local.evolved_obj is not None, ".ref() must be called before evolve"
+ assert _threading_local.evolved_obj is not None, (
+ ".ref() must be called before evolve"
+ )
assert isinstance(attribute, _Evolver) # Tricked you, type system!
dest_evolver: _Evolver[T] = cast(_Evolver[T], attribute)
@@ -45,20 +42,6 @@ class EvolvableModel:
return _threading_local.evolved_obj
-class FrozenModel(EvolvableModel, BaseModel):
- """
- The base class for most internal data (that does not need to be serialized).
-
- We generally prefer to keep data immutable in order to avoid side effects, race conditions, etc
- """
-
- model_config = ConfigDict(
- frozen=True,
- extra="forbid",
- arbitrary_types_allowed=False,
- )
-
-
class MutableModel(BaseModel):
"""
The base class for any internal data that strictly must be mutable. Should be used sparingly.
@@ -100,23 +83,6 @@ class SerializableModel(EvolvableModel, BaseModel, Serializable):
pydantic_extra.clear()
-def model_dump(obj: BaseModel, is_camel_case: bool = False) -> dict:
- return obj.model_dump(by_alias=is_camel_case)
-
-
-def model_load(model_type: type[T], data: dict) -> T:
- return model_type.model_validate(data)
-
-
-def model_dump_json(obj: BaseModel | Json, is_camel_case: bool = False) -> str:
- # pyre-fixme[16]: pyre complains that obj can be pydantic.types.AnyType, which has no model_dump_json
- return obj.model_dump_json(by_alias=is_camel_case)
-
-
-def model_load_json(model_type: type[T], data: str) -> T:
- return model_type.model_validate_json(data)
-
-
# this is mostly here for the default cases.
# When you want to upgrade a model (and keep it backwards compatible), you can make a custom discriminator
# (eg, that looks for the old type name or converts the old class names)
@@ -150,26 +116,3 @@ def build_discriminator(
return getattr(obj, field_name)
return Discriminator(discriminator=discriminator)
-
-
-class PydanticFrozenDictAnnotation:
- @classmethod
- def __get_pydantic_core_schema__(
- cls, source_type: Any, handler: GetCoreSchemaHandler
- ) -> pyd_core_schema.CoreSchema:
- def validate_from_dict(d: dict | FrozenDict) -> FrozenDict:
- return FrozenDict(d)
-
- frozendict_schema = pyd_core_schema.chain_schema(
- [
- # pyre-ignore[16]: pyre is confused by using dict as a type like this
- handler.generate_schema(dict[*get_args(source_type)]),
- pyd_core_schema.no_info_plain_validator_function(validate_from_dict),
- pyd_core_schema.is_instance_schema(FrozenDict),
- ]
- )
- return pyd_core_schema.json_or_python_schema(
- json_schema=frozendict_schema,
- python_schema=frozendict_schema,
- serialization=pyd_core_schema.plain_serializer_function_ser_schema(dict),
- )
diff --git a/imbue_core/imbue_core/retry_utils.py b/imbue_core/imbue_core/retry_utils.py
@@ -1,30 +0,0 @@
-from typing import Callable
-
-from loguru import logger
-from tenacity import RetryCallState
-
-
-def _log_before_sleep(retry_state: RetryCallState, log_fn: Callable[[str], None]) -> None:
- fn_name = getattr(retry_state.fn, "__name__", "unknown")
- sleep_time = retry_state.next_action.sleep if retry_state.next_action is not None else 0
- outcome = retry_state.outcome
- if outcome is not None:
- exception = outcome.exception()
- error_message = type(exception).__name__ + ": " + str(exception)
- else:
- error_message = "unknown"
- log_fn(
- f"Retrying {fn_name} in {sleep_time:.2f} seconds, attempt {retry_state.attempt_number} after error: {error_message}"
- )
-
-
-def log_before_sleep(retry_state: RetryCallState) -> None:
- _log_before_sleep(retry_state, logger.debug)
-
-
-def log_error_before_sleep(retry_state: RetryCallState) -> None:
- _log_before_sleep(retry_state, logger.error)
-
-
-def log_trace_before_sleep(retry_state: RetryCallState) -> None:
- _log_before_sleep(retry_state, logger.trace)
diff --git a/imbue_core/imbue_core/secrets_utils.py b/imbue_core/imbue_core/secrets_utils.py
@@ -61,21 +61,6 @@ def parse_secrets_file(filepath: str | pathlib.Path) -> dict[str, str]:
return out
-# TODO: this is gross and bad--we should make better handling for secrets.
-# Right now we read the necessary secrets out of the bashenv files
def get_secret(secret_name: str) -> str | None:
- value = os.environ.get(secret_name)
- if value is not None:
- return value
- secrets_files = (
- "science/secrets/environment_vars/bashenv.sh",
- "science/secrets/environment_vars/bashenv_secrets.sh",
- "science/secrets/environment_vars/common_vars.sh",
- )
- for file in secrets_files:
- if os.path.exists(file):
- secrets = parse_secrets_file(file)
- value = secrets.get(secret_name, None)
- if value is not None:
- return value
- return None
+ """Get a secret from environment variables."""
+ return os.environ.get(secret_name)
diff --git a/imbue_core/imbue_core/section.py b/imbue_core/imbue_core/section.py
@@ -1,131 +0,0 @@
-"""
-Provides a context manager for logging a potentially time-consuming process, or a "section".
-
-- Prints logs at start and end of a section.
-
-- Prints a Markdown-like heading for nested sections: "#" for top-level sections, "##" for one level down, and so on.
-
-- Emits structured logs for easier query.
-
-The SectionWrapper context manager can be used to time the __enter__ and __exit__ methods of an existing context manager.
-"""
-
-import contextlib
-import threading
-import time
-from asyncio import CancelledError
-from types import TracebackType
-from typing import TypeVar
-
-from loguru import logger
-
-_monotonic_base = time.monotonic()
-
-
-def _monotonic_time() -> float:
- """A wrapper around time.monotonic() to make the return values a bit smaller and easier to read by a human."""
- return time.monotonic() - _monotonic_base
-
-
-class _ThreadLocal(threading.local):
- def __init__(self) -> None:
- self.next_section_level: int = 0
-
-
-_thread_local = _ThreadLocal()
-
-
-class Section(contextlib.ContextDecorator):
- def __init__(self, message: str, log_level: int | str = "INFO") -> None:
- # TODO: loguru doesn't properly display integer log levels like e.g. logging.INFO
- self.message = message
- self.log_level = log_level
-
- def __enter__(self) -> "Section":
- level = _thread_local.next_section_level
- _thread_local.next_section_level += 1
- self.header = "#" * (level + 1) # pyre-ignore[16]
- self.start_monotonic_time = _monotonic_time() # pyre-ignore[16]
- start_clock_time = time.time()
- self.section = { # pyre-ignore[16]
- "name": self.message,
- "level": level,
- "start_monotonic_time": self.start_monotonic_time,
- "start_clock_time": start_clock_time,
- }
- logger.log(
- self.log_level,
- f"{self.header} Start: {self.message}",
- section=self.section,
- )
- return self
-
- def __exit__(
- self,
- exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- _thread_local.next_section_level -= 1
- finish_monotonic_time = _monotonic_time()
- finish_clock_time = time.time()
- # pyre-ignore[16]: we set this on __enter__
- duration_seconds = finish_monotonic_time - self.start_monotonic_time
- section = self.section | { # pyre-ignore[16]: we set this on __enter__
- "finish_monotonic_time": finish_monotonic_time,
- "finish_clock_time": finish_clock_time,
- "duration_seconds": duration_seconds,
- }
- self.elapsed = duration_seconds # pyre-ignore[16]
-
- header = self.header # pyre-ignore[16]: we set this on __enter__
-
- if exc_val is None:
- logger.log(
- self.log_level,
- f"{header} Done: {self.message} (took {duration_seconds:.2f} seconds)",
- section=section | {"result": "success"},
- )
- else:
- if isinstance(exc_val, CancelledError):
- logger.log(
- self.log_level,
- f"{header} Cancelled: {self.message} (took {duration_seconds:.2f} seconds)",
- section=section | {"result": "cancelled"},
- )
- else:
- logger.log(
- self.log_level,
- f"{header} Failed: {self.message} (within {duration_seconds:.2f} seconds)",
- section=section | {"result": "failed"},
- )
-
-
-T = TypeVar("T")
-
-
-# pyre-ignore[24]: pyre doesn't understand AbstractContextManager
-class SectionWrapper(contextlib.AbstractContextManager[T]):
- # pyre-ignore[24]
- def __init__(
- self,
- cm: contextlib.AbstractContextManager[T],
- enter_message: str,
- exit_message: str,
- ) -> None:
- self._cm = cm
- self._enter_message = enter_message
- self._exit_message = exit_message
-
- def __enter__(self) -> T:
- with Section(self._enter_message):
- return self._cm.__enter__()
-
- def __exit__(
- self,
- exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- with Section(self._exit_message):
- self._cm.__exit__(exc_type, exc_val, exc_tb)
diff --git a/imbue_core/imbue_core/serialization.py b/imbue_core/imbue_core/serialization.py
@@ -1,6 +1,5 @@
import builtins
import datetime
-import inspect
import json
from enum import Enum
from functools import cached_property
@@ -10,7 +9,6 @@ from pathlib import PosixPath
from traceback import format_tb
from types import TracebackType
from typing import Any
-from typing import Callable
from typing import Hashable
from typing import Iterable
from typing import Mapping
@@ -34,9 +32,9 @@ from imbue_core.fixed_traceback import FixedTraceback
from imbue_core.pydantic_serialization import SerializableModel
from imbue_core.serialization_types import Serializable
-assert (
- version("yasoo") == "0.12.6"
-), "This code was written for yasoo 0.12.6 and requires inheriting / monkeypatching the deserializer, so you probably don't want to use any other version without fixing TupleDeserializer"
+assert version("yasoo") == "0.12.6", (
+ "This code was written for yasoo 0.12.6 and requires inheriting / monkeypatching the deserializer, so you probably don't want to use any other version without fixing TupleDeserializer"
+)
T = TypeVar("T", bound=Hashable)
@@ -57,7 +55,12 @@ class TupleDeserializer(Deserializer):
return data
if isinstance(data, list):
list_types = self._get_list_types(obj_type, data)
- return tuple([self._deserialize(d, t, type_key, allow_extra_fields, all_globals) for t, d in list_types])
+ return tuple(
+ [
+ self._deserialize(d, t, type_key, allow_extra_fields, all_globals)
+ for t, d in list_types
+ ]
+ )
assert isinstance(data, dict), f"Expected a dict, but got {type(data)}"
@@ -65,7 +68,11 @@ class TupleDeserializer(Deserializer):
if type_key is not None:
type_data = data.get(type_key, None)
- if type_data is not None and type_data.startswith("builtins.") and type_data != "builtins.dict":
+ if (
+ type_data is not None
+ and type_data.startswith("builtins.")
+ and type_data != "builtins.dict"
+ ):
return data["value"]
# TODO: we need to potentially handle `builtins.dict`
@@ -81,10 +88,15 @@ class TupleDeserializer(Deserializer):
# return data["value"]
# TODO: remove this hack. Many of our sqlite files (search s3_sqlite_path) have FrozenDicts
- if isinstance(type_key, str) and data.get(type_key, None) == "flax.core.frozen_dict.FrozenDict":
+ if (
+ isinstance(type_key, str)
+ and data.get(type_key, None) == "flax.core.frozen_dict.FrozenDict"
+ ):
data[type_key] = "imbue_core.frozen_utils.FrozenMapping"
# we deliberately pass in a `None` type_key sometimes, which results in just returning obj_type
- obj_type = self._get_object_type(obj_type, data, type_key, all_globals) # pyre-ignore[6]
+ obj_type = self._get_object_type(
+ obj_type, data, type_key, all_globals
+ ) # pyre-ignore[6]
if type_key in data:
data.pop(type_key)
real_type, generic_args = normalize_type(obj_type, all_globals)
@@ -95,7 +107,9 @@ class TupleDeserializer(Deserializer):
bases = {ancestor for b in bases for ancestor in b.__bases__}
if not ignore_custom_deserializer:
- deserialization_method = self._custom_deserializers.get(obj_type, self._custom_deserializers.get(real_type))
+ deserialization_method = self._custom_deserializers.get(
+ obj_type, self._custom_deserializers.get(real_type)
+ )
if deserialization_method:
return deserialization_method(data)
for base_class, method in self._inheritance_deserializers.items():
@@ -119,7 +133,7 @@ class TupleDeserializer(Deserializer):
if e.name.lower() == value.lower():
return e
return real_type(value)
- # TODO (49780118-61e5-446b-b44b-cabb3ffc0ba2): serialization currently breaks with builtin.dicts and dicts with non-string keys
+ # TODO: serialization currently breaks with builtin.dicts and dicts with non-string keys
# if you have weird keys in your dict this branch won't be hit and your object won't be properly deserialized
elif issubclass(real_type, Mapping):
key_type = generic_args[0] if generic_args else None
@@ -141,9 +155,13 @@ class TupleDeserializer(Deserializer):
)
elif issubclass(real_type, Iterable):
# If we got here it means data is not a list, so obj_type came from the data itself and is safe to use
- return self._load_iterable(data, obj_type, type_key, allow_extra_fields, all_globals)
+ return self._load_iterable(
+ data, obj_type, type_key, allow_extra_fields, all_globals
+ )
elif real_type != obj_type:
- return self._deserialize(data, real_type, type_key, allow_extra_fields, external_globals)
+ return self._deserialize(
+ data, real_type, type_key, allow_extra_fields, external_globals
+ )
else:
raise
@@ -165,7 +183,9 @@ class TupleDeserializer(Deserializer):
# TODO: probably a good idea to ensure that all dicts are frozen as well...
class FrozenSerializer(Serializer):
- def __init__(self, force_serialization: bool, allow_unsafe_list_serialization: bool = False) -> None:
+ def __init__(
+ self, force_serialization: bool, allow_unsafe_list_serialization: bool = False
+ ) -> None:
super().__init__()
self._force_serialization = force_serialization
self._allow_unsafe_list_serialization = allow_unsafe_list_serialization
@@ -183,10 +203,12 @@ class FrozenSerializer(Serializer):
logger.info("Converting list to tuple for serialization: {}", obj)
obj = tuple(obj)
else:
- raise Exception(f"Lists are not allowed for serialization. Use tuples instead. Current iterable: {obj}")
- assert isinstance(
- obj, (tuple, frozenset, bytes)
- ), f"All iterables should be tuples or frozenset. Received {obj}"
+ raise Exception(
+ f"Lists are not allowed for serialization. Use tuples instead. Current iterable: {obj}"
+ )
+ assert isinstance(obj, (tuple, frozenset, bytes)), (
+ f"All iterables should be tuples or frozenset. Received {obj}"
+ )
return cast(
list[object],
tuple(
@@ -217,7 +239,9 @@ class FrozenSerializer(Serializer):
return obj # type: ignore
if globals:
- self._custom_serializers = resolve_types(self._custom_serializers, globals) # type: ignore
+ self._custom_serializers = resolve_types(
+ self._custom_serializers, globals
+ ) # type: ignore
result = self._serialize(
obj,
@@ -254,7 +278,9 @@ class SerializedException(SerializableModel):
traceback_dict: JsonTypeAlias
@classmethod
- def build(cls, exception: BaseException, traceback: TracebackType | None = None) -> "SerializedException":
+ def build(
+ cls, exception: BaseException, traceback: TracebackType | None = None
+ ) -> "SerializedException":
if traceback is None:
traceback = exception.__traceback__
assert traceback is not None, " ".join(
@@ -265,59 +291,16 @@ class SerializedException(SerializableModel):
)
return SerializedException( # pyre-fixme[28]: pyre doesn't understand pydantic
exception=get_fully_qualified_name_for_error(exception),
- args=tuple(_convert_serialized_exception_args(x, traceback) for x in exception.args),
+ args=tuple(
+ _convert_serialized_exception_args(x, traceback) for x in exception.args
+ ),
traceback_dict=FixedTraceback.from_tb(traceback).as_dict(),
)
- @cached_property
- def traceback(self) -> FixedTraceback | None:
- traceback_dict = self.traceback_dict
- if traceback_dict is None:
- return None
- return FixedTraceback.from_dict(traceback_dict)
-
- @cached_property
- def exception_module(self) -> str:
- if "." in self.exception:
- return self.exception.rsplit(".", maxsplit=1)[0]
- return ""
-
- @cached_property
- def exception_type(self) -> str:
- return self.exception.rsplit(".", maxsplit=1)[-1]
-
- @cached_property
- def exception_class(self) -> type[BaseException]:
- if self.exception_module:
- return cast(
- type[BaseException],
- getattr(import_module(self.exception_module), self.exception_type, None),
- )
- else:
- return cast(type[BaseException], getattr(builtins, self.exception_type, None))
-
- def construct_instance(self) -> BaseException:
- try:
- exception = self.exception_class(*cast(tuple[Serializable, ...], self.args))
- except TypeError as e:
- message_with_arg_info = (
- f"Failed to construct exception {self.exception_class} with args {self.args}.",
- "Ensure that the exception class is serializable and can be constructed with the provided args.",
- )
- raise TypeError(" ".join(message_with_arg_info)) from e
-
- return exception
-
- def as_formatted_traceback(self) -> str:
- if self.traceback is None:
- traceback_str = ""
- else:
- # pyre-ignore[6]: pyre doesn't know that FixedTraceback is a traceback (since it's not a TracebackType)
- traceback_str = "".join(format_tb(self.traceback))
- return f"Traceback (most recent call last):\n{traceback_str}\n{self.exception}: {self.args}"
-
-def _convert_serialized_exception_args(error: Serializable, traceback: TracebackType | None = None) -> JsonTypeAlias:
+def _convert_serialized_exception_args(
+ error: Serializable, traceback: TracebackType | None = None
+) -> JsonTypeAlias:
if isinstance(error, BaseException):
return SerializedException.build(error, traceback=traceback)
elif isinstance(error, (list, tuple)):
@@ -338,14 +321,24 @@ def _convert_to_json_serializable_with_better_errors(
return obj # type: ignore
if isinstance(obj, Mapping):
return {
- key: _convert_to_json_serializable_with_better_errors(value, f"{path}.{key}") for key, value in obj.items()
+ key: _convert_to_json_serializable_with_better_errors(
+ value, f"{path}.{key}"
+ )
+ for key, value in obj.items()
}
if isinstance(obj, Iterable):
- return [_convert_to_json_serializable_with_better_errors(item, f"{path}[{i}]") for i, item in enumerate(obj)]
- raise TypeError(f'Found object of type "{type(obj).__name__}" at {path} which cannot be serialized')
+ return [
+ _convert_to_json_serializable_with_better_errors(item, f"{path}[{i}]")
+ for i, item in enumerate(obj)
+ ]
+ raise TypeError(
+ f'Found object of type "{type(obj).__name__}" at {path} which cannot be serialized'
+ )
-SERIALIZER = FrozenSerializer(force_serialization=False, allow_unsafe_list_serialization=False)
+SERIALIZER = FrozenSerializer(
+ force_serialization=False, allow_unsafe_list_serialization=False
+)
DESERIALIZER = TupleDeserializer()
# note: you cannot change this without changing other calls to yasoo, this is its default
@@ -408,14 +401,14 @@ def serialize_datetime(data: datetime.datetime) -> dict:
@DESERIALIZER.register()
def deserialize_datetime(data: dict) -> datetime.datetime:
- return datetime.datetime.fromtimestamp(data["time"], datetime.timezone.utc if data.get("tzaware", None) else None)
-
+ return datetime.datetime.fromtimestamp(
+ data["time"], datetime.timezone.utc if data.get("tzaware", None) else None
+ )
-def serialize_to_dict(obj: Any) -> dict[str, Any]:
- return cast(dict[str, Any], SERIALIZER.serialize(obj))
-
-def serialize_to_json(obj: Any, indent: int | None = None, sort_keys: bool = False) -> str:
+def serialize_to_json(
+ obj: Any, indent: int | None = None, sort_keys: bool = False
+) -> str:
try:
return json.dumps(SERIALIZER.serialize(obj), indent=indent, sort_keys=sort_keys)
except Exception as e:
@@ -424,39 +417,8 @@ def serialize_to_json(obj: Any, indent: int | None = None, sort_keys: bool = Fal
def deserialize_from_json(data: str) -> Any:
try:
- return DESERIALIZER.deserialize(json.loads(data)) # pyre-ignore[20]: pyre doesn't understand deserialize
+ return DESERIALIZER.deserialize(
+ json.loads(data)
+ ) # pyre-ignore[20]: pyre doesn't understand deserialize
except Exception as e:
raise SerializationError(str(e)) from e
-
-
-def deserialize_from_dict(data: dict[str, Any]) -> Any:
- try:
- return DESERIALIZER.deserialize(data) # pyre-ignore[20]: pyre doesn't understand deserialize
- except Exception as e:
- raise SerializationError(str(e)) from e
-
-
-def deserialize_from_dict_with_type(data: dict[str, Any], obj_type: type[T]) -> T:
- try:
- result = DESERIALIZER.deserialize(data, obj_type=obj_type)
- assert isinstance(result, obj_type), f"Expected an object of type {obj_type}, but got {result}"
- return result
- except Exception as e:
- raise SerializationError(str(e)) from e
-
-
-def deserialize_from_json_with_type(data: str | bytes | bytearray, obj_type: type[T]) -> T:
- try:
- return deserialize_from_dict_with_type(json.loads(data), obj_type=obj_type)
- except Exception as e:
- raise SerializationError(str(e)) from e
-
-
-def get_serializable_properties(obj: Any) -> dict[str, Any]:
- members = inspect.getmembers(type(obj))
- marked_members = [name for name, member in members if is_serializable_property(member)]
- return {name: getattr(obj, name) for name in marked_members}
-
-
-def is_serializable_property(func: Callable) -> bool:
- return getattr(func, "_imbue_is_serializable_property", False)
diff --git a/imbue_core/imbue_core/simple_git.py b/imbue_core/imbue_core/simple_git.py
@@ -1,215 +0,0 @@
-"""This file implements a subset of the functionality found in git.py and compute_environment.py, but everything is synchronous."""
-
-import shlex
-import subprocess
-import time
-from pathlib import Path
-from typing import Sequence
-
-from loguru import logger
-
-from imbue_core.async_monkey_patches import log_exception
-from imbue_core.computing_environment.data_types import AnyPath
-from imbue_core.computing_environment.data_types import RunCommandError
-
-
-class SyncLocalGitRepo:
- """
- Provides different operations that you can perform over a git repository.
-
- Implements a subset of our async LocalGitRepo, but also pulls in some of the functions from computing_environment.py
- that normally would operate on a LocalGitRepo.
-
- Over time, we can probably replace LocalGitRepo and computing_environment more or less fully, as we're migrating
- code away from asyncio.
-
- Feel free to move additional functions from computing_environment.py into this class as needed. Usually, you just need
- to remove async/await keywords, and replace calls to compute_environment. member functions with self. calls.
- """
-
- _base_path: Path
-
- def __init__(self, base_path: Path) -> None:
- self._base_path = base_path
-
- @property
- def base_path(self) -> Path:
- """The base path of the git repo."""
- return self._base_path
-
- def run_git(
- self,
- command: Sequence[str],
- check: bool = True,
- cwd: AnyPath | None = None,
- is_error_logged: bool = True,
- is_stripped: bool = True,
- retry_on_git_lock_error: bool = True,
- ) -> str:
- """Run a git command in the repo.
-
- Example:
- ```
- git_repo.run_git("status")
- ```
- """
- command = ["git"] + list(command)
- if not retry_on_git_lock_error:
- result = self.run_command(command, check=check, is_error_logged=is_error_logged, cwd=cwd)
- else:
- result = self._run_command_with_retry_on_git_lock_error(
- command, check=check, is_error_logged=is_error_logged, cwd=cwd
- )
- if is_stripped:
- return result.strip()
- return result
-
- def run_command(
- self,
- command: Sequence[str],
- check: bool = True,
- secrets: dict[str, str] | None = None,
- cwd: AnyPath | None = None,
- is_error_logged: bool = True,
- ) -> str:
- """Run a command in the repo.
-
- Note, this can be used to run any command, not just git.
- """
- command_string = shlex.join(command)
- logger.trace(
- f"Running command: {command_string=} from cwd={cwd or self.base_path} with {secrets=} {check=} {is_error_logged=}"
- )
- completed_proc = subprocess.run(
- command,
- cwd=cwd or self._base_path,
- stdin=subprocess.DEVNULL,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- env=secrets,
- )
- # note, need to be carefull not to strip() lines since whitespace may be important (e.g. for diffs)
- # return joined lines since mostly we only use the output for logging, and this way we arn't
- # passing around lots of lists. Also it's easy to parse by lines if needed
- try:
- stdout = completed_proc.stdout.decode("UTF-8")
- except UnicodeDecodeError as e:
- # If we don't encounter this, it likely means something was fixed upstream and we can safely delete
- log_exception(
- e,
- "Command {command_string} failed to decode stdout, replacing any invalid bytes which could lead to problems later",
- command_string=command_string,
- )
- stdout = completed_proc.stdout.decode("UTF-8", errors="replace")
- stderr = completed_proc.stderr.decode("UTF-8")
- if check and completed_proc.returncode != 0:
- error_message = f"command run from cwd={self.base_path} failed with exit code {completed_proc.returncode} and stdout:\n{stdout}\nstderr:\n{stderr}"
- if is_error_logged:
- logger.error(
- f"command attempted: '{command_string}' from cwd={self.base_path}\nerror message: {error_message}"
- )
- # this should not be None, but do this to satisfy type checker, int or None we throw the same error
- returncode = completed_proc.returncode or -1
- raise RunCommandError(
- cmd=command_string,
- stderr=stderr,
- returncode=returncode,
- cwd=cwd or self.base_path,
- )
- return stdout
-
- def get_git_diff(
- self,
- commit_hash: str | None = None,
- staged: bool = False,
- is_error_logged: bool = True,
- include_binary: bool = True,
- ) -> str:
- """Get the diff for the current repo state."""
- # make sure `is_stripped=False` otherwise patch can be invalid
- command = ["diff", "--full-index"]
- if include_binary:
- # Without --binary, diffs of binary files will just contain a summary statement such as "Binary files a/file.bin and b/file.bin differ".
- # Such diffs cannot be applied, but are useful for inclusion in LLM prompts.
- command.append("--binary")
- if staged:
- command.append("--staged")
- if commit_hash:
- command.append(commit_hash)
- return self.run_git(command, is_stripped=False, is_error_logged=is_error_logged)
-
- def get_untracked_files(self) -> tuple[str, ...]:
- """Get the untracked files in the repo."""
- result = self.run_git(["ls-files", "--others", "--exclude-standard"], is_error_logged=False)
- return tuple([line.strip() for line in result.splitlines() if line.strip()])
-
- def get_untracked_file_diff(self, file_path: str, include_binary: bool = True) -> str:
- """Get the diff for a untracked file.
-
- Note this function will raise a RunCommandError if the there is no diff for the untracked file or if there
- is another error running the command. So it is best to use this function after checking that the file is untracked
- using get_untracked_files function.
- """
- command = ["diff", "--no-index"]
- if include_binary:
- command.append("--binary")
- untracked_diff = self.run_git(
- command + ["/dev/null", str(file_path)],
- # Unfortunately, `--no-index` implies `--exit-code`, which will cause git diff to return an exit code of 1
- # if the diff is not empty. So we can't use check=True here. We'll check for an empty output to detect failures.
- check=False,
- is_error_logged=True,
- is_stripped=False,
- )
- if not untracked_diff:
- raise RunCommandError(f"Unable to diff untracked file {file_path}")
- return untracked_diff
-
- def is_commit_a_branch(self, commit_hash: str) -> bool:
- """Check if the given git ref is a branch."""
- try:
- self.run_git(
- ("show-ref", "--verify", "-q", f"refs/heads/{commit_hash}"),
- is_error_logged=False,
- check=True,
- )
- return True
- except RunCommandError as e:
- if e.returncode == 1:
- return False
- raise
-
- def get_merge_base(self, branch_name: str, target_branch: str) -> str:
- """Get the merge base of the given branch and target branch.
-
- The merge base is the most recent commit that is on both branches.
- """
- return self.run_git(["merge-base", branch_name, target_branch], is_error_logged=False)
-
- def _run_command_with_retry_on_git_lock_error(
- self,
- command: Sequence[str],
- check: bool = True,
- is_error_logged: bool = True,
- cwd: AnyPath | None = None,
- ) -> str:
- max_retries = 50
- retry_count = 0
- retry_delay = 0.1 # seconds
- while True:
- try:
- return self.run_command(
- command,
- check=check,
- is_error_logged=is_error_logged and retry_count >= max_retries,
- cwd=cwd,
- )
- except RunCommandError as e:
- error_message = str(e)
- if "fatal: Unable to create" in error_message and ".git/index.lock': File exists" in error_message:
- if retry_count >= max_retries:
- raise
- time.sleep(retry_delay)
- retry_count += 1
- else:
- raise
diff --git a/imbue_core/imbue_core/test_utils.py b/imbue_core/imbue_core/test_utils.py
@@ -1,88 +0,0 @@
-import contextlib
-import shutil
-import tempfile
-import time
-from pathlib import Path
-from typing import AsyncGenerator
-from typing import Callable
-from typing import Generator
-
-from loguru import logger
-from syrupy.assertion import SnapshotAssertion
-
-from imbue_core.agents.llm_apis.data_types import CachedCostedLanguageModelResponse
-from imbue_core.agents.llm_apis.data_types import CachedCostedModelResponse
-from imbue_core.agents.llm_apis.data_types import CachedCountTokensResponse
-from imbue_core.agents.llm_apis.llm_testing_utils import check_llm_responses_in_cache
-from imbue_core.caching import AsyncCache
-from imbue_core.llm_testing_utils import get_cache_file_from_snapshot
-from imbue_core.llm_testing_utils import get_count_tokens_cache_file_from_snapshot
-from imbue_core.llm_testing_utils import preload_llm_cache
-from imbue_core.llm_testing_utils import record_llm_responses_in_cache
-
-
-def info_if_not_quiet(quiet: bool, message: str) -> None:
- if not quiet:
- logger.info(message)
-
-
-@contextlib.contextmanager
-def create_temp_dir(root_dir: Path) -> Generator[Path, None, None]:
- with tempfile.TemporaryDirectory(dir=root_dir) as temp_dir:
- yield Path(temp_dir)
- shutil.rmtree(temp_dir)
-
-
-def wait_until(condition: Callable[[], bool], timeout: float = 5.0, interval: float = 0.5) -> None:
- start_time = time.monotonic()
- while True:
- if condition():
- return
- if time.monotonic() - start_time > timeout:
- raise TimeoutError("Condition not met within timeout period")
- time.sleep(interval)
-
-
-async def make_llm_cache_with_snapshot_core(
- snapshot: SnapshotAssertion,
- json_cache_file: Path,
- value_cls: type[CachedCostedModelResponse],
- quiet: bool = True,
- suffix: str = "",
-) -> AsyncGenerator[Path, None]:
- info_if_not_quiet(quiet, f"Using llm_cache_pathfixture: {json_cache_file=}")
-
- with tempfile.TemporaryDirectory() as cache_path_string:
- cache_path = Path(cache_path_string)
- cache_context = AsyncCache(cache_path, value_cls)
-
- if not snapshot.session.update_snapshots:
- if json_cache_file.exists():
- await preload_llm_cache(json_cache_file, cache_context)
-
- yield cache_path
- info_if_not_quiet(quiet, "Finished with llm_cache_pathfixture, updating cache if needed.")
-
- if snapshot.session.update_snapshots:
- await record_llm_responses_in_cache(cache_context, json_cache_file)
-
- await check_llm_responses_in_cache(snapshot, cache_context, suffix)
- info_if_not_quiet(quiet, "Finished with llm_cache_pathfixture, checking cache.")
-
-
-async def make_llm_cache_with_snapshot(snapshot: SnapshotAssertion, quiet: bool = True) -> AsyncGenerator[Path, None]:
- json_cache_file = get_cache_file_from_snapshot(snapshot)
- async for path in make_llm_cache_with_snapshot_core(
- snapshot, json_cache_file, CachedCostedLanguageModelResponse, quiet
- ):
- yield path
-
-
-async def make_count_tokens_cache_with_snapshot(
- snapshot: SnapshotAssertion, quiet: bool = True
-) -> AsyncGenerator[Path, None]:
- json_cache_file = get_count_tokens_cache_file_from_snapshot(snapshot)
- async for path in make_llm_cache_with_snapshot_core(
- snapshot, json_cache_file, CachedCountTokensResponse, quiet, "_count_tokens"
- ):
- yield path
diff --git a/imbue_core/pyproject.toml b/imbue_core/pyproject.toml
@@ -11,7 +11,6 @@ authors = [
{ name="Imbue", email="imbue@imbue.com" },
]
license = "MIT"
-# NOTE: This list is replicated in sculptor/pyproject.toml, and the copy there must be kept in sync.
dependencies = [
"anyio",
"attrs",
diff --git a/imbue_tools/README.md b/imbue_tools/README.md
@@ -1,5 +1,5 @@
# Purpose
-Shared functionality for imbue-cli tools like vet, imbue-retrieve, etc.
+Shared functionality for CLI tools like vet.
# Contents
- formatting git repos as LLM input
diff --git a/imbue_tools/imbue_tools/conftest.py b/imbue_tools/imbue_tools/conftest.py
@@ -1,44 +0,0 @@
-from pathlib import Path
-from typing import Callable
-from typing import Generator
-
-import pytest
-from pytest_asyncio import fixture as async_fixture
-from syrupy.assertion import SnapshotAssertion
-
-from imbue_core.agents.configs import LanguageModelGenerationConfig
-from imbue_core.agents.llm_apis.anthropic_api import AnthropicModelName
-from imbue_core.async_monkey_patches_test import explode_on_error # noqa: F401
-from imbue_core.test_repo_utils import make_simple_test_git_repo
-from imbue_core.test_utils import make_llm_cache_with_snapshot
-
-llm_cache_path = async_fixture(make_llm_cache_with_snapshot)
-
-
-# this is copied from sculptor/conftest.py
-# (it must be copied rather than imported because of the autouse)
-@pytest.fixture(autouse=True)
-def always_explode_on_error(
- explode_on_error: Callable[[], Generator[None, None, None]],
-) -> Generator[None, None, None]:
- """
- Ensures that we do not log errors or exceptions during testing.
-
- If your test is checking error handling behavior (and you expect to see a log_exception call),
- use the `expect_exact_logged_errors` decorator to suppress the logging of those errors.
- """
- yield
-
-
-def llm_config_for_test(
- llm_cache_path: Path, snapshot: SnapshotAssertion, is_caching_inputs: bool = False
-) -> LanguageModelGenerationConfig:
- return LanguageModelGenerationConfig(
- model_name=AnthropicModelName.CLAUDE_4_SONNET_2025_05_14,
- cache_path=llm_cache_path,
- is_running_offline=not snapshot.session.update_snapshots,
- is_caching_inputs=is_caching_inputs,
- )
-
-
-simple_test_git_repo = pytest.fixture(make_simple_test_git_repo)
diff --git a/imbue_tools/imbue_tools/get_conversation_history/get_conversation_history.py b/imbue_tools/imbue_tools/get_conversation_history/get_conversation_history.py
@@ -1,5 +1,4 @@
import json
-from pathlib import Path
from typing import assert_never
from loguru import logger
@@ -11,9 +10,6 @@ from vet_types.messages import ChatInputUserMessage
from vet_types.messages import ConversationMessageUnion
from vet_types.messages import ResponseBlockAgentMessage
-CONVERSATION_FILE_ENV_VAR = "CONVERSATION_FILE"
-TASK_SOURCE_BRANCH_ENV_VAR = "TASK_SOURCE_BRANCH"
-
class ConversationLoadingError(Exception):
pass
@@ -58,14 +54,6 @@ def format_conversation_history_for_prompt(
# === loading from file ===
-def load_conversation_history(
- conversation_file_path: Path,
-) -> tuple[ConversationMessageUnion, ...]:
- """Load a jsonl file into a list of conversation messages"""
- file_contents = conversation_file_path.read_text()
- return parse_conversation_history(file_contents)
-
-
def parse_conversation_history(
conversation_str: str,
) -> tuple[ConversationMessageUnion, ...]:
diff --git a/imbue_tools/imbue_tools/repo_utils/context_retrieval.py b/imbue_tools/imbue_tools/repo_utils/context_retrieval.py
@@ -1,9 +1,6 @@
-import asyncio
import threading
import time
-from contextlib import asynccontextmanager
from pathlib import Path
-from typing import AsyncGenerator
from typing import Generator
import pygit2
@@ -12,7 +9,6 @@ from pygit2.enums import ObjectType
from pygit2.repository import Repository
from imbue_core.async_utils import make_async
-from imbue_core.git import LocalGitRepo
from imbue_tools.repo_utils.diff_utils import apply_diffs_to_files
from imbue_tools.repo_utils.file_system import FileContents
from imbue_tools.repo_utils.file_system import InMemoryFileSystem
@@ -34,8 +30,6 @@ class RepoContextManager:
# We need the sync lock due to pygit2 being synchronous.
# It is mostly used for the blob data cache, but also for the repo contents by git hash cache.
self._lock = threading.Lock()
- # We need the async lock for tests TODO: we can probably remove this
- self._local_repo_async_lock: asyncio.Lock = asyncio.Lock()
@classmethod
def build(cls, repo_path: Path) -> "RepoContextManager":
@@ -125,12 +119,3 @@ class RepoContextManager:
else:
raise ValueError(f"Unexpected entry type in git tree: {entry.type}")
-
- @asynccontextmanager
- async def tmp_repo_context(self) -> AsyncGenerator[LocalGitRepo, None]:
- """
- This function is only used in tests
- TODO: we can probably remove it
- """
- async with self._local_repo_async_lock:
- yield LocalGitRepo(self.repo_path)
diff --git a/imbue_tools/imbue_tools/repo_utils/diff_utils.py b/imbue_tools/imbue_tools/repo_utils/diff_utils.py
@@ -8,7 +8,6 @@ from async_lru import alru_cache # type: ignore[undefined-attribute]: pyre on m
from loguru import logger
from imbue_tools.repo_utils.errors import DiffApplicationError
-from imbue_tools.repo_utils.errors import DiffCalculationError
from imbue_tools.repo_utils.file_system import FileContents
from imbue_tools.repo_utils.file_system import InMemoryFileSystem
from imbue_tools.repo_utils.file_system import SymlinkContents
@@ -18,65 +17,6 @@ from imbue_tools.repo_utils.file_system_utils import (
from imbue_tools.repo_utils.file_system_utils import (
temporary_local_dir_from_in_memory_file_system,
)
-from imbue_tools.repo_utils.file_system_utils import write_file_contents_to_dir
-
-
-class NonZeroReturncodeError(Exception):
- pass
-
-
-async def get_diff_between_files(old_file_contents: InMemoryFileSystem, new_file_contents: InMemoryFileSystem) -> str:
- with (
- tempfile.TemporaryDirectory() as old_repo_dir,
- tempfile.TemporaryDirectory() as new_repo_dir,
- ):
- # Get all changed file contents to prevent writing more than necessary
- changed_old_file_contents_dict = {}
- changed_new_file_contents_dict = {}
- old_file_contents_dict = old_file_contents.files
- new_file_contents_dict = new_file_contents.files
- for file_path in old_file_contents_dict.keys() | new_file_contents_dict.keys():
- if file_path not in old_file_contents_dict:
- changed_new_file_contents_dict[file_path] = new_file_contents_dict[file_path]
- elif file_path not in new_file_contents_dict:
- changed_old_file_contents_dict[file_path] = old_file_contents_dict[file_path]
- elif old_file_contents_dict[file_path] != new_file_contents_dict[file_path]:
- changed_old_file_contents_dict[file_path] = old_file_contents_dict[file_path]
- changed_new_file_contents_dict[file_path] = new_file_contents_dict[file_path]
-
- changed_old_file_contents = InMemoryFileSystem.build(changed_old_file_contents_dict)
- changed_new_file_contents = InMemoryFileSystem.build(changed_new_file_contents_dict)
-
- await write_file_contents_to_dir(changed_old_file_contents, old_repo_dir)
- await write_file_contents_to_dir(changed_new_file_contents, new_repo_dir)
-
- try:
- result = subprocess.run(
- (
- "git",
- "diff",
- "--no-index",
- "--relative",
- "--full-index",
- "--binary",
- old_repo_dir,
- new_repo_dir,
- ),
- capture_output=True,
- text=True,
- timeout=10.0,
- )
- if result.returncode == 0 or result.returncode == 1:
- diff = result.stdout
- else:
- raise NonZeroReturncodeError(f"git diff process returned with non-zero returncode {result.returncode}")
- except Exception as e:
- raise DiffCalculationError from e
-
- diff = diff.replace(old_repo_dir, "")
- diff = diff.replace(new_repo_dir, "")
-
- return diff
@alru_cache
diff --git a/imbue_tools/imbue_tools/repo_utils/errors.py b/imbue_tools/imbue_tools/repo_utils/errors.py
@@ -1,6 +1,3 @@
-from imbue_core.errors import ExpectedError
-
-
class PromptAssemblyError(Exception):
"""Raised when there is an error assembling the prompt."""
@@ -9,17 +6,5 @@ class ContextLengthExceededError(PromptAssemblyError):
"""Raised when the context length exceeds the maximum allowed length."""
-class InvalidVersionedConfigError(ExpectedError):
- pass
-
-
-class MissingVersionedConfigError(ExpectedError):
- pass
-
-
class DiffApplicationError(Exception):
pass
-
-
-class DiffCalculationError(Exception):
- pass
diff --git a/imbue_tools/imbue_tools/repo_utils/file_system.py b/imbue_tools/imbue_tools/repo_utils/file_system.py
@@ -4,7 +4,6 @@ from typing import Mapping
from imbue_core.frozen_utils import FrozenDict
from imbue_core.frozen_utils import deep_freeze_mapping
from imbue_core.pydantic_serialization import SerializableModel
-from imbue_tools.repo_utils.subrepo_formatting import BaseFilenamePattern
class SymlinkContents(SerializableModel):
@@ -70,33 +69,3 @@ def _try_decode_file_contents(contents: FileContents) -> DecodedTextFileContents
return contents.decode("utf-8")
except UnicodeDecodeError:
return None
-
-
-def filter_files_patterns(
- file_system: InMemoryFileSystem,
- include_patterns: tuple[str, ...] | None = None,
- exclude_patterns: tuple[str, ...] | None = None,
-) -> InMemoryFileSystem:
- """Filter all files based on include/exclude patterns.
- If an include pattern is provided, only files that match the include pattern will be included.
- If an exclude pattern is provided, files that match the exclude pattern will be excluded. If no include or exclude patterns are provided, all files will be included.
-
- Args:
- file_system: The file system to filter
- include_patterns: Glob patterns for files to include
- exclude_patterns: Glob patterns for files to exclude
-
- Returns:
- Filtered InMemoryFileSystem
- """
- include_spec = BaseFilenamePattern.from_lines(include_patterns or ())
- exclude_spec = BaseFilenamePattern.from_lines(exclude_patterns or ())
- filtered_files = {
- file_path: content
- for file_path, content in file_system.files.items()
- if (
- (not include_patterns or include_spec.match_file(file_path))
- and (not exclude_patterns or not exclude_spec.match_file(file_path))
- )
- }
- return InMemoryFileSystem.build(filtered_files)
diff --git a/imbue_tools/imbue_tools/repo_utils/find_relative_to.py b/imbue_tools/imbue_tools/repo_utils/find_relative_to.py
@@ -1,30 +0,0 @@
-from pathlib import Path
-
-from imbue_core.simple_git import SyncLocalGitRepo
-
-
-def find_relative_to_commit_hash(relative_to: str, repo_path: Path) -> str:
- """
- Find the commit hash to use as the source to compare against.
- - If relative_to is "HEAD", it will return the current commit hash.
- - If relative_to is a branch name, it will find the last common ancestor between that branch and the current state.
- - If relative_to is a commit hash or tag, it will return that commit hash.
- """
- repo = SyncLocalGitRepo(repo_path)
- if relative_to.startswith("HEAD"):
- # The current commit hash or relative to it (e.g. "HEAD~1")
- base_commit = repo.run_git(["rev-parse", relative_to], check=True)
- else:
- # Check if relative_to is the name of a branch.
- is_branch = repo.is_commit_a_branch(relative_to)
- if is_branch:
- # Yes, it's a branch.
- # Since we're comparing to a branch, this command will find the last common ancestor
- # between that branch and the current state. This is typically what we want for branches.
- # (Think of this as getting the diff that would be applied if this branch was to be merged into relative_to.)
- base_commit = repo.get_merge_base(relative_to, "HEAD")
- else:
- # Not a branch. relative_to might be a commit hash or tag.
- base_commit = relative_to
-
- return base_commit
diff --git a/imbue_tools/imbue_tools/repo_utils/stubify_file.py b/imbue_tools/imbue_tools/repo_utils/stubify_file.py
@@ -17,16 +17,6 @@ def check_on_body(stmt: cst.CSTNode, check: Callable[[cst.CSTNode], bool]) -> bo
return check(first_body_item)
-def is_assign(stmt: cst.CSTNode) -> bool:
- if not m.matches(stmt, m.SimpleStatementLine()):
- return False
- # pyre-ignore[16]: m.SimpleStatementLine has a body attribute which is a Sequence
- first_body_item = stmt.body[0]
- if m.matches(first_body_item, m.Assign()):
- return True
- return False
-
-
class CompressTransformer(cst.CSTTransformer):
DESCRIPTION = str = "Replaces function body with ..."
replacement_string = '"__FUNC_BODY_REPLACEMENT_STRING__"'
@@ -35,7 +25,9 @@ class CompressTransformer(cst.CSTTransformer):
self.keep_constant = keep_constant
self.keep_indent = keep_indent
- def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
+ def leave_Module(
+ self, original_node: cst.Module, updated_node: cst.Module
+ ) -> cst.Module:
new_body = [
stmt
for stmt in updated_node.body
@@ -43,12 +35,16 @@ class CompressTransformer(cst.CSTTransformer):
or m.matches(stmt, m.FunctionDef())
or (
self.keep_constant
- and check_on_body(stmt, lambda first_body_item: m.matches(first_body_item, m.Assign()))
+ and check_on_body(
+ stmt, lambda first_body_item: m.matches(first_body_item, m.Assign())
+ )
)
]
return updated_node.with_changes(body=new_body)
- def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
+ def leave_ClassDef(
+ self, original_node: cst.ClassDef, updated_node: cst.ClassDef
+ ) -> cst.ClassDef:
# Remove docstring in the class body
new_body = [
stmt
@@ -56,13 +52,18 @@ class CompressTransformer(cst.CSTTransformer):
if not check_on_body(
stmt,
lambda first_body_item: m.matches(first_body_item, m.Expr())
- or (hasattr(first_body_item, "value") and m.matches(first_body_item.value, m.SimpleString())),
+ or (
+ hasattr(first_body_item, "value")
+ and m.matches(first_body_item.value, m.SimpleString())
+ ),
)
]
# pyre-fixme[6]: cst.IndentedBlock has a body attribute which is a Sequence[BaseStatement], not a Sequence[BaseSmallStatement] like new_body
return updated_node.with_changes(body=cst.IndentedBlock(body=new_body))
- def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.BaseStatement:
+ def leave_FunctionDef(
+ self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
+ ) -> cst.BaseStatement:
if not self.keep_indent:
# replace with unindented statement
new_expr = cst.Expr(value=cst.SimpleString(value=self.replacement_string))
@@ -75,57 +76,9 @@ class CompressTransformer(cst.CSTTransformer):
new_expr = [
cst.Expr(value=cst.SimpleString(value=self.replacement_string)),
]
- return updated_node.with_changes(body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=new_expr)]))
-
-
-class GlobalVariableVisitor(cst.CSTVisitor):
- METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,)
-
- def __init__(self) -> None:
- self.assigns: list[list[Any]] = []
-
- def leave_Assign(self, original_node: cst.Assign) -> None:
- stmt = original_node
- start_pos = self.get_metadata(cst.metadata.PositionProvider, stmt).start
- end_pos = self.get_metadata(cst.metadata.PositionProvider, stmt).end
- self.assigns.append([stmt, start_pos, end_pos])
-
-
-def remove_lines(raw_code: str, remove_line_intervals: list[tuple[int, int]]) -> str:
- # TODO: speed up this function
- # remove_line_intervals.sort()
-
- # Remove lines
- new_code = ""
- for i, line in enumerate(raw_code.splitlines()):
- # intervals are one-based
- if not any(start <= i + 1 <= end for start, end in remove_line_intervals):
- new_code += line + "\n"
- if any(start == i + 1 for start, _ in remove_line_intervals):
- new_code += "...\n"
- return new_code
-
-
-def compress_assign_stmts(raw_code: str, total_lines: int = 30, prefix_lines: int = 10, suffix_lines: int = 10) -> str:
- try:
- tree = cst.parse_module(raw_code)
- except cst.ParserSyntaxError as e:
- logger.debug(
- "Failed to compress assign statements: {exception_class}, {exception}",
- exception_class=e.__class__.__name__,
- exception=e,
- )
- return raw_code
-
- wrapper = cst.metadata.MetadataWrapper(tree)
- visitor = GlobalVariableVisitor()
- wrapper.visit(visitor)
-
- remove_line_intervals = []
- for stmt in visitor.assigns:
- if stmt[2].line - stmt[1].line > total_lines:
- remove_line_intervals.append((stmt[1].line + prefix_lines, stmt[2].line - suffix_lines))
- return remove_lines(raw_code, remove_line_intervals)
+ return updated_node.with_changes(
+ body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=new_expr)])
+ )
def stubify_code_file(
@@ -133,10 +86,6 @@ def stubify_code_file(
raw_code: str,
keep_constant: bool = True,
keep_indent: bool = False,
- compress_assign: bool = False,
- total_lines: int = 30,
- prefix_lines: int = 10,
- suffix_lines: int = 10,
) -> str:
try:
tree = cst.parse_module(raw_code)
@@ -148,14 +97,6 @@ def stubify_code_file(
modified_tree = tree.visit(transformer)
code = modified_tree.code
- if compress_assign:
- code = compress_assign_stmts(
- code,
- total_lines=total_lines,
- prefix_lines=prefix_lines,
- suffix_lines=suffix_lines,
- )
-
if keep_indent:
code = code.replace(CompressTransformer.replacement_string + "\n", "...\n")
code = code.replace(CompressTransformer.replacement_string, "...\n")
diff --git a/imbue_tools/imbue_tools/repo_utils/subrepo_formatting.py b/imbue_tools/imbue_tools/repo_utils/subrepo_formatting.py
@@ -1,5 +1,4 @@
import functools
-from collections import defaultdict
from enum import Enum
from typing import Annotated
from typing import Iterable
@@ -239,28 +238,6 @@ def format_subrepo(formatted_repo_contents: Mapping[str, str]) -> str:
return escape_all_jinja_variables(escape_prompt_markers(repo_context_str))
-def format_subrepo_context_full(repo: Mapping[str, str]) -> str:
- """Like get_repo_context but there's no checking for context limits (so we can use the api checks instead)
- and the selected strategy is always the full repo contents."""
- formatted_repo_context = format_subrepo(
- format_all_for_agent(
- format_subrepo_context_into_filecontexts(
- full_repo_contents=repo,
- path_to_format_style=defaultdict(lambda: ContextFormatStyle.FULL_FILE, {}),
- )
- )
- )
-
- repo_context_core_prompt = formatted_subrepo_to_prompt(
- repo_context_str=formatted_repo_context,
- is_shortened=False,
- has_hidden_files=False,
- template=BASE_REPO_CONTEXT_TEMPLATE,
- )
-
- return repo_context_core_prompt
-
-
def formatted_subrepo_to_prompt(
repo_context_str: str, is_shortened: bool, has_hidden_files: bool, template: str
) -> str:
diff --git a/vet/api.py b/vet/api.py
@@ -39,7 +39,7 @@ def get_issues_with_raw_responses(
# should be not None and not empty
if conversation_history:
try:
- # TODO: we use the imbue verify config here, but we may want to configure this separately
+ # TODO: we use the config here, but we may want to configure this separately
goal = get_goal_from_conversation(conversation_history, config.language_model_generation_config)
logger.info("Generated goal from conversation history: {}", goal)
except Exception as e:
diff --git a/vet/conftest.py b/vet/conftest.py
@@ -9,8 +9,7 @@ from imbue_core.test_repo_utils import make_simple_test_git_repo
simple_test_git_repo = pytest.fixture(make_simple_test_git_repo)
-# this is copied from sculptor/conftest.py
-# (it must be copied rather than imported because of the autouse)
+# This fixture must be defined locally (not imported) because of the autouse flag.
@pytest.fixture(autouse=True)
def always_explode_on_error(
explode_on_error: Callable[[], Generator[None, None, None]],
diff --git a/vet/errors.py b/vet/errors.py
@@ -1,2 +1,19 @@
+import subprocess
+from typing import Any
+
+
class GitException(Exception):
pass
+
+
+class RunCommandError(subprocess.CalledProcessError):
+ """Custom exception for errors encountered during shell commands."""
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ self.cwd = kwargs.get("cwd", None)
+ if "cwd" in kwargs:
+ del kwargs["cwd"]
+ super().__init__(*args, **kwargs)
+
+ def __str__(self) -> str:
+ return f"Command `{self.cmd}` returned non-zero exit status {self.returncode}.\nOutput: {self.stdout}\nError: {self.stderr}\nCWD: {self.cwd}"
diff --git a/vet/git.py b/vet/git.py
@@ -0,0 +1,236 @@
+"""Git utilities for vet."""
+
+import shlex
+import subprocess
+import time
+from pathlib import Path
+from typing import Sequence
+
+import anyio
+from loguru import logger
+
+from imbue_core.async_monkey_patches import log_exception
+from vet.errors import RunCommandError
+
+# Flexible path type alias
+AnyPath = Path | str | anyio.Path
+
+
+class SyncLocalGitRepo:
+ """
+ Provides different operations that you can perform over a git repository.
+ """
+
+ _base_path: Path
+
+ def __init__(self, base_path: Path) -> None:
+ self._base_path = base_path
+
+ @property
+ def base_path(self) -> Path:
+ """The base path of the git repo."""
+ return self._base_path
+
+ def run_git(
+ self,
+ command: Sequence[str],
+ check: bool = True,
+ cwd: AnyPath | None = None,
+ is_error_logged: bool = True,
+ is_stripped: bool = True,
+ retry_on_git_lock_error: bool = True,
+ ) -> str:
+ """Run a git command in the repo.
+
+ Example:
+ ```
+ git_repo.run_git("status")
+ ```
+ """
+ command = ["git"] + list(command)
+ if not retry_on_git_lock_error:
+ result = self.run_command(command, check=check, is_error_logged=is_error_logged, cwd=cwd)
+ else:
+ result = self._run_command_with_retry_on_git_lock_error(
+ command, check=check, is_error_logged=is_error_logged, cwd=cwd
+ )
+ if is_stripped:
+ return result.strip()
+ return result
+
+ def run_command(
+ self,
+ command: Sequence[str],
+ check: bool = True,
+ secrets: dict[str, str] | None = None,
+ cwd: AnyPath | None = None,
+ is_error_logged: bool = True,
+ ) -> str:
+ """Run a command in the repo.
+
+ Note, this can be used to run any command, not just git.
+ """
+ command_string = shlex.join(command)
+ logger.trace(
+ f"Running command: {command_string=} from cwd={cwd or self.base_path} with {secrets=} {check=} {is_error_logged=}"
+ )
+ completed_proc = subprocess.run(
+ command,
+ cwd=cwd or self._base_path,
+ stdin=subprocess.DEVNULL,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ env=secrets,
+ )
+ # note, need to be carefull not to strip() lines since whitespace may be important (e.g. for diffs)
+ # return joined lines since mostly we only use the output for logging, and this way we arn't
+ # passing around lots of lists. Also it's easy to parse by lines if needed
+ try:
+ stdout = completed_proc.stdout.decode("UTF-8")
+ except UnicodeDecodeError as e:
+ # If we don't encounter this, it likely means something was fixed upstream and we can safely delete
+ log_exception(
+ e,
+ "Command {command_string} failed to decode stdout, replacing any invalid bytes which could lead to problems later",
+ command_string=command_string,
+ )
+ stdout = completed_proc.stdout.decode("UTF-8", errors="replace")
+ stderr = completed_proc.stderr.decode("UTF-8")
+ if check and completed_proc.returncode != 0:
+ error_message = f"command run from cwd={self.base_path} failed with exit code {completed_proc.returncode} and stdout:\n{stdout}\nstderr:\n{stderr}"
+ if is_error_logged:
+ logger.error(
+ f"command attempted: '{command_string}' from cwd={self.base_path}\nerror message: {error_message}"
+ )
+ # this should not be None, but do this to satisfy type checker, int or None we throw the same error
+ returncode = completed_proc.returncode or -1
+ raise RunCommandError(
+ cmd=command_string,
+ stderr=stderr,
+ returncode=returncode,
+ cwd=cwd or self.base_path,
+ )
+ return stdout
+
+ def get_git_diff(
+ self,
+ commit_hash: str | None = None,
+ staged: bool = False,
+ is_error_logged: bool = True,
+ include_binary: bool = True,
+ ) -> str:
+ """Get the diff for the current repo state."""
+ # make sure `is_stripped=False` otherwise patch can be invalid
+ command = ["diff", "--full-index"]
+ if include_binary:
+ # Without --binary, diffs of binary files will just contain a summary statement such as "Binary files a/file.bin and b/file.bin differ".
+ # Such diffs cannot be applied, but are useful for inclusion in LLM prompts.
+ command.append("--binary")
+ if staged:
+ command.append("--staged")
+ if commit_hash:
+ command.append(commit_hash)
+ return self.run_git(command, is_stripped=False, is_error_logged=is_error_logged)
+
+ def get_untracked_files(self) -> tuple[str, ...]:
+ """Get the untracked files in the repo."""
+ result = self.run_git(["ls-files", "--others", "--exclude-standard"], is_error_logged=False)
+ return tuple([line.strip() for line in result.splitlines() if line.strip()])
+
+ def get_untracked_file_diff(self, file_path: str, include_binary: bool = True) -> str:
+ """Get the diff for a untracked file.
+
+ Note this function will raise a RunCommandError if the there is no diff for the untracked file or if there
+ is another error running the command. So it is best to use this function after checking that the file is untracked
+ using get_untracked_files function.
+ """
+ command = ["diff", "--no-index"]
+ if include_binary:
+ command.append("--binary")
+ untracked_diff = self.run_git(
+ command + ["/dev/null", str(file_path)],
+ # Unfortunately, `--no-index` implies `--exit-code`, which will cause git diff to return an exit code of 1
+ # if the diff is not empty. So we can't use check=True here. We'll check for an empty output to detect failures.
+ check=False,
+ is_error_logged=True,
+ is_stripped=False,
+ )
+ if not untracked_diff:
+ raise RunCommandError(f"Unable to diff untracked file {file_path}")
+ return untracked_diff
+
+ def is_commit_a_branch(self, commit_hash: str) -> bool:
+ """Check if the given git ref is a branch."""
+ try:
+ self.run_git(
+ ("show-ref", "--verify", "-q", f"refs/heads/{commit_hash}"),
+ is_error_logged=False,
+ check=True,
+ )
+ return True
+ except RunCommandError as e:
+ if e.returncode == 1:
+ return False
+ raise
+
+ def get_merge_base(self, branch_name: str, target_branch: str) -> str:
+ """Get the merge base of the given branch and target branch.
+
+ The merge base is the most recent commit that is on both branches.
+ """
+ return self.run_git(["merge-base", branch_name, target_branch], is_error_logged=False)
+
+ def _run_command_with_retry_on_git_lock_error(
+ self,
+ command: Sequence[str],
+ check: bool = True,
+ is_error_logged: bool = True,
+ cwd: AnyPath | None = None,
+ ) -> str:
+ max_retries = 50
+ retry_count = 0
+ retry_delay = 0.1 # seconds
+ while True:
+ try:
+ return self.run_command(
+ command,
+ check=check,
+ is_error_logged=is_error_logged and retry_count >= max_retries,
+ cwd=cwd,
+ )
+ except RunCommandError as e:
+ error_message = str(e)
+ if "fatal: Unable to create" in error_message and ".git/index.lock': File exists" in error_message:
+ if retry_count >= max_retries:
+ raise
+ time.sleep(retry_delay)
+ retry_count += 1
+ else:
+ raise
+
+
+def find_relative_to_commit_hash(relative_to: str, repo_path: Path) -> str:
+ """
+ Find the commit hash to use as the source to compare against.
+ - If relative_to is "HEAD", it will return the current commit hash.
+ - If relative_to is a branch name, it will find the last common ancestor between that branch and the current state.
+ - If relative_to is a commit hash or tag, it will return that commit hash.
+ """
+ repo = SyncLocalGitRepo(repo_path)
+ if relative_to.startswith("HEAD"):
+ # The current commit hash or relative to it (e.g. "HEAD~1")
+ base_commit = repo.run_git(["rev-parse", relative_to], check=True)
+ else:
+ # Check if relative_to is the name of a branch.
+ is_branch = repo.is_commit_a_branch(relative_to)
+ if is_branch:
+ # Yes, it's a branch.
+ # Since we're comparing to a branch, this command will find the last common ancestor
+ # between that branch and the current state. This is typically what we want for branches.
+ # (Think of this as getting the diff that would be applied if this branch was to be merged into relative_to.)
+ base_commit = repo.get_merge_base(relative_to, "HEAD")
+ else:
+ # Not a branch. relative_to might be a commit hash or tag.
+ base_commit = relative_to
+
+ return base_commit
diff --git a/vet/issue_identifiers/agentic_issue_collation.py b/vet/issue_identifiers/agentic_issue_collation.py
@@ -132,7 +132,7 @@ def collate_issues_with_agent(
issues: The issues to collate.
identifier_inputs: The inputs which determine the content provided to the identifiers.
project_context: Loaded data corresponding to the inputs, e.g. diffs or files.
- config: Settings for imbue verify.
+ config: Settings
enabled_issue_codes: The issue types used by the issue identifiers.
Returns:
diff --git a/vet/issue_identifiers/base.py b/vet/issue_identifiers/base.py
@@ -39,7 +39,7 @@ class IssueIdentifier(SerializableModel, abc.ABC, Generic[T]):
Args:
identifier_inputs: The inputs which determine the content provided to the identifier.
project_context: Loaded data corresponding to the inputs, e.g. diffs or files.
- config: Settings for imbue verify.
+ config: Settings
Returns:
A generator of identified issues. When done iterating, returns the debug info.
diff --git a/vet/issue_identifiers/issue_deduplication.py b/vet/issue_identifiers/issue_deduplication.py
@@ -119,7 +119,7 @@ def deduplicate_issues(
Args:
issues: The issues to deduplicate.
- config: Settings for imbue verify.
+ config: Settings
enabled_issue_codes: The issue types used by the issue identifiers.
Returns:
diff --git a/vet/issue_identifiers/issue_evaluation.py b/vet/issue_identifiers/issue_evaluation.py
@@ -266,7 +266,7 @@ def filter_issues(
results: The issues to filter.
inputs: The inputs which determine the content provided to the evaluator.
project_context: Loaded data corresponding to the inputs, e.g. diffs or files.
- config: Settings for imbue verify.
+ config: Settings
Returns:
A generator of issues with the passes_filtration flag set.
diff --git a/vet/repo_utils.py b/vet/repo_utils.py
@@ -1,10 +1,10 @@
from pathlib import Path
from imbue_core.async_monkey_patches import log_exception
-from imbue_core.computing_environment.data_types import RunCommandError
-from imbue_core.simple_git import SyncLocalGitRepo
-from imbue_tools.repo_utils.find_relative_to import find_relative_to_commit_hash
from vet.errors import GitException
+from vet.errors import RunCommandError
+from vet.git import SyncLocalGitRepo
+from vet.git import find_relative_to_commit_hash
# Maximum length of LLM prompts used within vet in tokens, without the repository-specific context.
# Currently, the prompt is well under 10k tokens, but this value might need to be bumped up if we add a lot of additional
diff --git a/vet_types/vet_types/chat_state.py b/vet_types/vet_types/chat_state.py
@@ -56,9 +56,7 @@ class CommandBlock(ContentBlock):
object_type: str = "CommandBlock"
type: Literal["command"] = "command"
command: str
- is_automated: bool = Field(
- default=False, description="Whether the command is automated"
- )
+ is_automated: bool = Field(default=False, description="Whether the command is automated")
ToolInput = dict[str, Any]
@@ -69,17 +67,13 @@ class ToolUseBlock(ContentBlock):
type: Literal["tool_use"] = "tool_use"
id: ToolUseID = Field(..., description="Unique identifier for this tool use")
name: str = Field(..., description="Name of the tool being used")
- input: ToolInput = Field(
- default_factory=ToolInput, description="Input parameters for the tool"
- )
+ input: ToolInput = Field(default_factory=ToolInput, description="Input parameters for the tool")
class ToolResultContent(SerializableModel):
"""Base class for tool result content with type discriminator"""
- content_type: str = Field(
- ..., description="Type discriminator for tool result content"
- )
+ content_type: str = Field(..., description="Type discriminator for tool result content")
class SimpleToolContent(ToolResultContent):
@@ -114,15 +108,9 @@ class ToolResultBlock(ContentBlock):
type: Literal["tool_result"] = "tool_result"
tool_use_id: ToolUseID = Field(..., description="ID of the corresponding tool use")
tool_name: str = Field(..., description="Name of the tool that was used")
- invocation_string: str = Field(
- ..., description="String representation of how the tool was invoked"
- )
- content: ToolResultContentType = Field(
- ..., description="Result content from the tool execution"
- )
- is_error: bool = Field(
- default=False, description="Whether the tool execution resulted in an error"
- )
+ invocation_string: str = Field(..., description="String representation of how the tool was invoked")
+ content: ToolResultContentType = Field(..., description="Result content from the tool execution")
+ is_error: bool = Field(default=False, description="Whether the tool execution resulted in an error")
class WarningBlock(ContentBlock):
@@ -130,9 +118,7 @@ class WarningBlock(ContentBlock):
type: Literal["warning"] = "warning"
message: str = Field(..., description="Warning message")
traceback: str | None = Field(..., description="Warning traceback")
- warning_type: str | None = Field(
- ..., description="Type of warning, i.e. name of the exception that was raised"
- )
+ warning_type: str | None = Field(..., description="Type of warning, i.e. name of the exception that was raised")
class ErrorBlock(ContentBlock):
@@ -140,9 +126,7 @@ class ErrorBlock(ContentBlock):
type: Literal["error"] = "error"
message: str = Field(..., description="Error message")
traceback: str = Field(..., description="Error traceback")
- error_type: str = Field(
- ..., description="Type of error, i.e. name of the exception that was raised"
- )
+ error_type: str = Field(..., description="Type of error, i.e. name of the exception that was raised")
class FileBlock(ContentBlock):
diff --git a/vet_types/vet_types/ids.py b/vet_types/vet_types/ids.py
@@ -32,23 +32,17 @@ class ObjectID(TypeID, ABC):
# For convenience, don't require the caller to strip the prefix from existing IDs.
if prefix is not None:
if prefix != self.tag:
- raise TypeIDPrefixMismatchError(
- f"Expected prefix '{self.tag}', got '{prefix}'"
- )
+ raise TypeIDPrefixMismatchError(f"Expected prefix '{self.tag}', got '{prefix}'")
value = suffix
super().__init__(self.tag, value)
@classmethod
- def __get_pydantic_core_schema__(
- cls, source_type: type, handler: GetCoreSchemaHandler
- ) -> core_schema.CoreSchema:
+ def __get_pydantic_core_schema__(cls, source_type: type, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
"""
Support transparently deserializing strings into ObjectID instances and vice versa.
"""
return core_schema.no_info_before_validator_function(
- lambda raw_value: (
- cls(raw_value) if isinstance(raw_value, str) else raw_value
- ),
+ lambda raw_value: (cls(raw_value) if isinstance(raw_value, str) else raw_value),
core_schema.union_schema(
[
core_schema.is_instance_schema(cls),
@@ -77,16 +71,12 @@ class NonEmptyStr(str):
return value
@classmethod
- def __get_pydantic_core_schema__(
- cls, source_type: type, handler: GetCoreSchemaHandler
- ) -> core_schema.CoreSchema:
+ def __get_pydantic_core_schema__(cls, source_type: type, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
"""
Support transparently deserializing strings into ObjectID instances and vice versa.
"""
return core_schema.no_info_before_validator_function(
- lambda raw_value: (
- cls(raw_value) if isinstance(raw_value, str) else raw_value
- ),
+ lambda raw_value: (cls(raw_value) if isinstance(raw_value, str) else raw_value),
core_schema.union_schema(
[
core_schema.is_instance_schema(cls),
diff --git a/vet_types/vet_types/messages.py b/vet_types/vet_types/messages.py
@@ -36,20 +36,18 @@ class LLMModel(StrEnum):
class AgentMessageSource(StrEnum):
"""
- Messages can come the AGENT (in-container LLM), USER (chat messages & direct interactions),
- SCULPTOR_SYSTEM (multifaceted sculptor app and service code) and RUNNER (the process
- controlling a task on the server.)
+ Messages can come from the AGENT (LLM), USER (chat messages & direct interactions),
+ SYSTEM (app and service code) and RUNNER (the process controlling a task).
"""
- # Messages coming directly from the agent from inside the environment.
+ # Messages coming directly from the agent.
AGENT = "AGENT"
- # Messages coming directly from a user interacting with the interface, ie chat
+ # Messages coming directly from a user interacting with the interface, ie chat.
USER = "USER"
- # Messages coming from sculptor-mediated actions and automations, like local sync updates
- # or manual sync operations.
- SCULPTOR_SYSTEM = "SCULPTOR_SYSTEM"
+ # Messages coming from system-mediated actions and automations.
+ SYSTEM = "SYSTEM"
# Messages coming from the task runner wrapper, such as environment shutdown.
RUNNER = "RUNNER"
@@ -65,15 +63,11 @@ class Message(SerializableModel):
# the source of the message, which can be either the agent, user, or runner.
source: AgentMessageSource
# roughly when the message was created, in UTC.
- approximate_creation_time: datetime.datetime = Field(
- default_factory=get_current_time
- )
+ approximate_creation_time: datetime.datetime = Field(default_factory=get_current_time)
@property
def is_ephemeral(self) -> bool:
- raise NotImplementedError(
- "All messages must be subclassed off of PersistentMessage or EphemeralMessage"
- )
+ raise NotImplementedError("All messages must be subclassed off of PersistentMessage or EphemeralMessage")
class PersistentMessage(Message):