transport.py (6517B)
1 import json 2 import os 3 import subprocess 4 import threading 5 from abc import ABC 6 from abc import abstractmethod 7 from contextlib import contextmanager 8 from pathlib import Path 9 from subprocess import PIPE 10 from typing import Any 11 from typing import ContextManager 12 from typing import Generator 13 from typing import Generic 14 from typing import Iterable 15 from typing import Iterator 16 from typing import Self 17 from typing import Sequence 18 from typing import TypeVar 19 20 from vet.imbue_core.agents.agent_api.data_types import AgentOptions 21 from vet.imbue_core.agents.agent_api.errors import AgentCLIConnectionError 22 from vet.imbue_core.agents.agent_api.errors import AgentCLIJSONDecodeError as SDKJSONDecodeError 23 from vet.imbue_core.agents.agent_api.errors import AgentCLINotFoundError 24 from vet.imbue_core.agents.agent_api.errors import AgentProcessError 25 from vet.imbue_core.pydantic_serialization import SerializableModel 26 27 TransportOptionsT = TypeVar("TransportOptionsT", bound=SerializableModel) 28 29 30 class AgentTransport(ABC, Generic[TransportOptionsT]): 31 """Abstract transport for Agent communication.""" 32 33 @classmethod 34 @abstractmethod 35 def build(cls, options: TransportOptionsT) -> ContextManager[Self]: 36 """Build a transport from options. 37 38 This is the main entry point for building a transport and managing its lifecycle. 39 """ 40 41 @abstractmethod 42 def send_request(self, messages: list[Any], agent_options: AgentOptions) -> None: 43 """Send request to underlying agent via transport.""" 44 45 @abstractmethod 46 def receive_messages(self) -> Iterator[dict[str, Any]]: 47 """Receive messages from underlying agent via transport.""" 48 49 @abstractmethod 50 def is_connected(self) -> bool: 51 """Check if transport is connected.""" 52 53 54 class AgentSubprocessCLITransportOptions(SerializableModel): 55 """Options for AgentSubprocessCLITransport.""" 56 57 cmd: Sequence[str] 58 cwd: str | Path | None = None 59 extra_env_vars: dict[str, str] | None = None 60 61 62 class AgentSubprocessCLITransport(AgentTransport[AgentSubprocessCLITransportOptions]): 63 """Subprocess transport using Coding Agent via a CLI.""" 64 65 def __init__( 66 self, 67 popen: subprocess.Popen[str], 68 ) -> None: 69 self._process = popen 70 self._stdin_stream = popen.stdin 71 self._stdout_stream = popen.stdout 72 self._stderr_stream = popen.stderr 73 74 @classmethod 75 @contextmanager 76 def build(cls, options: AgentSubprocessCLITransportOptions) -> Generator[Self, None, None]: 77 extra_env_vars = options.extra_env_vars or {} 78 try: 79 popen = subprocess.Popen( 80 options.cmd, 81 stdin=PIPE, 82 stdout=PIPE, 83 stderr=PIPE, 84 cwd=options.cwd, 85 env={**os.environ, **extra_env_vars}, 86 # ensure output is line buffered 87 bufsize=1, 88 text=True, 89 encoding="utf-8", 90 ) 91 except FileNotFoundError as e: 92 raise AgentCLINotFoundError(f"Agent CLI not found for: cmd={options.cmd}") from e 93 except Exception as e: 94 raise AgentCLIConnectionError(f"Failed to start Agent CLI via cmd={options.cmd}: {e}") from e 95 96 try: 97 yield cls(popen) 98 finally: 99 # Make sure to terminate the process if it is still running, and clean up the streams 100 if popen.poll() is None: 101 try: 102 popen.terminate() 103 popen.wait(timeout=5.0) 104 except subprocess.TimeoutExpired: 105 popen.kill() 106 popen.wait(timeout=5.0) 107 popen.stdout and popen.stdout.close() 108 popen.stderr and popen.stderr.close() 109 popen.stdin and popen.stdin.close() 110 111 def send_request(self, messages: Iterable[dict[str, Any] | str], agent_options: AgentOptions) -> None: 112 process = self._process 113 stdin_stream = self._stdin_stream 114 if not process or not stdin_stream: 115 raise AgentCLIConnectionError("Not connected") 116 117 for message in messages: 118 stdin_stream.write(json.dumps(message) + "\n") 119 stdin_stream.flush() 120 121 def write_stdin(self, text: str) -> None: 122 stdin_stream = self._stdin_stream 123 if not self._process or not stdin_stream: 124 raise AgentCLIConnectionError("Not connected") 125 126 stdin_stream.write(text) 127 stdin_stream.flush() 128 stdin_stream.close() 129 self._stdin_stream = None 130 131 def _read_stderr(self, output_buffer: list[str]) -> None: 132 """Read stderr in background.""" 133 stderr_stream = self._stderr_stream 134 if stderr_stream: 135 try: 136 for line in stderr_stream: 137 output_buffer.append(line.strip()) 138 except subprocess.SubprocessError: 139 pass 140 141 def receive_messages(self) -> Iterator[dict[str, Any]]: 142 process = self._process 143 stdout_stream = self._stdout_stream 144 if not process or not stdout_stream: 145 raise AgentCLIConnectionError("Not connected") 146 147 stderr_lines: list[str] = [] 148 stderr_read_thread = threading.Thread(target=self._read_stderr, args=(stderr_lines,)) 149 stderr_read_thread.start() 150 151 try: 152 for line in stdout_stream: 153 line_str = line.strip() 154 if not line_str: 155 continue 156 157 try: 158 data = json.loads(line_str) 159 try: 160 yield data 161 except GeneratorExit: 162 # Handle generator cleanup gracefully 163 return 164 except json.JSONDecodeError as e: 165 if line_str.startswith("{") or line_str.startswith("["): 166 raise SDKJSONDecodeError(line_str, e) from e 167 continue 168 169 except subprocess.SubprocessError: 170 pass 171 172 stderr_read_thread.join(timeout=5.0) 173 process.wait() 174 if process.returncode is not None and process.returncode != 0: 175 stderr_output = "\n".join(stderr_lines) 176 raise AgentProcessError( 177 "CLI process failed", 178 exit_code=process.returncode, 179 stderr=stderr_output, 180 ) 181 182 def is_connected(self) -> bool: 183 process = self._process 184 return process is not None and process.returncode is None