interaction.py (4363B)
1 from typing import Sequence 2 3 from vet.imbue_core.agents.agent_api.data_types import AgentAssistantMessage 4 from vet.imbue_core.agents.agent_api.data_types import AgentMessage 5 from vet.imbue_core.agents.agent_api.data_types import AgentOptions 6 from vet.imbue_core.agents.agent_api.data_types import AgentToolResultBlock 7 from vet.imbue_core.agents.agent_api.data_types import AgentToolUseBlock 8 from vet.imbue_core.agents.agent_api.data_types import AgentUserMessage 9 from vet.imbue_core.agents.agent_api.data_types import ToolUseRecord 10 from vet.imbue_core.pydantic_serialization import SerializableModel 11 12 13 class AgentInteraction: 14 """A class for tracking an ongoing interaction with an agent. 15 16 Note that this class is not thread-safe. 17 """ 18 19 def __init__(self, prompt: str, options: AgentOptions) -> None: 20 self.prompt = prompt 21 self.options = options 22 self.messages: list[AgentMessage] = [] 23 self.tool_use_records: list[ToolUseRecord] = [] 24 self._unresolved_tool_use_requests: list[AgentToolUseBlock] = [] 25 26 def put(self, message: AgentMessage) -> None: 27 self.messages.append(message) 28 29 if isinstance(message, AgentAssistantMessage): 30 for assistant_content_block in message.content: 31 if isinstance(assistant_content_block, AgentToolUseBlock): 32 self._unresolved_tool_use_requests.append(assistant_content_block) 33 elif isinstance(message, AgentUserMessage) and isinstance(message.content, list): 34 for content_block in message.content: 35 if isinstance(content_block, AgentToolResultBlock): 36 remaining_unresolved_requests = [] 37 for request in self._unresolved_tool_use_requests: 38 if request.id == content_block.tool_use_id: 39 self.tool_use_records.append( 40 ToolUseRecord( 41 request_message=request, 42 result_message=content_block, 43 ) 44 ) 45 else: 46 remaining_unresolved_requests.append(request) 47 self._unresolved_tool_use_requests = remaining_unresolved_requests 48 49 def find_tool_use_record_by_command(self, command: str, by_most_recent: bool = True) -> ToolUseRecord | None: 50 """Look for tool use request and result messages by the tool command. 51 52 If by_most_recent is True, the records are searched in reverse order (most recent first). 53 """ 54 return _find_tool_use_record_by_command(self.tool_use_records, command, by_most_recent) 55 56 57 class AgentInteractionRecord(SerializableModel): 58 """A serializable record of a completed agent interaction. 59 60 This is meant to be used for storing a completed log in a database or cache. 61 """ 62 63 prompt: str 64 options: AgentOptions 65 messages: tuple[AgentMessage, ...] 66 tool_use_records: tuple[ToolUseRecord, ...] 67 68 @classmethod 69 def from_agent_interaction(cls, agent_interaction: AgentInteraction) -> "AgentInteractionRecord": 70 return cls( 71 prompt=agent_interaction.prompt, 72 options=agent_interaction.options, 73 messages=tuple(agent_interaction.messages), 74 tool_use_records=tuple(agent_interaction.tool_use_records), 75 ) 76 77 def find_tool_use_record_by_command(self, command: str, by_most_recent: bool = True) -> ToolUseRecord | None: 78 """Look for tool use request and result messages by the tool command. 79 80 If by_most_recent is True, the records are searched in reverse order (most recent first). 81 """ 82 return _find_tool_use_record_by_command(self.tool_use_records, command, by_most_recent) 83 84 85 def _find_tool_use_record_by_command( 86 tool_use_records: Sequence[ToolUseRecord], command: str, reverse: bool = True 87 ) -> ToolUseRecord | None: 88 """Look for tool use request and result messages by the tool command. 89 90 If reverse is True, the records are searched in reverse order (most recent first). 91 """ 92 for record in reversed(tool_use_records) if reverse else tool_use_records: 93 tool_input = record.tool_input 94 if "command" in tool_input and tool_input["command"] == command: 95 return record 96 return None