vet

Mirror of Vet, an AI code review tool
git clone git://git.laack.co/vet.git
Log | Files | Refs | README | LICENSE

interaction.py (4363B)


      1 from typing import Sequence
      2 
      3 from vet.imbue_core.agents.agent_api.data_types import AgentAssistantMessage
      4 from vet.imbue_core.agents.agent_api.data_types import AgentMessage
      5 from vet.imbue_core.agents.agent_api.data_types import AgentOptions
      6 from vet.imbue_core.agents.agent_api.data_types import AgentToolResultBlock
      7 from vet.imbue_core.agents.agent_api.data_types import AgentToolUseBlock
      8 from vet.imbue_core.agents.agent_api.data_types import AgentUserMessage
      9 from vet.imbue_core.agents.agent_api.data_types import ToolUseRecord
     10 from vet.imbue_core.pydantic_serialization import SerializableModel
     11 
     12 
     13 class AgentInteraction:
     14     """A class for tracking an ongoing interaction with an agent.
     15 
     16     Note that this class is not thread-safe.
     17     """
     18 
     19     def __init__(self, prompt: str, options: AgentOptions) -> None:
     20         self.prompt = prompt
     21         self.options = options
     22         self.messages: list[AgentMessage] = []
     23         self.tool_use_records: list[ToolUseRecord] = []
     24         self._unresolved_tool_use_requests: list[AgentToolUseBlock] = []
     25 
     26     def put(self, message: AgentMessage) -> None:
     27         self.messages.append(message)
     28 
     29         if isinstance(message, AgentAssistantMessage):
     30             for assistant_content_block in message.content:
     31                 if isinstance(assistant_content_block, AgentToolUseBlock):
     32                     self._unresolved_tool_use_requests.append(assistant_content_block)
     33         elif isinstance(message, AgentUserMessage) and isinstance(message.content, list):
     34             for content_block in message.content:
     35                 if isinstance(content_block, AgentToolResultBlock):
     36                     remaining_unresolved_requests = []
     37                     for request in self._unresolved_tool_use_requests:
     38                         if request.id == content_block.tool_use_id:
     39                             self.tool_use_records.append(
     40                                 ToolUseRecord(
     41                                     request_message=request,
     42                                     result_message=content_block,
     43                                 )
     44                             )
     45                         else:
     46                             remaining_unresolved_requests.append(request)
     47                     self._unresolved_tool_use_requests = remaining_unresolved_requests
     48 
     49     def find_tool_use_record_by_command(self, command: str, by_most_recent: bool = True) -> ToolUseRecord | None:
     50         """Look for tool use request and result messages by the tool command.
     51 
     52         If by_most_recent is True, the records are searched in reverse order (most recent first).
     53         """
     54         return _find_tool_use_record_by_command(self.tool_use_records, command, by_most_recent)
     55 
     56 
     57 class AgentInteractionRecord(SerializableModel):
     58     """A serializable record of a completed agent interaction.
     59 
     60     This is meant to be used for storing a completed log in a database or cache.
     61     """
     62 
     63     prompt: str
     64     options: AgentOptions
     65     messages: tuple[AgentMessage, ...]
     66     tool_use_records: tuple[ToolUseRecord, ...]
     67 
     68     @classmethod
     69     def from_agent_interaction(cls, agent_interaction: AgentInteraction) -> "AgentInteractionRecord":
     70         return cls(
     71             prompt=agent_interaction.prompt,
     72             options=agent_interaction.options,
     73             messages=tuple(agent_interaction.messages),
     74             tool_use_records=tuple(agent_interaction.tool_use_records),
     75         )
     76 
     77     def find_tool_use_record_by_command(self, command: str, by_most_recent: bool = True) -> ToolUseRecord | None:
     78         """Look for tool use request and result messages by the tool command.
     79 
     80         If by_most_recent is True, the records are searched in reverse order (most recent first).
     81         """
     82         return _find_tool_use_record_by_command(self.tool_use_records, command, by_most_recent)
     83 
     84 
     85 def _find_tool_use_record_by_command(
     86     tool_use_records: Sequence[ToolUseRecord], command: str, reverse: bool = True
     87 ) -> ToolUseRecord | None:
     88     """Look for tool use request and result messages by the tool command.
     89 
     90     If reverse is True, the records are searched in reverse order (most recent first).
     91     """
     92     for record in reversed(tool_use_records) if reverse else tool_use_records:
     93         tool_input = record.tool_input
     94         if "command" in tool_input and tool_input["command"] == command:
     95             return record
     96     return None