information-retrieval

Exploration of information retrieval topics
git clone git://git.laack.co/information-retrieval.git
Log | Files | Refs

main.cpp (6307B)


      1 #include <cpr/cpr.h>
      2 #include <ctime>
      3 #include "../include/openai.hpp"
      4 #include <vector>
      5 #include <iostream>
      6 #include "../include/nlohmann/json.hpp"
      7 #include "prompts.cpp"
      8 
      9 struct SearchResult{
     10     std::string url;
     11     std::string title;
     12     std::string content;
     13     float score;
     14 };
     15 
     16 nlohmann::json sendUserMessage(openai::OpenAI* connection , std::string model, std::string message){
     17     nlohmann::json payload = {
     18         {"model", model},
     19         {"messages", {{{"role", "user"}, {"content", message}}}},
     20         {"temperature", 0}
     21     };
     22 
     23     auto chat = connection->chat.create(
     24         payload
     25     );
     26     return chat;
     27 }
     28 
     29 std::string getMessageFromChat(nlohmann::json json){
     30     std::string result = json["choices"][0]["message"]["content"].dump(2);
     31     return result;
     32 }
     33 
     34 std::string getResearcherPrompt(){
     35     std::string returnStr(researchPrompt);
     36 
     37     std::time_t t = std::time(nullptr);
     38     std::tm* now = std::localtime(&t);
     39     std::string dateReplacement = std::to_string(now->tm_mon + 1) + "-" 
     40                                 + std::to_string(now->tm_mday) + "-" 
     41                                 + std::to_string(now->tm_year + 1900);
     42     std::string toReplace = "$DATE";
     43 
     44     std::size_t dateLocation = returnStr.find(toReplace);
     45     returnStr.replace(dateLocation, toReplace.length(), dateReplacement);
     46 
     47     return returnStr;
     48 }
     49 
     50 
     51 std::vector<std::pair<std::string, std::string>> parseToolCalls(std::string response){
     52 
     53     std::vector<std::string> rawToolCalls;
     54 
     55     std::stringstream ss(response);
     56     std::string line;
     57 
     58     while (getline(ss, line)) {
     59         if (!line.empty()) {
     60             rawToolCalls.push_back(line);
     61         }
     62     }
     63 
     64 
     65     auto pairs = std::vector<std::pair<std::string, std::string>>();
     66 
     67     for(int i = 0 ; i < rawToolCalls.size(); ++i){
     68         
     69         std::string current = rawToolCalls[i];
     70 
     71         int space = current.find(' ');
     72         int parenthesis = current.find('(');
     73 
     74         std::string toolcall;
     75         std::string remaining;
     76 
     77 
     78         // TODO: Do prefix matching
     79         if (space != std::string::npos){
     80             toolcall = current.substr(1, space - 1);
     81             std::string unclean = current.substr(space + 1, current.size() - space - 2);
     82             for (size_t i = 0; i < unclean.size(); ++i) {
     83                 if (unclean[i] == '\\' && i + 1 < unclean.size() && unclean[i + 1] == '"') {
     84                     continue;
     85                 }
     86                 remaining += unclean[i];
     87             }
     88             
     89         }
     90         else {
     91             toolcall = "done";
     92         }
     93 
     94         auto pair = std::pair(toolcall, remaining);
     95         pairs.push_back(pair);
     96     }
     97     return pairs;
     98 }
     99 
    100 
    101 
    102 // I hate cpr.
    103 std::vector<SearchResult> searchSearxng(std::string url, std::string query){
    104 
    105     std::vector<SearchResult> returnValues;
    106 
    107     cpr::Response r = cpr::Get(
    108         cpr::Url{url + "/search"},
    109         cpr::Parameters{{"q", query}, {"format", "json"}}
    110     );
    111     // TODO: Check status code
    112     //std::cout << r.status_code << std::endl;
    113     //std::cout << r.text << std::endl;
    114 
    115     nlohmann::json currentJson = nlohmann::json::parse(r.text);
    116     nlohmann::json results = currentJson["results"];
    117 
    118     for (int i = 0 ; i < results.size(); ++i){
    119         auto currentResult = SearchResult();
    120         currentResult.url = results[i]["url"].get<std::string>();
    121         currentResult.content= results[i]["content"].get<std::string>();
    122         currentResult.score = results[i]["score"].get<float>();
    123         currentResult.title = results[i]["title"].get<std::string>();
    124         returnValues.push_back(currentResult);
    125     }
    126 
    127     return returnValues;
    128 }
    129 
    130 int main(int argc, char* argv[]) {
    131 
    132 
    133     openai::start();
    134     openai::OpenAI connection = openai::OpenAI();
    135 
    136     std::string envVariable = "ANTHROPIC_API_KEY";
    137     std::string model = "claude-opus-4-5-20251101";
    138     std::string baseURL = "https://api.anthropic.com/v1/";
    139     std::string priorRuns = "";
    140     int linkNum = 1;
    141 
    142     connection.setToken(getenv(envVariable.c_str()));
    143     connection.setBaseUrl(baseURL);
    144 
    145     std::string chatMessage;
    146     if(argc == 2){
    147         chatMessage = argv[1];
    148     }
    149     else{
    150         std::getline(std::cin, chatMessage);
    151     }
    152 
    153     while(true){
    154         std::vector<std::pair<std::string, std::string>> toolCallsAndParams;
    155         std::string message;
    156 
    157         std::string prompt = getResearcherPrompt();
    158         message = prompt + chatMessage;
    159 
    160         nlohmann::json response = sendUserMessage(&connection, model, priorRuns + message);
    161         toolCallsAndParams = parseToolCalls(getMessageFromChat(response));
    162 
    163 
    164         for(int i = 0 ; i < toolCallsAndParams.size(); ++i){
    165             std::pair<std::string, std::string> current = toolCallsAndParams[i];
    166             if(current.first.compare("web_search") == 0){
    167                 //std::cout << current.second << std::endl;
    168                 std::vector<std::string> params = nlohmann::json::parse(current.second).get<std::vector<std::string>>();
    169                 for(int i = 0 ; i < params.size(); ++i){
    170                     std::cout << "Searching the web for " << params[i] << std::endl;
    171                     //std::cout << params[i] << std::endl;
    172                     // TODO: add citations
    173                     std::vector<SearchResult> searchRes = searchSearxng("https://searx.laack.co", params[i]);
    174 
    175                     priorRuns += "Searching for " + params[i];
    176                     for (int i = 0 ; i < searchRes.size(); ++i){
    177                         priorRuns += "Results from " + searchRes[i].url + " [" + std::to_string(linkNum) + "]:\n" + searchRes[i].content + "\n";
    178                         linkNum += 1;
    179                     }
    180                 }
    181             }
    182             else{
    183                 if(current.first == "done"){
    184                     // TODO: Add a final summarization prompt
    185                     std::string summary = "Consider the above context, and use it to answer the following question:";
    186                     std::string final = "\n Ensure your final answer makes extensive references to resources with [1] syntax.";
    187 
    188                     nlohmann::json response = sendUserMessage(&connection, model, priorRuns + summary + chatMessage + final);
    189                     std::cout << getMessageFromChat(response) << std::endl;
    190                     return 0;
    191                 }
    192             }
    193         }
    194     }
    195 }