information-retrieval

Exploration of information retrieval topics
git clone git://git.laack.co/information-retrieval.git
Log | Files | Refs

nli-test.py (1970B)


      1 import argparse
      2 import torch
      3 from transformers import AutoTokenizer, AutoModelForSequenceClassification
      4 
      5 MODEL_NAME = "facebook/bart-large-mnli"
      6 # You can swap with: microsoft/deberta-v3-large-mnli
      7 
      8 def load_model():
      9     tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
     10     model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
     11     model.eval()
     12     return tokenizer, model
     13 
     14 
     15 def entailment_probability(logits):
     16     """
     17     MNLI label order for BART:
     18     0 = contradiction
     19     1 = neutral
     20     2 = entailment
     21     """
     22     probs = torch.softmax(logits, dim=-1)
     23     return probs[:, 2].item()
     24 
     25 
     26 def check_categories(tweet, categories, tokenizer, model):
     27     results = {}
     28 
     29     for cat in categories:
     30         hypothesis = f"This text is about {cat}."
     31         inputs = tokenizer(tweet, hypothesis, return_tensors="pt", truncation=True)
     32 
     33         with torch.no_grad():
     34             outputs = model(**inputs)
     35 
     36         score = entailment_probability(outputs.logits)
     37         results[cat] = score
     38 
     39     return results
     40 
     41 
     42 def main():
     43     parser = argparse.ArgumentParser(
     44         description="Check if a tweet belongs to user-specified categories using NLI"
     45     )
     46 
     47     parser.add_argument(
     48         "--categories",
     49         nargs="+",
     50         required=True,
     51         help="List of categories to test"
     52     )
     53 
     54     parser.add_argument(
     55         "--threshold",
     56         type=float,
     57         default=0.1,
     58         help="Entailment threshold for flagging"
     59     )
     60 
     61     args = parser.parse_args()
     62 
     63     print("Loading model...")
     64     tokenizer, model = load_model()
     65 
     66     while True:
     67         tweet = input('Tweet: ')
     68         results = check_categories(
     69             tweet,
     70             args.categories,
     71             tokenizer,
     72             model
     73         )
     74 
     75         print("\nResults:")
     76         for cat, score in results.items():
     77             flag = "BLOCK" if score >= args.threshold else "ALLOW"
     78             print(f"{cat:20s} {score:.3f} → {flag}")
     79 
     80 
     81 if __name__ == "__main__":
     82     main()