nli-test.py (1970B)
1 import argparse 2 import torch 3 from transformers import AutoTokenizer, AutoModelForSequenceClassification 4 5 MODEL_NAME = "facebook/bart-large-mnli" 6 # You can swap with: microsoft/deberta-v3-large-mnli 7 8 def load_model(): 9 tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) 10 model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) 11 model.eval() 12 return tokenizer, model 13 14 15 def entailment_probability(logits): 16 """ 17 MNLI label order for BART: 18 0 = contradiction 19 1 = neutral 20 2 = entailment 21 """ 22 probs = torch.softmax(logits, dim=-1) 23 return probs[:, 2].item() 24 25 26 def check_categories(tweet, categories, tokenizer, model): 27 results = {} 28 29 for cat in categories: 30 hypothesis = f"This text is about {cat}." 31 inputs = tokenizer(tweet, hypothesis, return_tensors="pt", truncation=True) 32 33 with torch.no_grad(): 34 outputs = model(**inputs) 35 36 score = entailment_probability(outputs.logits) 37 results[cat] = score 38 39 return results 40 41 42 def main(): 43 parser = argparse.ArgumentParser( 44 description="Check if a tweet belongs to user-specified categories using NLI" 45 ) 46 47 parser.add_argument( 48 "--categories", 49 nargs="+", 50 required=True, 51 help="List of categories to test" 52 ) 53 54 parser.add_argument( 55 "--threshold", 56 type=float, 57 default=0.1, 58 help="Entailment threshold for flagging" 59 ) 60 61 args = parser.parse_args() 62 63 print("Loading model...") 64 tokenizer, model = load_model() 65 66 while True: 67 tweet = input('Tweet: ') 68 results = check_categories( 69 tweet, 70 args.categories, 71 tokenizer, 72 model 73 ) 74 75 print("\nResults:") 76 for cat, score in results.items(): 77 flag = "BLOCK" if score >= args.threshold else "ALLOW" 78 print(f"{cat:20s} {score:.3f} → {flag}") 79 80 81 if __name__ == "__main__": 82 main()