commit 1507d56ab81f28ecbd24025082ac2ff5c7ad7746
parent eb447c513be89e86ffa878fe0e9239a9978cd1e9
Author: Andrew Laack <andrew@laack.co>
Date: Tue, 17 Feb 2026 14:02:44 -0600
Playing with nli and smol agents
Diffstat:
4 files changed, 238 insertions(+), 0 deletions(-)
diff --git a/nli/nli-test.py b/nli/nli-test.py
@@ -0,0 +1,82 @@
+import argparse
+import torch
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
+
+MODEL_NAME = "facebook/bart-large-mnli"
+# You can swap with: microsoft/deberta-v3-large-mnli
+
+def load_model():
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
+ model.eval()
+ return tokenizer, model
+
+
+def entailment_probability(logits):
+ """
+ MNLI label order for BART:
+ 0 = contradiction
+ 1 = neutral
+ 2 = entailment
+ """
+ probs = torch.softmax(logits, dim=-1)
+ return probs[:, 2].item()
+
+
+def check_categories(tweet, categories, tokenizer, model):
+ results = {}
+
+ for cat in categories:
+ hypothesis = f"This text is about {cat}."
+ inputs = tokenizer(tweet, hypothesis, return_tensors="pt", truncation=True)
+
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ score = entailment_probability(outputs.logits)
+ results[cat] = score
+
+ return results
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Check if a tweet belongs to user-specified categories using NLI"
+ )
+
+ parser.add_argument(
+ "--categories",
+ nargs="+",
+ required=True,
+ help="List of categories to test"
+ )
+
+ parser.add_argument(
+ "--threshold",
+ type=float,
+ default=0.1,
+ help="Entailment threshold for flagging"
+ )
+
+ args = parser.parse_args()
+
+ print("Loading model...")
+ tokenizer, model = load_model()
+
+ while True:
+ tweet = input('Tweet: ')
+ results = check_categories(
+ tweet,
+ args.categories,
+ tokenizer,
+ model
+ )
+
+ print("\nResults:")
+ for cat, score in results.items():
+ flag = "BLOCK" if score >= args.threshold else "ALLOW"
+ print(f"{cat:20s} {score:.3f} → {flag}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/smol/web/__pycache__/web_search.cpython-313.pyc b/smol/web/__pycache__/web_search.cpython-313.pyc
Binary files differ.
diff --git a/smol/web/search.py b/smol/web/search.py
@@ -0,0 +1,72 @@
+import pandas as pd
+import sys
+from datetime import datetime
+import os
+from smolagents import CodeAgent, OpenAIModel
+from web_search import WebSearchTool
+from web_search import WebVisitTool
+import os
+
+from smolagents import (
+ CodeAgent,
+ ToolCallingAgent,
+)
+
+
+custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}
+user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
+
+def create_agent(model):
+
+ tools = [WebSearchTool(), WebVisitTool()]
+
+ text_webbrowser_agent = ToolCallingAgent(
+ model=model,
+ tools=tools,
+ max_steps=200,
+ verbosity_level=2,
+ planning_interval=4,
+ name="search_agent",
+ description=f"""A team member that will search the internet to answer your question.
+ Ask him for all your questions that require browsing the web.
+ Provide him as much context as possible, in particular if you need to search on a specific timeframe!
+ And don't hesitate to provide him with a complex search task, like finding a difference between two webpages.
+ Your request must be a real sentence, not a google search! Like "Find me this information (...)" rather than a few keywords. Note that the date today is {datetime.now().strftime('%B %d, %Y')}
+ """,
+ provide_run_summary=True,
+ )
+ text_webbrowser_agent.prompt_templates["managed_agent"]["task"] += """You can navigate to .txt online files.
+ If a non-html page is in another format, especially .pdf or a Youtube video, use tool 'inspect_file_as_text' to inspect it.
+ Additionally, if after some searching you find out that you need more information to answer the question, you can use `final_answer` with your request for clarification as argument to request for more information."""
+
+ manager_agent = CodeAgent(
+ model=model,
+ tools=tools,
+ max_steps=12,
+ verbosity_level=2,
+ additional_authorized_imports=["*"],
+ planning_interval=4,
+ managed_agents=[text_webbrowser_agent],
+ )
+
+ return manager_agent
+
+
+if __name__ == "__main__":
+ model = OpenAIModel(
+ model_id="claude-opus-4-5-20251101",
+ api_base="https://api.anthropic.com/v1/",
+ api_key=os.environ["ANTHROPIC_API_KEY"],
+ )
+
+ agent = create_agent(model)
+
+ #answer = agent.run(sys.argv[1] + f"Note that the date today is {datetime.now().strftime('%B %d, %Y')}")
+
+ df = pd.read_csv('questions.csv', quotechar='"', on_bad_lines='skip')
+ itr = 0
+ for index, row in df.iterrows():
+ print(row['question'])
+ answer = agent.run(row['question'] + f" Note that the date today is {datetime.now().strftime('%B %d, %Y')}")
+ with open(str(itr) + '.txt', 'w') as f:
+ f.write(str(answer))
diff --git a/smol/web/web_search.py b/smol/web/web_search.py
@@ -0,0 +1,84 @@
+import requests
+from smolagents import Tool
+
+class WebSearchTool(Tool):
+ name = "web_search"
+ description = "Performs a web search for a query and returns a string of the top search results formatted as markdown with titles, links, and descriptions."
+ inputs = {"query": {"type": "string", "description": "The search query to perform."}}
+ output_type = "string"
+ def forward(self, query: str) -> str:
+ src_params = {
+ 'q': query,
+ 'format': 'json'
+ }
+ search_url = 'https://searx.laack.co/search'
+
+ try:
+ response = requests.get(search_url, params=src_params)
+ response.raise_for_status()
+ res_list = response.json()['results']
+ except (requests.RequestException, KeyError) as e:
+ return f"Search failed: {e}"
+
+ markdown_results = []
+ for result in res_list:
+ title = result.get('title', 'No title')
+ url = result.get('url', '')
+ content = result.get('content', 'No description')
+ markdown_results.append(f"### [{title}]({url})\n{content}\n")
+
+ return "\n".join(markdown_results) if markdown_results else "No results found."
+
+import requests
+from smolagents import Tool
+from bs4 import BeautifulSoup
+from pypdf import PdfReader
+from io import BytesIO
+
+class WebVisitTool(Tool):
+ name = "visit_webpage"
+ description = "Visits a webpage or PDF at the given URL and returns its text content. Supports HTML pages and PDF documents."
+ inputs = {"url": {"type": "string", "description": "The URL of the webpage or PDF to visit."}}
+ output_type = "string"
+
+ def forward(self, url: str) -> str:
+ headers = {
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
+ }
+
+ try:
+ response = requests.get(url, headers=headers, timeout=15)
+ response.raise_for_status()
+
+ content_type = response.headers.get('Content-Type', '').lower()
+
+ if 'application/pdf' in content_type or url.lower().endswith('.pdf'):
+ return self._parse_pdf(response.content)
+
+ return self._parse_html(response.text)
+
+ except requests.RequestException as e:
+ return f"Failed to fetch URL: {e}"
+
+ def _parse_pdf(self, content: bytes) -> str:
+ try:
+ reader = PdfReader(BytesIO(content))
+ text_parts = []
+ for page in reader.pages:
+ text_parts.append(page.extract_text() or "")
+ text = "\n".join(text_parts)
+ return self._truncate(text)
+ except Exception as e:
+ return f"Failed to parse PDF: {e}"
+
+ def _parse_html(self, html: str) -> str:
+ soup = BeautifulSoup(html, 'html.parser')
+ for element in soup(['script', 'style', 'nav', 'footer', 'header']):
+ element.decompose()
+ text = soup.get_text(separator='\n', strip=True)
+ return self._truncate(text)
+
+ def _truncate(self, text: str, max_chars: int = 15000) -> str:
+ if len(text) > max_chars:
+ return text[:max_chars] + "\n\n[Content truncated...]"
+ return text if text else "No readable content found."