Code cleanup.

This commit is contained in:
Syfaro 2022-10-12 22:18:46 -04:00
parent 876a6bee55
commit 45ecb8e5ad
2 changed files with 151 additions and 126 deletions

View File

@ -11,7 +11,7 @@ use futures::StreamExt;
use opentelemetry::KeyValue;
use prometheus::{Encoder, TextEncoder};
use sqlx::postgres::PgPoolOptions;
use tokio::sync::RwLock;
use tracing::Instrument;
use tracing_subscriber::layer::SubscriberExt;
use tracing_unwrap::ResultExt;
@ -20,8 +20,6 @@ mod tree;
lazy_static::lazy_static! {
static ref HTTP_REQUEST_COUNT: prometheus::CounterVec = prometheus::register_counter_vec!("http_requests_total", "Number of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
static ref HTTP_REQUEST_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("http_request_duration_seconds", "Duration of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap();
}
#[derive(thiserror::Error, Debug)]
@ -83,7 +81,7 @@ async fn main() {
tracing::info!("starting bkapi");
let tree: tree::Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(tree::Hamming)));
let tree = tree::Tree::new();
tracing::trace!("connecting to postgres");
let pool = PgPoolOptions::new()
@ -235,14 +233,14 @@ async fn health() -> impl Responder {
}
#[get("/metrics")]
async fn metrics() -> Result<HttpResponse, std::convert::Infallible> {
async fn metrics() -> HttpResponse {
let mut buffer = Vec::new();
let encoder = TextEncoder::new();
let metric_families = prometheus::gather();
encoder.encode(&metric_families, &mut buffer).unwrap();
Ok(HttpResponse::Ok().body(buffer))
HttpResponse::Ok().body(buffer)
}
#[derive(Debug, serde::Deserialize)]
@ -251,18 +249,12 @@ struct Query {
distance: u32,
}
#[derive(serde::Serialize)]
struct HashDistance {
hash: i64,
distance: u32,
}
#[derive(serde::Serialize)]
struct SearchResponse {
hash: i64,
distance: u32,
hashes: Vec<HashDistance>,
hashes: Vec<tree::HashDistance>,
}
#[get("/search")]
@ -275,9 +267,10 @@ async fn search(
let Query { hash, distance } = query.0;
let distance = distance.clamp(0, config.max_distance);
let tree = tree.read().await;
let hashes = search_tree(&tree, hash, distance);
drop(tree);
let hashes = tree
.find([tree::HashDistance { hash, distance }])
.await
.remove(0);
let resp = SearchResponse {
hash,
@ -294,6 +287,7 @@ struct SearchPayload {
distance: u32,
}
#[tracing::instrument(skip(client, tree, config))]
async fn search_nats(
client: async_nats::Client,
tree: tree::Tree,
@ -301,6 +295,8 @@ async fn search_nats(
) -> Result<(), Error> {
tracing::info!("subscribing to searches");
let client = Arc::new(client);
let mut sub = client
.queue_subscribe("bkapi.search".to_string(), "bkapi-search".to_string())
.await?;
@ -311,7 +307,7 @@ async fn search_nats(
let reply = match message.reply {
Some(reply) => reply,
None => {
tracing::warn!("message had no reply subject");
tracing::warn!("message had no reply subject, skipping");
continue;
}
};
@ -320,51 +316,26 @@ async fn search_nats(
serde_json::from_slice(&message.payload).map_err(Error::Data)?;
let tree = tree.clone();
let config = config.clone();
let client = client.clone();
let max_distance = config.max_distance;
tokio::task::spawn(async move {
let tree = tree.read().await;
tokio::task::spawn(
async move {
let hashes = payloads.into_iter().map(|payload| tree::HashDistance {
hash: payload.hash,
distance: payload.distance.clamp(0, max_distance),
});
let results: Vec<_> = payloads
.into_iter()
.map(|payload| (payload.hash, payload.distance.clamp(0, config.max_distance)))
.map(|(hash, distance)| search_tree(&tree, hash, distance))
.collect();
drop(tree);
let results = tree.find(hashes).await;
client
.publish(reply, serde_json::to_vec(&results).unwrap_or_log().into())
.await
.unwrap_or_log();
});
}
.in_current_span(),
);
}
Ok(())
}
#[tracing::instrument(skip(tree))]
fn search_tree(
tree: &bk_tree::BKTree<tree::Node, tree::Hamming>,
hash: i64,
distance: u32,
) -> Vec<HashDistance> {
tracing::debug!("searching tree");
let duration = TREE_DURATION
.with_label_values(&[&distance.to_string()])
.start_timer();
let results: Vec<_> = tree
.find(&hash.into(), distance)
.into_iter()
.map(|item| HashDistance {
distance: item.0,
hash: (*item.1).into(),
})
.collect();
let time = duration.stop_and_record();
tracing::info!(time, results = results.len(), "found results");
results
}

View File

@ -1,5 +1,6 @@
use std::sync::Arc;
use bk_tree::BKTree;
use futures::TryStreamExt;
use sqlx::{postgres::PgListener, Pool, Postgres, Row};
use tokio::sync::RwLock;
@ -9,12 +10,121 @@ use crate::{Config, Error};
lazy_static::lazy_static! {
static ref TREE_ADD_DURATION: prometheus::Histogram = prometheus::register_histogram!("bkapi_tree_add_duration_seconds", "Duration to add new item to tree").unwrap();
static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap();
}
pub(crate) type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
/// A BKTree wrapper to cover common operations.
#[derive(Clone)]
pub struct Tree {
tree: Arc<RwLock<BKTree<Node, Hamming>>>,
}
/// A hash and distance pair. May be used for searching or in search results.
#[derive(serde::Serialize)]
pub struct HashDistance {
pub hash: i64,
pub distance: u32,
}
impl Tree {
/// Create an empty tree.
pub fn new() -> Self {
Self {
tree: Arc::new(RwLock::new(BKTree::new(Hamming))),
}
}
/// Replace tree contents with the results of a SQL query.
///
/// The tree is only replaced after it finishes loading, leaving stale/empty
/// data available while running.
pub(crate) async fn reload(&self, pool: &sqlx::PgPool, query: &str) -> Result<(), Error> {
let mut new_tree = BKTree::new(Hamming);
let mut rows = sqlx::query(query).fetch(pool);
let start = std::time::Instant::now();
let mut count = 0;
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
let node: Node = row.get::<i64, _>(0).into();
if new_tree.find_exact(&node).is_none() {
new_tree.add(node);
}
count += 1;
if count % 250_000 == 0 {
tracing::debug!(count, "loaded more rows");
}
}
let dur = std::time::Instant::now().duration_since(start);
tracing::info!(count, "completed loading rows in {:?}", dur);
let mut tree = self.tree.write().await;
*tree = new_tree;
Ok(())
}
/// Add a hash to the tree, returning if it already existed.
#[tracing::instrument(skip(self))]
pub async fn add(&self, hash: i64) -> bool {
let node = Node::from(hash);
let is_new_hash = {
let tree = self.tree.read().await;
tree.find_exact(&node).is_none()
};
if is_new_hash {
let mut tree = self.tree.write().await;
tree.add(node);
}
tracing::info!(is_new_hash, "added hash");
is_new_hash
}
/// Attempt to find any number of hashes within the tree.
pub async fn find<H>(&self, hashes: H) -> Vec<Vec<HashDistance>>
where
H: IntoIterator<Item = HashDistance>,
{
let tree = self.tree.read().await;
hashes
.into_iter()
.map(|HashDistance { hash, distance }| Self::search(&tree, hash, distance))
.collect()
}
/// Search a read-locked tree for a hash with a given distance.
#[tracing::instrument(skip(tree))]
fn search(tree: &BKTree<Node, Hamming>, hash: i64, distance: u32) -> Vec<HashDistance> {
tracing::debug!("searching tree");
let duration = TREE_DURATION
.with_label_values(&[&distance.to_string()])
.start_timer();
let results: Vec<_> = tree
.find(&hash.into(), distance)
.into_iter()
.map(|item| HashDistance {
distance: item.0,
hash: (*item.1).into(),
})
.collect();
let time = duration.stop_and_record();
tracing::info!(time, results = results.len(), "found results");
results
}
}
/// A hamming distance metric.
pub(crate) struct Hamming;
struct Hamming;
impl bk_tree::Metric<Node> for Hamming {
fn distance(&self, a: &Node, b: &Node) -> u32 {
@ -29,7 +139,7 @@ impl bk_tree::Metric<Node> for Hamming {
/// A value of a node in the BK tree.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct Node([u8; 8]);
struct Node([u8; 8]);
impl From<i64> for Node {
fn from(num: i64) -> Self {
@ -43,50 +153,14 @@ impl From<Node> for i64 {
}
}
/// Create a new BK tree and pull in all hashes from provided query.
///
/// This must be called after you have started a listener, otherwise items may
/// be lost.
async fn create_tree(
conn: &Pool<Postgres>,
query: &str,
) -> Result<bk_tree::BKTree<Node, Hamming>, Error> {
tracing::warn!("creating new tree");
let mut tree = bk_tree::BKTree::new(Hamming);
let mut rows = sqlx::query(query).fetch(conn);
let mut count = 0;
let start = std::time::Instant::now();
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
let node: Node = row.get::<i64, _>(0).into();
if tree.find_exact(&node).is_none() {
tree.add(node);
}
count += 1;
if count % 250_000 == 0 {
tracing::debug!(count, "loaded more rows");
}
}
let dur = std::time::Instant::now().duration_since(start);
tracing::info!(count, "completed loading rows in {:?}", dur);
Ok(tree)
}
/// A payload for a new hash to add to the index from Postgres or NATS.
#[derive(serde::Deserialize)]
struct Payload {
hash: i64,
}
/// Listen for incoming payloads.
///
/// This will create a new tree to ensure all items are present. It will also
/// automatically recreate trees as needed if the database connection is lost.
/// Listen for incoming payloads from Postgres.
#[tracing::instrument(skip(conn, subscription, query, tree, initial))]
pub(crate) async fn listen_for_payloads_db(
conn: Pool<Postgres>,
subscription: String,
@ -105,11 +179,7 @@ pub(crate) async fn listen_for_payloads_db(
.await
.map_err(Error::Listener)?;
let new_tree = create_tree(&conn, &query).await?;
{
let mut tree = tree.write().await;
*tree = new_tree;
}
tree.reload(&conn, &query).await?;
if let Some(initial) = initial.take() {
initial
@ -127,6 +197,8 @@ pub(crate) async fn listen_for_payloads_db(
}
}
/// Listen for incoming payloads from NATS.
#[tracing::instrument(skip(config, pool, client, tree, initial))]
pub(crate) async fn listen_for_payloads_nats(
config: Config,
pool: sqlx::PgPool,
@ -134,14 +206,12 @@ pub(crate) async fn listen_for_payloads_nats(
tree: Tree,
initial: futures::channel::oneshot::Sender<()>,
) -> Result<(), Error> {
static STREAM_NAME: &str = "bkapi-hashes";
let jetstream = async_nats::jetstream::new(client);
let mut initial = Some(initial);
let stream = jetstream
.get_or_create_stream(async_nats::jetstream::stream::Config {
name: STREAM_NAME.to_string(),
name: "bkapi-hashes".to_string(),
subjects: vec!["bkapi.add".to_string()],
max_age: std::time::Duration::from_secs(60 * 60 * 24),
retention: async_nats::jetstream::stream::RetentionPolicy::Interest,
@ -159,11 +229,7 @@ pub(crate) async fn listen_for_payloads_nats(
)
.await?;
let new_tree = create_tree(&pool, &config.database_query).await?;
{
let mut tree = tree.write().await;
*tree = new_tree;
}
tree.reload(&pool, &config.database_query).await?;
if let Some(initial) = initial.take() {
initial
@ -184,25 +250,13 @@ pub(crate) async fn listen_for_payloads_nats(
}
}
/// Process a payload from Postgres or NATS and add to the tree.
async fn process_payload(tree: &Tree, payload: &[u8]) -> Result<(), Error> {
let payload: Payload = serde_json::from_slice(payload).map_err(Error::Data)?;
tracing::trace!("got hash: {}", payload.hash);
let node: Node = payload.hash.into();
let _timer = TREE_ADD_DURATION.start_timer();
let is_new_hash = {
let tree = tree.read().await;
tree.find_exact(&node).is_none()
};
if is_new_hash {
let mut tree = tree.write().await;
tree.add(node);
}
tracing::debug!(hash = payload.hash, is_new_hash, "processed incoming hash");
tree.add(payload.hash).await;
Ok(())
}