mirror of
https://github.com/Syfaro/bkapi.git
synced 2024-11-21 14:34:08 +00:00
Code cleanup.
This commit is contained in:
parent
876a6bee55
commit
45ecb8e5ad
@ -11,7 +11,7 @@ use futures::StreamExt;
|
||||
use opentelemetry::KeyValue;
|
||||
use prometheus::{Encoder, TextEncoder};
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::Instrument;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_unwrap::ResultExt;
|
||||
|
||||
@ -20,8 +20,6 @@ mod tree;
|
||||
lazy_static::lazy_static! {
|
||||
static ref HTTP_REQUEST_COUNT: prometheus::CounterVec = prometheus::register_counter_vec!("http_requests_total", "Number of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
|
||||
static ref HTTP_REQUEST_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("http_request_duration_seconds", "Duration of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
|
||||
|
||||
static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap();
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -83,7 +81,7 @@ async fn main() {
|
||||
|
||||
tracing::info!("starting bkapi");
|
||||
|
||||
let tree: tree::Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(tree::Hamming)));
|
||||
let tree = tree::Tree::new();
|
||||
|
||||
tracing::trace!("connecting to postgres");
|
||||
let pool = PgPoolOptions::new()
|
||||
@ -235,14 +233,14 @@ async fn health() -> impl Responder {
|
||||
}
|
||||
|
||||
#[get("/metrics")]
|
||||
async fn metrics() -> Result<HttpResponse, std::convert::Infallible> {
|
||||
async fn metrics() -> HttpResponse {
|
||||
let mut buffer = Vec::new();
|
||||
let encoder = TextEncoder::new();
|
||||
|
||||
let metric_families = prometheus::gather();
|
||||
encoder.encode(&metric_families, &mut buffer).unwrap();
|
||||
|
||||
Ok(HttpResponse::Ok().body(buffer))
|
||||
HttpResponse::Ok().body(buffer)
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
@ -251,18 +249,12 @@ struct Query {
|
||||
distance: u32,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct HashDistance {
|
||||
hash: i64,
|
||||
distance: u32,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct SearchResponse {
|
||||
hash: i64,
|
||||
distance: u32,
|
||||
|
||||
hashes: Vec<HashDistance>,
|
||||
hashes: Vec<tree::HashDistance>,
|
||||
}
|
||||
|
||||
#[get("/search")]
|
||||
@ -275,9 +267,10 @@ async fn search(
|
||||
let Query { hash, distance } = query.0;
|
||||
let distance = distance.clamp(0, config.max_distance);
|
||||
|
||||
let tree = tree.read().await;
|
||||
let hashes = search_tree(&tree, hash, distance);
|
||||
drop(tree);
|
||||
let hashes = tree
|
||||
.find([tree::HashDistance { hash, distance }])
|
||||
.await
|
||||
.remove(0);
|
||||
|
||||
let resp = SearchResponse {
|
||||
hash,
|
||||
@ -294,6 +287,7 @@ struct SearchPayload {
|
||||
distance: u32,
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(client, tree, config))]
|
||||
async fn search_nats(
|
||||
client: async_nats::Client,
|
||||
tree: tree::Tree,
|
||||
@ -301,6 +295,8 @@ async fn search_nats(
|
||||
) -> Result<(), Error> {
|
||||
tracing::info!("subscribing to searches");
|
||||
|
||||
let client = Arc::new(client);
|
||||
|
||||
let mut sub = client
|
||||
.queue_subscribe("bkapi.search".to_string(), "bkapi-search".to_string())
|
||||
.await?;
|
||||
@ -311,7 +307,7 @@ async fn search_nats(
|
||||
let reply = match message.reply {
|
||||
Some(reply) => reply,
|
||||
None => {
|
||||
tracing::warn!("message had no reply subject");
|
||||
tracing::warn!("message had no reply subject, skipping");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@ -320,51 +316,26 @@ async fn search_nats(
|
||||
serde_json::from_slice(&message.payload).map_err(Error::Data)?;
|
||||
|
||||
let tree = tree.clone();
|
||||
let config = config.clone();
|
||||
let client = client.clone();
|
||||
let max_distance = config.max_distance;
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let tree = tree.read().await;
|
||||
tokio::task::spawn(
|
||||
async move {
|
||||
let hashes = payloads.into_iter().map(|payload| tree::HashDistance {
|
||||
hash: payload.hash,
|
||||
distance: payload.distance.clamp(0, max_distance),
|
||||
});
|
||||
|
||||
let results: Vec<_> = payloads
|
||||
.into_iter()
|
||||
.map(|payload| (payload.hash, payload.distance.clamp(0, config.max_distance)))
|
||||
.map(|(hash, distance)| search_tree(&tree, hash, distance))
|
||||
.collect();
|
||||
|
||||
drop(tree);
|
||||
let results = tree.find(hashes).await;
|
||||
|
||||
client
|
||||
.publish(reply, serde_json::to_vec(&results).unwrap_or_log().into())
|
||||
.await
|
||||
.unwrap_or_log();
|
||||
});
|
||||
}
|
||||
.in_current_span(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(tree))]
|
||||
fn search_tree(
|
||||
tree: &bk_tree::BKTree<tree::Node, tree::Hamming>,
|
||||
hash: i64,
|
||||
distance: u32,
|
||||
) -> Vec<HashDistance> {
|
||||
tracing::debug!("searching tree");
|
||||
|
||||
let duration = TREE_DURATION
|
||||
.with_label_values(&[&distance.to_string()])
|
||||
.start_timer();
|
||||
let results: Vec<_> = tree
|
||||
.find(&hash.into(), distance)
|
||||
.into_iter()
|
||||
.map(|item| HashDistance {
|
||||
distance: item.0,
|
||||
hash: (*item.1).into(),
|
||||
})
|
||||
.collect();
|
||||
let time = duration.stop_and_record();
|
||||
|
||||
tracing::info!(time, results = results.len(), "found results");
|
||||
results
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use bk_tree::BKTree;
|
||||
use futures::TryStreamExt;
|
||||
use sqlx::{postgres::PgListener, Pool, Postgres, Row};
|
||||
use tokio::sync::RwLock;
|
||||
@ -9,12 +10,121 @@ use crate::{Config, Error};
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref TREE_ADD_DURATION: prometheus::Histogram = prometheus::register_histogram!("bkapi_tree_add_duration_seconds", "Duration to add new item to tree").unwrap();
|
||||
static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap();
|
||||
}
|
||||
|
||||
pub(crate) type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
|
||||
/// A BKTree wrapper to cover common operations.
|
||||
#[derive(Clone)]
|
||||
pub struct Tree {
|
||||
tree: Arc<RwLock<BKTree<Node, Hamming>>>,
|
||||
}
|
||||
|
||||
/// A hash and distance pair. May be used for searching or in search results.
|
||||
#[derive(serde::Serialize)]
|
||||
pub struct HashDistance {
|
||||
pub hash: i64,
|
||||
pub distance: u32,
|
||||
}
|
||||
|
||||
impl Tree {
|
||||
/// Create an empty tree.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tree: Arc::new(RwLock::new(BKTree::new(Hamming))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace tree contents with the results of a SQL query.
|
||||
///
|
||||
/// The tree is only replaced after it finishes loading, leaving stale/empty
|
||||
/// data available while running.
|
||||
pub(crate) async fn reload(&self, pool: &sqlx::PgPool, query: &str) -> Result<(), Error> {
|
||||
let mut new_tree = BKTree::new(Hamming);
|
||||
let mut rows = sqlx::query(query).fetch(pool);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let mut count = 0;
|
||||
|
||||
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
|
||||
let node: Node = row.get::<i64, _>(0).into();
|
||||
|
||||
if new_tree.find_exact(&node).is_none() {
|
||||
new_tree.add(node);
|
||||
}
|
||||
|
||||
count += 1;
|
||||
if count % 250_000 == 0 {
|
||||
tracing::debug!(count, "loaded more rows");
|
||||
}
|
||||
}
|
||||
|
||||
let dur = std::time::Instant::now().duration_since(start);
|
||||
tracing::info!(count, "completed loading rows in {:?}", dur);
|
||||
|
||||
let mut tree = self.tree.write().await;
|
||||
*tree = new_tree;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a hash to the tree, returning if it already existed.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub async fn add(&self, hash: i64) -> bool {
|
||||
let node = Node::from(hash);
|
||||
|
||||
let is_new_hash = {
|
||||
let tree = self.tree.read().await;
|
||||
tree.find_exact(&node).is_none()
|
||||
};
|
||||
|
||||
if is_new_hash {
|
||||
let mut tree = self.tree.write().await;
|
||||
tree.add(node);
|
||||
}
|
||||
|
||||
tracing::info!(is_new_hash, "added hash");
|
||||
|
||||
is_new_hash
|
||||
}
|
||||
|
||||
/// Attempt to find any number of hashes within the tree.
|
||||
pub async fn find<H>(&self, hashes: H) -> Vec<Vec<HashDistance>>
|
||||
where
|
||||
H: IntoIterator<Item = HashDistance>,
|
||||
{
|
||||
let tree = self.tree.read().await;
|
||||
|
||||
hashes
|
||||
.into_iter()
|
||||
.map(|HashDistance { hash, distance }| Self::search(&tree, hash, distance))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Search a read-locked tree for a hash with a given distance.
|
||||
#[tracing::instrument(skip(tree))]
|
||||
fn search(tree: &BKTree<Node, Hamming>, hash: i64, distance: u32) -> Vec<HashDistance> {
|
||||
tracing::debug!("searching tree");
|
||||
|
||||
let duration = TREE_DURATION
|
||||
.with_label_values(&[&distance.to_string()])
|
||||
.start_timer();
|
||||
let results: Vec<_> = tree
|
||||
.find(&hash.into(), distance)
|
||||
.into_iter()
|
||||
.map(|item| HashDistance {
|
||||
distance: item.0,
|
||||
hash: (*item.1).into(),
|
||||
})
|
||||
.collect();
|
||||
let time = duration.stop_and_record();
|
||||
|
||||
tracing::info!(time, results = results.len(), "found results");
|
||||
results
|
||||
}
|
||||
}
|
||||
|
||||
/// A hamming distance metric.
|
||||
pub(crate) struct Hamming;
|
||||
struct Hamming;
|
||||
|
||||
impl bk_tree::Metric<Node> for Hamming {
|
||||
fn distance(&self, a: &Node, b: &Node) -> u32 {
|
||||
@ -29,7 +139,7 @@ impl bk_tree::Metric<Node> for Hamming {
|
||||
|
||||
/// A value of a node in the BK tree.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct Node([u8; 8]);
|
||||
struct Node([u8; 8]);
|
||||
|
||||
impl From<i64> for Node {
|
||||
fn from(num: i64) -> Self {
|
||||
@ -43,50 +153,14 @@ impl From<Node> for i64 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new BK tree and pull in all hashes from provided query.
|
||||
///
|
||||
/// This must be called after you have started a listener, otherwise items may
|
||||
/// be lost.
|
||||
async fn create_tree(
|
||||
conn: &Pool<Postgres>,
|
||||
query: &str,
|
||||
) -> Result<bk_tree::BKTree<Node, Hamming>, Error> {
|
||||
tracing::warn!("creating new tree");
|
||||
let mut tree = bk_tree::BKTree::new(Hamming);
|
||||
let mut rows = sqlx::query(query).fetch(conn);
|
||||
|
||||
let mut count = 0;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
|
||||
let node: Node = row.get::<i64, _>(0).into();
|
||||
|
||||
if tree.find_exact(&node).is_none() {
|
||||
tree.add(node);
|
||||
}
|
||||
|
||||
count += 1;
|
||||
if count % 250_000 == 0 {
|
||||
tracing::debug!(count, "loaded more rows");
|
||||
}
|
||||
}
|
||||
|
||||
let dur = std::time::Instant::now().duration_since(start);
|
||||
|
||||
tracing::info!(count, "completed loading rows in {:?}", dur);
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// A payload for a new hash to add to the index from Postgres or NATS.
|
||||
#[derive(serde::Deserialize)]
|
||||
struct Payload {
|
||||
hash: i64,
|
||||
}
|
||||
|
||||
/// Listen for incoming payloads.
|
||||
///
|
||||
/// This will create a new tree to ensure all items are present. It will also
|
||||
/// automatically recreate trees as needed if the database connection is lost.
|
||||
/// Listen for incoming payloads from Postgres.
|
||||
#[tracing::instrument(skip(conn, subscription, query, tree, initial))]
|
||||
pub(crate) async fn listen_for_payloads_db(
|
||||
conn: Pool<Postgres>,
|
||||
subscription: String,
|
||||
@ -105,11 +179,7 @@ pub(crate) async fn listen_for_payloads_db(
|
||||
.await
|
||||
.map_err(Error::Listener)?;
|
||||
|
||||
let new_tree = create_tree(&conn, &query).await?;
|
||||
{
|
||||
let mut tree = tree.write().await;
|
||||
*tree = new_tree;
|
||||
}
|
||||
tree.reload(&conn, &query).await?;
|
||||
|
||||
if let Some(initial) = initial.take() {
|
||||
initial
|
||||
@ -127,6 +197,8 @@ pub(crate) async fn listen_for_payloads_db(
|
||||
}
|
||||
}
|
||||
|
||||
/// Listen for incoming payloads from NATS.
|
||||
#[tracing::instrument(skip(config, pool, client, tree, initial))]
|
||||
pub(crate) async fn listen_for_payloads_nats(
|
||||
config: Config,
|
||||
pool: sqlx::PgPool,
|
||||
@ -134,14 +206,12 @@ pub(crate) async fn listen_for_payloads_nats(
|
||||
tree: Tree,
|
||||
initial: futures::channel::oneshot::Sender<()>,
|
||||
) -> Result<(), Error> {
|
||||
static STREAM_NAME: &str = "bkapi-hashes";
|
||||
|
||||
let jetstream = async_nats::jetstream::new(client);
|
||||
let mut initial = Some(initial);
|
||||
|
||||
let stream = jetstream
|
||||
.get_or_create_stream(async_nats::jetstream::stream::Config {
|
||||
name: STREAM_NAME.to_string(),
|
||||
name: "bkapi-hashes".to_string(),
|
||||
subjects: vec!["bkapi.add".to_string()],
|
||||
max_age: std::time::Duration::from_secs(60 * 60 * 24),
|
||||
retention: async_nats::jetstream::stream::RetentionPolicy::Interest,
|
||||
@ -159,11 +229,7 @@ pub(crate) async fn listen_for_payloads_nats(
|
||||
)
|
||||
.await?;
|
||||
|
||||
let new_tree = create_tree(&pool, &config.database_query).await?;
|
||||
{
|
||||
let mut tree = tree.write().await;
|
||||
*tree = new_tree;
|
||||
}
|
||||
tree.reload(&pool, &config.database_query).await?;
|
||||
|
||||
if let Some(initial) = initial.take() {
|
||||
initial
|
||||
@ -184,25 +250,13 @@ pub(crate) async fn listen_for_payloads_nats(
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a payload from Postgres or NATS and add to the tree.
|
||||
async fn process_payload(tree: &Tree, payload: &[u8]) -> Result<(), Error> {
|
||||
let payload: Payload = serde_json::from_slice(payload).map_err(Error::Data)?;
|
||||
tracing::trace!("got hash: {}", payload.hash);
|
||||
|
||||
let node: Node = payload.hash.into();
|
||||
|
||||
let _timer = TREE_ADD_DURATION.start_timer();
|
||||
|
||||
let is_new_hash = {
|
||||
let tree = tree.read().await;
|
||||
tree.find_exact(&node).is_none()
|
||||
};
|
||||
|
||||
if is_new_hash {
|
||||
let mut tree = tree.write().await;
|
||||
tree.add(node);
|
||||
}
|
||||
|
||||
tracing::debug!(hash = payload.hash, is_new_hash, "processed incoming hash");
|
||||
tree.add(payload.hash).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user