From 45ecb8e5ad3cddeaa9042140a022ea735a7b5659 Mon Sep 17 00:00:00 2001 From: Syfaro Date: Wed, 12 Oct 2022 22:18:46 -0400 Subject: [PATCH] Code cleanup. --- bkapi/src/main.rs | 85 +++++++------------- bkapi/src/tree.rs | 192 +++++++++++++++++++++++++++++----------------- 2 files changed, 151 insertions(+), 126 deletions(-) diff --git a/bkapi/src/main.rs b/bkapi/src/main.rs index 789a8b2..168d6d4 100644 --- a/bkapi/src/main.rs +++ b/bkapi/src/main.rs @@ -11,7 +11,7 @@ use futures::StreamExt; use opentelemetry::KeyValue; use prometheus::{Encoder, TextEncoder}; use sqlx::postgres::PgPoolOptions; -use tokio::sync::RwLock; +use tracing::Instrument; use tracing_subscriber::layer::SubscriberExt; use tracing_unwrap::ResultExt; @@ -20,8 +20,6 @@ mod tree; lazy_static::lazy_static! { static ref HTTP_REQUEST_COUNT: prometheus::CounterVec = prometheus::register_counter_vec!("http_requests_total", "Number of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap(); static ref HTTP_REQUEST_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("http_request_duration_seconds", "Duration of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap(); - - static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap(); } #[derive(thiserror::Error, Debug)] @@ -83,7 +81,7 @@ async fn main() { tracing::info!("starting bkapi"); - let tree: tree::Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(tree::Hamming))); + let tree = tree::Tree::new(); tracing::trace!("connecting to postgres"); let pool = PgPoolOptions::new() @@ -235,14 +233,14 @@ async fn health() -> impl Responder { } #[get("/metrics")] -async fn metrics() -> Result { +async fn metrics() -> HttpResponse { let mut buffer = Vec::new(); let encoder = TextEncoder::new(); let metric_families = prometheus::gather(); encoder.encode(&metric_families, &mut buffer).unwrap(); - Ok(HttpResponse::Ok().body(buffer)) + HttpResponse::Ok().body(buffer) } #[derive(Debug, serde::Deserialize)] @@ -251,18 +249,12 @@ struct Query { distance: u32, } -#[derive(serde::Serialize)] -struct HashDistance { - hash: i64, - distance: u32, -} - #[derive(serde::Serialize)] struct SearchResponse { hash: i64, distance: u32, - hashes: Vec, + hashes: Vec, } #[get("/search")] @@ -275,9 +267,10 @@ async fn search( let Query { hash, distance } = query.0; let distance = distance.clamp(0, config.max_distance); - let tree = tree.read().await; - let hashes = search_tree(&tree, hash, distance); - drop(tree); + let hashes = tree + .find([tree::HashDistance { hash, distance }]) + .await + .remove(0); let resp = SearchResponse { hash, @@ -294,6 +287,7 @@ struct SearchPayload { distance: u32, } +#[tracing::instrument(skip(client, tree, config))] async fn search_nats( client: async_nats::Client, tree: tree::Tree, @@ -301,6 +295,8 @@ async fn search_nats( ) -> Result<(), Error> { tracing::info!("subscribing to searches"); + let client = Arc::new(client); + let mut sub = client .queue_subscribe("bkapi.search".to_string(), "bkapi-search".to_string()) .await?; @@ -311,7 +307,7 @@ async fn search_nats( let reply = match message.reply { Some(reply) => reply, None => { - tracing::warn!("message had no reply subject"); + tracing::warn!("message had no reply subject, skipping"); continue; } }; @@ -320,51 +316,26 @@ async fn search_nats( serde_json::from_slice(&message.payload).map_err(Error::Data)?; let tree = tree.clone(); - let config = config.clone(); let client = client.clone(); + let max_distance = config.max_distance; - tokio::task::spawn(async move { - let tree = tree.read().await; + tokio::task::spawn( + async move { + let hashes = payloads.into_iter().map(|payload| tree::HashDistance { + hash: payload.hash, + distance: payload.distance.clamp(0, max_distance), + }); - let results: Vec<_> = payloads - .into_iter() - .map(|payload| (payload.hash, payload.distance.clamp(0, config.max_distance))) - .map(|(hash, distance)| search_tree(&tree, hash, distance)) - .collect(); + let results = tree.find(hashes).await; - drop(tree); - - client - .publish(reply, serde_json::to_vec(&results).unwrap_or_log().into()) - .await - .unwrap_or_log(); - }); + client + .publish(reply, serde_json::to_vec(&results).unwrap_or_log().into()) + .await + .unwrap_or_log(); + } + .in_current_span(), + ); } Ok(()) } - -#[tracing::instrument(skip(tree))] -fn search_tree( - tree: &bk_tree::BKTree, - hash: i64, - distance: u32, -) -> Vec { - tracing::debug!("searching tree"); - - let duration = TREE_DURATION - .with_label_values(&[&distance.to_string()]) - .start_timer(); - let results: Vec<_> = tree - .find(&hash.into(), distance) - .into_iter() - .map(|item| HashDistance { - distance: item.0, - hash: (*item.1).into(), - }) - .collect(); - let time = duration.stop_and_record(); - - tracing::info!(time, results = results.len(), "found results"); - results -} diff --git a/bkapi/src/tree.rs b/bkapi/src/tree.rs index 1e7ecb8..1ad3a9f 100644 --- a/bkapi/src/tree.rs +++ b/bkapi/src/tree.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use bk_tree::BKTree; use futures::TryStreamExt; use sqlx::{postgres::PgListener, Pool, Postgres, Row}; use tokio::sync::RwLock; @@ -9,12 +10,121 @@ use crate::{Config, Error}; lazy_static::lazy_static! { static ref TREE_ADD_DURATION: prometheus::Histogram = prometheus::register_histogram!("bkapi_tree_add_duration_seconds", "Duration to add new item to tree").unwrap(); + static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap(); } -pub(crate) type Tree = Arc>>; +/// A BKTree wrapper to cover common operations. +#[derive(Clone)] +pub struct Tree { + tree: Arc>>, +} + +/// A hash and distance pair. May be used for searching or in search results. +#[derive(serde::Serialize)] +pub struct HashDistance { + pub hash: i64, + pub distance: u32, +} + +impl Tree { + /// Create an empty tree. + pub fn new() -> Self { + Self { + tree: Arc::new(RwLock::new(BKTree::new(Hamming))), + } + } + + /// Replace tree contents with the results of a SQL query. + /// + /// The tree is only replaced after it finishes loading, leaving stale/empty + /// data available while running. + pub(crate) async fn reload(&self, pool: &sqlx::PgPool, query: &str) -> Result<(), Error> { + let mut new_tree = BKTree::new(Hamming); + let mut rows = sqlx::query(query).fetch(pool); + + let start = std::time::Instant::now(); + let mut count = 0; + + while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? { + let node: Node = row.get::(0).into(); + + if new_tree.find_exact(&node).is_none() { + new_tree.add(node); + } + + count += 1; + if count % 250_000 == 0 { + tracing::debug!(count, "loaded more rows"); + } + } + + let dur = std::time::Instant::now().duration_since(start); + tracing::info!(count, "completed loading rows in {:?}", dur); + + let mut tree = self.tree.write().await; + *tree = new_tree; + + Ok(()) + } + + /// Add a hash to the tree, returning if it already existed. + #[tracing::instrument(skip(self))] + pub async fn add(&self, hash: i64) -> bool { + let node = Node::from(hash); + + let is_new_hash = { + let tree = self.tree.read().await; + tree.find_exact(&node).is_none() + }; + + if is_new_hash { + let mut tree = self.tree.write().await; + tree.add(node); + } + + tracing::info!(is_new_hash, "added hash"); + + is_new_hash + } + + /// Attempt to find any number of hashes within the tree. + pub async fn find(&self, hashes: H) -> Vec> + where + H: IntoIterator, + { + let tree = self.tree.read().await; + + hashes + .into_iter() + .map(|HashDistance { hash, distance }| Self::search(&tree, hash, distance)) + .collect() + } + + /// Search a read-locked tree for a hash with a given distance. + #[tracing::instrument(skip(tree))] + fn search(tree: &BKTree, hash: i64, distance: u32) -> Vec { + tracing::debug!("searching tree"); + + let duration = TREE_DURATION + .with_label_values(&[&distance.to_string()]) + .start_timer(); + let results: Vec<_> = tree + .find(&hash.into(), distance) + .into_iter() + .map(|item| HashDistance { + distance: item.0, + hash: (*item.1).into(), + }) + .collect(); + let time = duration.stop_and_record(); + + tracing::info!(time, results = results.len(), "found results"); + results + } +} /// A hamming distance metric. -pub(crate) struct Hamming; +struct Hamming; impl bk_tree::Metric for Hamming { fn distance(&self, a: &Node, b: &Node) -> u32 { @@ -29,7 +139,7 @@ impl bk_tree::Metric for Hamming { /// A value of a node in the BK tree. #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) struct Node([u8; 8]); +struct Node([u8; 8]); impl From for Node { fn from(num: i64) -> Self { @@ -43,50 +153,14 @@ impl From for i64 { } } -/// Create a new BK tree and pull in all hashes from provided query. -/// -/// This must be called after you have started a listener, otherwise items may -/// be lost. -async fn create_tree( - conn: &Pool, - query: &str, -) -> Result, Error> { - tracing::warn!("creating new tree"); - let mut tree = bk_tree::BKTree::new(Hamming); - let mut rows = sqlx::query(query).fetch(conn); - - let mut count = 0; - - let start = std::time::Instant::now(); - - while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? { - let node: Node = row.get::(0).into(); - - if tree.find_exact(&node).is_none() { - tree.add(node); - } - - count += 1; - if count % 250_000 == 0 { - tracing::debug!(count, "loaded more rows"); - } - } - - let dur = std::time::Instant::now().duration_since(start); - - tracing::info!(count, "completed loading rows in {:?}", dur); - Ok(tree) -} - +/// A payload for a new hash to add to the index from Postgres or NATS. #[derive(serde::Deserialize)] struct Payload { hash: i64, } -/// Listen for incoming payloads. -/// -/// This will create a new tree to ensure all items are present. It will also -/// automatically recreate trees as needed if the database connection is lost. +/// Listen for incoming payloads from Postgres. +#[tracing::instrument(skip(conn, subscription, query, tree, initial))] pub(crate) async fn listen_for_payloads_db( conn: Pool, subscription: String, @@ -105,11 +179,7 @@ pub(crate) async fn listen_for_payloads_db( .await .map_err(Error::Listener)?; - let new_tree = create_tree(&conn, &query).await?; - { - let mut tree = tree.write().await; - *tree = new_tree; - } + tree.reload(&conn, &query).await?; if let Some(initial) = initial.take() { initial @@ -127,6 +197,8 @@ pub(crate) async fn listen_for_payloads_db( } } +/// Listen for incoming payloads from NATS. +#[tracing::instrument(skip(config, pool, client, tree, initial))] pub(crate) async fn listen_for_payloads_nats( config: Config, pool: sqlx::PgPool, @@ -134,14 +206,12 @@ pub(crate) async fn listen_for_payloads_nats( tree: Tree, initial: futures::channel::oneshot::Sender<()>, ) -> Result<(), Error> { - static STREAM_NAME: &str = "bkapi-hashes"; - let jetstream = async_nats::jetstream::new(client); let mut initial = Some(initial); let stream = jetstream .get_or_create_stream(async_nats::jetstream::stream::Config { - name: STREAM_NAME.to_string(), + name: "bkapi-hashes".to_string(), subjects: vec!["bkapi.add".to_string()], max_age: std::time::Duration::from_secs(60 * 60 * 24), retention: async_nats::jetstream::stream::RetentionPolicy::Interest, @@ -159,11 +229,7 @@ pub(crate) async fn listen_for_payloads_nats( ) .await?; - let new_tree = create_tree(&pool, &config.database_query).await?; - { - let mut tree = tree.write().await; - *tree = new_tree; - } + tree.reload(&pool, &config.database_query).await?; if let Some(initial) = initial.take() { initial @@ -184,25 +250,13 @@ pub(crate) async fn listen_for_payloads_nats( } } +/// Process a payload from Postgres or NATS and add to the tree. async fn process_payload(tree: &Tree, payload: &[u8]) -> Result<(), Error> { let payload: Payload = serde_json::from_slice(payload).map_err(Error::Data)?; tracing::trace!("got hash: {}", payload.hash); - let node: Node = payload.hash.into(); - let _timer = TREE_ADD_DURATION.start_timer(); - - let is_new_hash = { - let tree = tree.read().await; - tree.find_exact(&node).is_none() - }; - - if is_new_hash { - let mut tree = tree.write().await; - tree.add(node); - } - - tracing::debug!(hash = payload.hash, is_new_hash, "processed incoming hash"); + tree.add(payload.hash).await; Ok(()) }