vet

Mirror of Vet, an AI code review tool
git clone git://git.laack.co/vet.git
Log | Files | Refs | README | LICENSE

transport.py (6517B)


      1 import json
      2 import os
      3 import subprocess
      4 import threading
      5 from abc import ABC
      6 from abc import abstractmethod
      7 from contextlib import contextmanager
      8 from pathlib import Path
      9 from subprocess import PIPE
     10 from typing import Any
     11 from typing import ContextManager
     12 from typing import Generator
     13 from typing import Generic
     14 from typing import Iterable
     15 from typing import Iterator
     16 from typing import Self
     17 from typing import Sequence
     18 from typing import TypeVar
     19 
     20 from vet.imbue_core.agents.agent_api.data_types import AgentOptions
     21 from vet.imbue_core.agents.agent_api.errors import AgentCLIConnectionError
     22 from vet.imbue_core.agents.agent_api.errors import AgentCLIJSONDecodeError as SDKJSONDecodeError
     23 from vet.imbue_core.agents.agent_api.errors import AgentCLINotFoundError
     24 from vet.imbue_core.agents.agent_api.errors import AgentProcessError
     25 from vet.imbue_core.pydantic_serialization import SerializableModel
     26 
     27 TransportOptionsT = TypeVar("TransportOptionsT", bound=SerializableModel)
     28 
     29 
     30 class AgentTransport(ABC, Generic[TransportOptionsT]):
     31     """Abstract transport for Agent communication."""
     32 
     33     @classmethod
     34     @abstractmethod
     35     def build(cls, options: TransportOptionsT) -> ContextManager[Self]:
     36         """Build a transport from options.
     37 
     38         This is the main entry point for building a transport and managing its lifecycle.
     39         """
     40 
     41     @abstractmethod
     42     def send_request(self, messages: list[Any], agent_options: AgentOptions) -> None:
     43         """Send request to underlying agent via transport."""
     44 
     45     @abstractmethod
     46     def receive_messages(self) -> Iterator[dict[str, Any]]:
     47         """Receive messages from underlying agent via transport."""
     48 
     49     @abstractmethod
     50     def is_connected(self) -> bool:
     51         """Check if transport is connected."""
     52 
     53 
     54 class AgentSubprocessCLITransportOptions(SerializableModel):
     55     """Options for AgentSubprocessCLITransport."""
     56 
     57     cmd: Sequence[str]
     58     cwd: str | Path | None = None
     59     extra_env_vars: dict[str, str] | None = None
     60 
     61 
     62 class AgentSubprocessCLITransport(AgentTransport[AgentSubprocessCLITransportOptions]):
     63     """Subprocess transport using Coding Agent via a CLI."""
     64 
     65     def __init__(
     66         self,
     67         popen: subprocess.Popen[str],
     68     ) -> None:
     69         self._process = popen
     70         self._stdin_stream = popen.stdin
     71         self._stdout_stream = popen.stdout
     72         self._stderr_stream = popen.stderr
     73 
     74     @classmethod
     75     @contextmanager
     76     def build(cls, options: AgentSubprocessCLITransportOptions) -> Generator[Self, None, None]:
     77         extra_env_vars = options.extra_env_vars or {}
     78         try:
     79             popen = subprocess.Popen(
     80                 options.cmd,
     81                 stdin=PIPE,
     82                 stdout=PIPE,
     83                 stderr=PIPE,
     84                 cwd=options.cwd,
     85                 env={**os.environ, **extra_env_vars},
     86                 # ensure output is line buffered
     87                 bufsize=1,
     88                 text=True,
     89                 encoding="utf-8",
     90             )
     91         except FileNotFoundError as e:
     92             raise AgentCLINotFoundError(f"Agent CLI not found for: cmd={options.cmd}") from e
     93         except Exception as e:
     94             raise AgentCLIConnectionError(f"Failed to start Agent CLI via cmd={options.cmd}: {e}") from e
     95 
     96         try:
     97             yield cls(popen)
     98         finally:
     99             # Make sure to terminate the process if it is still running, and clean up the streams
    100             if popen.poll() is None:
    101                 try:
    102                     popen.terminate()
    103                     popen.wait(timeout=5.0)
    104                 except subprocess.TimeoutExpired:
    105                     popen.kill()
    106                     popen.wait(timeout=5.0)
    107             popen.stdout and popen.stdout.close()
    108             popen.stderr and popen.stderr.close()
    109             popen.stdin and popen.stdin.close()
    110 
    111     def send_request(self, messages: Iterable[dict[str, Any] | str], agent_options: AgentOptions) -> None:
    112         process = self._process
    113         stdin_stream = self._stdin_stream
    114         if not process or not stdin_stream:
    115             raise AgentCLIConnectionError("Not connected")
    116 
    117         for message in messages:
    118             stdin_stream.write(json.dumps(message) + "\n")
    119             stdin_stream.flush()
    120 
    121     def write_stdin(self, text: str) -> None:
    122         stdin_stream = self._stdin_stream
    123         if not self._process or not stdin_stream:
    124             raise AgentCLIConnectionError("Not connected")
    125 
    126         stdin_stream.write(text)
    127         stdin_stream.flush()
    128         stdin_stream.close()
    129         self._stdin_stream = None
    130 
    131     def _read_stderr(self, output_buffer: list[str]) -> None:
    132         """Read stderr in background."""
    133         stderr_stream = self._stderr_stream
    134         if stderr_stream:
    135             try:
    136                 for line in stderr_stream:
    137                     output_buffer.append(line.strip())
    138             except subprocess.SubprocessError:
    139                 pass
    140 
    141     def receive_messages(self) -> Iterator[dict[str, Any]]:
    142         process = self._process
    143         stdout_stream = self._stdout_stream
    144         if not process or not stdout_stream:
    145             raise AgentCLIConnectionError("Not connected")
    146 
    147         stderr_lines: list[str] = []
    148         stderr_read_thread = threading.Thread(target=self._read_stderr, args=(stderr_lines,))
    149         stderr_read_thread.start()
    150 
    151         try:
    152             for line in stdout_stream:
    153                 line_str = line.strip()
    154                 if not line_str:
    155                     continue
    156 
    157                 try:
    158                     data = json.loads(line_str)
    159                     try:
    160                         yield data
    161                     except GeneratorExit:
    162                         # Handle generator cleanup gracefully
    163                         return
    164                 except json.JSONDecodeError as e:
    165                     if line_str.startswith("{") or line_str.startswith("["):
    166                         raise SDKJSONDecodeError(line_str, e) from e
    167                     continue
    168 
    169         except subprocess.SubprocessError:
    170             pass
    171 
    172         stderr_read_thread.join(timeout=5.0)
    173         process.wait()
    174         if process.returncode is not None and process.returncode != 0:
    175             stderr_output = "\n".join(stderr_lines)
    176             raise AgentProcessError(
    177                 "CLI process failed",
    178                 exit_code=process.returncode,
    179                 stderr=stderr_output,
    180             )
    181 
    182     def is_connected(self) -> bool:
    183         process = self._process
    184         return process is not None and process.returncode is None