information-retrieval

Exploration of information retrieval topics
git clone git://git.laack.co/information-retrieval.git
Log | Files | Refs

researcher.cpp (6056B)


      1 #include <stdexcept>
      2 #include <iostream>
      3 #include "../include/nlohmann/json.hpp"
      4 #include "../include/openai.hpp"
      5 #include "../include/researcher.hpp"
      6 #include <cpr/cpr.h>
      7 
      8 
      9 std::string Researcher::downloadSite(std::string url){
     10     std::cout << "Downloading " << url << std::endl;
     11     
     12     std::string command = "lynx --dump --nolist --display_charset=US-ASCII " + url;
     13     
     14     FILE* pipe = popen(command.c_str(), "r");
     15     std::string result;
     16     char buffer[4096];
     17     while (fgets(buffer, sizeof(buffer), pipe) != nullptr) {
     18         result += buffer;
     19     }
     20     
     21     pclose(pipe);
     22     
     23     if (result.length() > 10000) {
     24         result = result.substr(0, 10000) + "\n... (content truncated)";
     25     }
     26     
     27     return "DOWNLOADED CONTENT:\n" + result;
     28 }
     29 
     30 std::vector<SearchResult> Researcher::webSearch(std::string url, std::string query){
     31 
     32     std::cout << "Web searching for: " + query << std::endl;
     33 
     34     std::vector<SearchResult> returnValues;
     35 
     36     cpr::Response r = cpr::Get(
     37         cpr::Url{url + "/search"},
     38         cpr::Parameters{{"q", query}, {"format", "json"}}
     39     );
     40     // TODO: Check status code
     41     //std::cout << r.status_code << std::endl;
     42     //std::cout << r.text << std::endl;
     43 
     44     nlohmann::json currentJson = nlohmann::json::parse(r.text);
     45     nlohmann::json results = currentJson["results"];
     46 
     47     for (int i = 0 ; i < results.size(); ++i){
     48         auto currentResult = SearchResult();
     49         currentResult.url = results[i]["url"].get<std::string>();
     50         currentResult.content= results[i]["content"].get<std::string>();
     51         currentResult.score = results[i]["score"].get<float>();
     52         currentResult.title = results[i]["title"].get<std::string>();
     53         returnValues.push_back(currentResult);
     54     }
     55 
     56     std::cout << "Found " << results.size() << " search results"<< std::endl;
     57     return returnValues;
     58 }
     59 
     60 
     61 
     62 nlohmann::json Researcher::sendUserMessage(openai::OpenAI* connection , std::string model, std::string message){
     63 
     64     if(messages.size() == 0){
     65         throw std::runtime_error("Cannot send first message, system prompt not set.");
     66     }
     67 
     68     nlohmann::json userMessage = {
     69         {"role", "user"},
     70         {"content", message}
     71     };
     72     messages.push_back(userMessage);
     73 
     74     nlohmann::json messageHistory = nlohmann::json::array();
     75     for (const auto& msg : messages){
     76         messageHistory.push_back(msg);
     77     }
     78 
     79     nlohmann::json payload = {
     80         {"model", model},
     81         {"messages", messages},
     82         {"temperature", 0}
     83     };
     84 
     85     auto chat = connection->chat.create(
     86         payload
     87     );
     88 
     89     return chat;
     90 }
     91 
     92 std::string Researcher::executeTool(std::string output){
     93     std::string returnStr = "";
     94     std::vector<std::string> keywords = {"Search_web", "Download_webpage", "Done"};
     95 
     96     for(size_t i = 0; i < output.size(); ++i){
     97         for(size_t x = 0; x < keywords.size(); ++x){
     98             std::string current_word = keywords[x];
     99             
    100             if(i + current_word.size() > output.size()) continue;
    101             
    102             if(output.substr(i, current_word.size()) == current_word){
    103                 
    104                 size_t paren_start = i + current_word.size();
    105                 
    106                 if(paren_start >= output.size() || output[paren_start] != '('){
    107                     continue;
    108                 }
    109                 
    110                 size_t content_start = paren_start + 1;
    111                 
    112                 size_t content_end = content_start;
    113                 int paren_depth = 1;
    114                 
    115                 while(content_end < output.size() && paren_depth > 0){
    116                     if(output[content_end] == '(') paren_depth++;
    117                     else if(output[content_end] == ')') paren_depth--;
    118                     if(paren_depth > 0) content_end++;
    119                 }
    120                 
    121                 std::string inside = "";
    122                 if(content_end > content_start){
    123                     inside = output.substr(content_start, content_end - content_start);
    124                 }
    125                 
    126                 if(inside.size() >= 2 && inside.front() == '"' && inside.back() == '"'){
    127                     inside = inside.substr(1, inside.size() - 2);
    128                 }
    129                 if(inside.size() >= 4 && inside.substr(0, 2) == "\\\"" && inside.substr(inside.size() - 2) == "\\\""){
    130                     inside = inside.substr(2, inside.size() - 4);
    131                 }
    132                 
    133                 if(current_word == "Search_web"){
    134                     std::vector<SearchResult> results = webSearch("https://searx.laack.co", inside);
    135                     returnStr += "WEB SEARCH RESULTS:\n";
    136                     for(size_t z = 0; z < results.size(); ++z){
    137                         returnStr += "URL: " + results[z].url + "\n";
    138                         returnStr += "TITLE: " + results[z].title + "\n";
    139                         returnStr += "CONTENT: " + results[z].content + "\n";
    140                     }
    141                 }
    142                 
    143                 if(current_word == "Download_webpage"){
    144                     returnStr += downloadSite(inside);
    145                 }
    146                 
    147                 if(current_word == "Done"){
    148                     return "DONE";
    149                 }
    150                 
    151                 i = content_end;
    152             }
    153         }
    154     }
    155     return returnStr;
    156 }
    157 
    158 std::string Researcher::getDocs(){
    159     std::string result = "";
    160     // skip system prompt.
    161     for(int i = 1; i < messages.size(); ++i){
    162         result += messages[i].dump(2);
    163         result += "\n\n";
    164     }
    165     return result;
    166 }
    167 
    168 std::string Researcher::getMessageFromChat(nlohmann::json json){
    169     std::string result = json["choices"][0]["message"]["content"].get<std::string>();
    170     return result;
    171 }
    172 
    173 void Researcher::setSystemPrompt(std::string prompt){
    174     if (messages.size() == 0){
    175         nlohmann::json userMessage = {
    176             {"role", "system"},
    177             {"content", prompt}
    178         };
    179         messages.push_back(userMessage);
    180     }
    181     else{
    182         throw std::runtime_error("Cannot set system prompt after first message is sent.");
    183     }
    184 
    185     return;
    186 }