From 5d076328ecb9b560c3f601a3532efa69bcaceb73 Mon Sep 17 00:00:00 2001 From: Myzel394 <50424412+Myzel394@users.noreply.github.com> Date: Wed, 21 Feb 2024 09:24:44 +0100 Subject: [PATCH] fix: Use mpsc to pass messages --- src/engines/brave.rs | 112 ++++++++++++-------------- src/engines/duckduckgo.rs | 157 ------------------------------------- src/engines/engine_base.rs | 115 ++++++++++++++++++++++----- src/helpers.rs | 31 +++----- src/main.rs | 72 ++++------------- 5 files changed, 172 insertions(+), 315 deletions(-) diff --git a/src/engines/brave.rs b/src/engines/brave.rs index aeff021..f33b254 100644 --- a/src/engines/brave.rs +++ b/src/engines/brave.rs @@ -1,12 +1,19 @@ // Search engine parser for Brave Search // This uses the clearnet, unlocalized version of the search engine. pub mod brave { + use std::sync::Arc; + + use futures::lock::Mutex; use lazy_static::lazy_static; use regex::Regex; + use tokio::sync::mpsc::Sender; use urlencoding::decode; use crate::{ - engines::engine_base::engine_base::{EngineBase, SearchEngine, SearchResult}, + engines::engine_base::engine_base::{ + EngineBase, EnginePositions, ResultsCollector, SearchEngine, SearchResult, + }, + helpers::helpers::build_default_client, utils::utils::decode_html_text, }; @@ -17,68 +24,40 @@ pub mod brave { static ref STRIP_HTML_TAGS: Regex = Regex::new(r#"<(?:"[^"]*"['"]*|'[^']*'['"]*|[^'">])+>"#).unwrap(); } + #[derive(Clone, Debug)] pub struct Brave { - pub completed: bool, - results_started: bool, - pub previous_block: String, - pub results: Vec, - } - - impl Brave { - fn slice_remaining_block(&mut self, start_position: &usize) { - let previous_block_bytes = self.previous_block.as_bytes().to_vec(); - let remaining_bytes = previous_block_bytes[*start_position..].to_vec(); - let remaining_text = String::from_utf8(remaining_bytes).unwrap(); - - self.previous_block.clear(); - self.previous_block.push_str(&remaining_text); - } - - pub fn new() -> Self { - Self { - results_started: false, - previous_block: String::new(), - results: vec![], - completed: false, - } - } + positions: EnginePositions, } impl EngineBase for Brave { - fn add_result(&mut self, result: crate::engines::engine_base::engine_base::SearchResult) { - self.results.push(result); - } - fn parse_next<'a>(&mut self) -> Option { - if self.results_started { - match SINGLE_RESULT.captures(&self.previous_block.to_owned()) { - Some(captures) => { - let title = decode(captures.name("title").unwrap().as_str()) - .unwrap() - .into_owned(); - let description_raw = - decode_html_text(captures.name("description").unwrap().as_str()) - .unwrap(); - let description = STRIP_HTML_TAGS - .replace_all(&description_raw, "") - .into_owned(); - let url = decode(captures.name("url").unwrap().as_str()) - .unwrap() - .into_owned(); + if self.positions.started { + if let Some(capture) = + SINGLE_RESULT.captures(&self.positions.previous_block.to_owned()) + { + let title = decode(capture.name("title").unwrap().as_str()) + .unwrap() + .into_owned(); + let description_raw = + decode_html_text(capture.name("description").unwrap().as_str()).unwrap(); + let description = STRIP_HTML_TAGS + .replace_all(&description_raw, "") + .into_owned(); + let url = decode(capture.name("url").unwrap().as_str()) + .unwrap() + .into_owned(); - let result = SearchResult { - title, - description, - url, - engine: SearchEngine::DuckDuckGo, - }; + let result = SearchResult { + title, + description, + url, + engine: SearchEngine::DuckDuckGo, + }; - let end_position = captures.get(0).unwrap().end(); - self.slice_remaining_block(&end_position); + let end_position = capture.get(0).unwrap().end(); + self.positions.slice_remaining_block(&end_position); - return Some(result); - } - None => {} + return Some(result); } } @@ -90,15 +69,28 @@ pub mod brave { let raw_text = String::from_utf8_lossy(&bytes); let text = STRIP.replace_all(&raw_text, " "); - if self.results_started { - self.previous_block.push_str(&text); + if self.positions.started { + self.positions.previous_block.push_str(&text); } else { - self.results_started = RESULTS_START.is_match(&text); + self.positions.started = RESULTS_START.is_match(&text); + } + } + } + + impl Brave { + pub fn new() -> Self { + Self { + positions: EnginePositions::new(), } } - async fn search(&mut self, query: &str) { - todo!() + pub async fn search(&mut self, query: &str, tx: Sender) { + let client = build_default_client(); + let request = client + .get(format!("https://search.brave.com/search?q={}", query)) + .send(); + + self.handle_request(request, tx).await; } } } diff --git a/src/engines/duckduckgo.rs b/src/engines/duckduckgo.rs index 77cbad4..fbe8866 100644 --- a/src/engines/duckduckgo.rs +++ b/src/engines/duckduckgo.rs @@ -1,14 +1,7 @@ // Search engine parser for DuckDuckGo pub mod duckduckgo { - use std::{ - io::{Read, Write}, - net::TcpStream, - sync::Arc, - }; - use lazy_static::lazy_static; use regex::Regex; - use rustls::RootCertStore; use urlencoding::decode; use crate::{ @@ -72,10 +65,6 @@ pub mod duckduckgo { // } impl EngineBase for DuckDuckGo { - fn add_result(&mut self, result: SearchResult) { - self.results.push(result); - } - fn parse_next<'a>(&mut self) -> Option { if self.results_started { match SINGLE_RESULT.captures(&self.previous_block.to_owned()) { @@ -123,152 +112,6 @@ pub mod duckduckgo { self.results_started = RESULTS_START.is_match(&text); } } - - // Searches DuckDuckGo for the given query - // Uses rustls as reqwest does not support accessing the raw packets - async fn search(&mut self, query: &str) { - let root_store = - RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let mut config = rustls::ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - - // Allow using SSLKEYLOGFILE. - config.key_log = Arc::new(rustls::KeyLogFile::new()); - - let now = std::time::Instant::now(); - let server_name = "html.duckduckgo.com".try_into().unwrap(); - let mut conn = rustls::ClientConnection::new(Arc::new(config), server_name).unwrap(); - let mut sock = TcpStream::connect("html.duckduckgo.com:443").unwrap(); - let mut tls = rustls::Stream::new(&mut conn, &mut sock); - tls.write_all( - concat!( - "POST /html/ HTTP/1.1\r\n", - "Host: html.duckduckgo.com\r\n", - "Connection: cloSe\r\n", - "Accept-Encoding: identity\r\n", - "Content-Length: 6\r\n", - // form data - "Content-Type: application/x-www-form-urlencoded\r\n", - "\r\n", - "q=test", - ) - .as_bytes(), - ) - .unwrap(); - let mut plaintext = Vec::new(); - dbg!(now.elapsed()); - - loop { - let mut buf = [0; 65535]; - tls.conn.complete_io(tls.sock); - let n = tls.conn.reader().read(&mut buf); - - if n.is_ok() { - dbg!(&n); - let n = n.unwrap(); - if n == 0 { - break; - } - println!("{}", "================="); - dbg!(now.elapsed()); - // println!("{}", String::from_utf8_lossy(&buf)); - plaintext.extend_from_slice(&buf); - } - } - - // let root_store = RootCertStore { - // roots: webpki_roots::TLS_SERVER_ROOTS.into(), - // }; - // - // let mut config = rustls::ClientConfig::builder() - // .with_root_certificates(root_store) - // .with_no_client_auth(); - // - // // Allow using SSLKEYLOGFILE. - // config.key_log = Arc::new(rustls::KeyLogFile::new()); - // - // let server_name = "html.duckduckgo.com".try_into().unwrap(); - // let mut conn = rustls::ClientConnection::new(Arc::new(config), server_name).unwrap(); - // - // let mut sock = TcpStream::connect("html.duckduckgo.com:443").unwrap(); - // let mut tls = rustls::Stream::new(&mut conn, &mut sock); - // tls.write_all( - // concat!( - // "POST /html/ HTTP/1.1\r\n", - // "Host: html.duckduckgo.com\r\n", - // "Connection: close\r\n", - // "Accept-Encoding: identity\r\n", - // "Content-Length: 6\r\n", - // // form data - // "Content-Type: application/x-www-form-urlencoded\r\n", - // "\r\n", - // "q=test", - // ) - // .as_bytes(), - // ) - // .unwrap(); - // let ciphersuite = tls.conn.negotiated_cipher_suite().unwrap(); - // writeln!( - // &mut std::io::stderr(), - // "Current ciphersuite: {:?}", - // ciphersuite.suite() - // ) - // .unwrap(); - // - // // Iterate over the stream to read the response. - // loop { - // let mut buf = [0u8; 1024]; - // let n = tls.read(&mut buf).unwrap(); - // if n == 0 { - // break; - // } - // - // if let Some(result) = self.parse_packet(buf.iter()) { - // self.add_result(result); - // - // // Wait one second - // std::thread::sleep(std::time::Duration::from_millis(100)); - // } - // } - // - // while let Some(result) = self.parse_next() { - // self.add_result(result); - // } - // - // dbg!("done with searching"); - - // let client = reqwest::Client::new(); - // - // let now = std::time::Instant::now(); - // - // let mut stream = client - // .post("https://html.duckduckgo.com/html/") - // .header("Content-Type", "application/x-www-form-urlencoded") - // .body(format!("q={}", query)) - // .send() - // .await - // .unwrap() - // .bytes_stream(); - // - // let diff = now.elapsed(); - // dbg!(diff); - // - // while let Some(item) = stream.next().await { - // let packet = item.unwrap(); - // - // if let Some(result) = self.parse_packet(packet.iter()) { - // self.add_result(result); - // } - // } - // - // while let Some(result) = self.parse_next() { - // self.add_result(result); - // } - // - // let second_diff = now.elapsed(); - // dbg!(second_diff); - } } impl DuckDuckGo { diff --git a/src/engines/engine_base.rs b/src/engines/engine_base.rs index e33b7ad..23e1eaa 100644 --- a/src/engines/engine_base.rs +++ b/src/engines/engine_base.rs @@ -1,12 +1,11 @@ pub mod engine_base { use std::sync::Arc; - use bytes::Bytes; - - use futures::{lock::Mutex, Future, Stream, StreamExt}; + use futures::{lock::Mutex, Future, StreamExt}; use lazy_static::lazy_static; use regex::Regex; - use reqwest::{Client, Error, Response}; + use reqwest::{Error, Response}; + use tokio::sync::mpsc::Sender; lazy_static! { static ref STRIP: Regex = Regex::new(r"\s+").unwrap(); @@ -25,23 +24,19 @@ pub mod engine_base { pub engine: SearchEngine, } - pub trait EngineBase { - fn add_result(&mut self, result: SearchResult); + /// ResultsCollector collects results across multiple tasks + #[derive(Clone, Debug, Hash, Default)] + pub struct ResultsCollector { + pub started: bool, + pub previous_block: String, + results: Vec, + current_index: usize, + } + pub trait EngineBase { fn parse_next<'a>(&mut self) -> Option; fn push_packet<'a>(&mut self, packet: impl Iterator); - // fn push_packet<'a>(&mut self, packet: impl Iterator) { - // let bytes: Vec = packet.map(|bit| *bit).collect(); - // let raw_text = String::from_utf8_lossy(&bytes); - // let text = STRIP.replace_all(&raw_text, " "); - // - // if self.results_started { - // self.previous_block.push_str(&text); - // } else { - // self.results_started = RESULTS_START.is_match(&text); - // } - // } /// Push packet to internal block and return next available search result, if available fn parse_packet<'a>( @@ -53,11 +48,89 @@ pub mod engine_base { self.parse_next() } - async fn search(&mut self, query: &str); + async fn handle_request( + &mut self, + request: impl Future>, + tx: Sender, + ) { + let mut stream = request.await.unwrap().bytes_stream(); + + while let Some(chunk) = stream.next().await { + let buffer = chunk.unwrap(); + + self.push_packet(buffer.iter()); + + while let Some(result) = self.parse_next() { + tx.send(result).await; + } + } + + while let Some(result) = self.parse_next() { + tx.send(result).await; + } + } } - #[derive(Clone, Debug, Hash, Default)] - pub struct ResultsCollector { - results: Vec, + impl ResultsCollector { + pub fn new() -> Self { + Self { + results: Vec::new(), + current_index: 0, + previous_block: String::new(), + started: false, + } + } + + pub fn results(&self) -> &Vec { + &self.results + } + + pub fn add_result(&mut self, result: SearchResult) { + self.results.push(result); + } + + pub fn get_next_items(&self) -> &[SearchResult] { + if self.current_index >= self.results.len() { + return &[]; + } + + &self.results[self.current_index + 1..self.results.len()] + } + + pub fn update_index(&mut self) { + self.current_index = self.results.len() - 1; + } + + pub fn has_more_results(&self) -> bool { + if self.results.len() == 0 { + return true; + } + + self.current_index < self.results.len() - 1 + } + } + + #[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] + pub struct EnginePositions { + pub previous_block: String, + pub started: bool, + } + + impl EnginePositions { + pub fn new() -> Self { + EnginePositions { + previous_block: String::new(), + started: false, + } + } + + pub fn slice_remaining_block(&mut self, start_position: &usize) { + let previous_block_bytes = self.previous_block.as_bytes().to_vec(); + let remaining_bytes = previous_block_bytes[*start_position..].to_vec(); + let remaining_text = String::from_utf8(remaining_bytes).unwrap(); + + self.previous_block.clear(); + self.previous_block.push_str(&remaining_text); + } } } diff --git a/src/helpers.rs b/src/helpers.rs index 1554642..65572e2 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -4,29 +4,18 @@ pub mod helpers { use std::sync::Arc; - use futures::{lock::Mutex, Future, StreamExt}; - use reqwest::{Error, Response}; + use bytes::Bytes; + use futures::{lock::Mutex, Future, Stream, StreamExt}; + use reqwest::{Client, ClientBuilder, Error, Response}; - use crate::engines::engine_base::engine_base::EngineBase; + use crate::engines::engine_base::engine_base::{EngineBase, ResultsCollector}; - pub async fn run_search( - request: impl Future>, - engine_ref: Arc>, - ) { - let response = request.await.unwrap(); + const DEFAULT_USER_AGENT: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.3"; - let mut stream = response.bytes_stream(); - while let Some(chunk) = stream.next().await { - let buffer = chunk.unwrap(); - - let mut engine = engine_ref.lock().await; - - if let Some(result) = engine.parse_packet(buffer.iter()) { - engine.add_result(result); - - drop(engine); - tokio::task::yield_now().await; - } - } + pub fn build_default_client() -> Client { + ClientBuilder::new() + .user_agent(DEFAULT_USER_AGENT) + .build() + .unwrap() } } diff --git a/src/main.rs b/src/main.rs index 54e929a..97d283d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,15 @@ -use std::str; use std::sync::Arc; +use std::{str, thread}; use engines::brave::brave::Brave; +use engines::engine_base::engine_base::{ResultsCollector, SearchResult}; use futures::lock::Mutex; use lazy_static::lazy_static; -use reqwest::ClientBuilder; use rocket::response::content::{RawCss, RawHtml}; use rocket::response::stream::TextStream; use rocket::time::Instant; -use utils::utils::Yieldable; +use tokio::sync::mpsc; -use crate::helpers::helpers::run_search; use crate::static_files::static_files::read_file_contents; pub mod client; @@ -52,73 +51,34 @@ fn get_tailwindcss() -> RawCss<&'static str> { async fn hello<'a>(query: &str) -> RawHtml { let query_box = query.to_string(); - let completed_ref = Arc::new(Mutex::new(false)); - let completed_ref_writer = completed_ref.clone(); - let brave_ref = Arc::new(Mutex::new(Brave::new())); - let brave_ref_writer = brave_ref.clone(); - let mut brave_first_result_has_yielded = false; - let brave_first_result_start = Instant::now(); - let client = Arc::new(Box::new( - ClientBuilder::new().user_agent(USER_AGENT).build().unwrap(), - )); - let client_ref = client.clone(); + let mut first_result_yielded = false; + let first_result_start = Instant::now(); + + let (tx, mut rx) = mpsc::channel::(16); tokio::spawn(async move { - let request = client_ref - .get(format!("https://search.brave.com/search?q={}", query_box)) - .send(); + let mut brave = Brave::new(); - run_search(request, brave_ref_writer).await; - - let mut completed = completed_ref_writer.lock().await; - *completed = true; + brave.search(&query_box, tx).await; }); - let mut current_index = 0; - RawHtml(TextStream! { yield HTML_BEGINNING.to_string(); - loop { - let brave = brave_ref.lock().await; - - let len = brave.results.len(); - - if len == 0 { - drop(brave); - tokio::task::yield_now().await; - continue - } - - let completed = completed_ref.lock().await; - if *completed && current_index == len - 1 { - break - } - drop(completed); - - if !brave_first_result_has_yielded { - let diff = brave_first_result_start.elapsed().whole_milliseconds(); - brave_first_result_has_yielded = true; + while let Some(result) = rx.recv().await { + if !first_result_yielded { + let diff = first_result_start.elapsed().whole_milliseconds(); + first_result_yielded = true; yield format!("Time taken: {}ms", diff); } - for ii in (current_index + 1)..len { - let result = brave.results.get(ii).unwrap(); + let text = format!("
  • {}

    {}

  • ", &result.title, &result.description); - let text = format!("
  • {}

    {}

  • ", &result.title, &result.description); - - yield text.to_string(); - } - drop(brave); - tokio::task::yield_now().await; - - // [1] -> 0 - // 1 -> [1] - current_index = len - 1; + yield text.to_string(); } - let diff = brave_first_result_start.elapsed().whole_milliseconds(); + let diff = first_result_start.elapsed().whole_milliseconds(); yield format!("End taken: {}ms", diff); yield HTML_END.to_string(); })