mirror of
https://github.com/Syfaro/bkapi.git
synced 2024-11-21 14:34:08 +00:00
commit
510da555e6
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@ -63,15 +63,3 @@ jobs:
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: bkapi/Dockerfile
|
||||
|
||||
sourcegraph:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Generate LSIF data
|
||||
uses: sourcegraph/lsif-rust-action@main
|
||||
- name: Upload LSIF data
|
||||
uses: sourcegraph/lsif-upload-action@master
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
||||
/target
|
||||
.env
|
||||
|
1331
Cargo.lock
generated
1331
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -6,14 +6,15 @@ edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
tracing = "0.1"
|
||||
tracing-opentelemetry = "0.17"
|
||||
tracing-opentelemetry = "0.18"
|
||||
|
||||
opentelemetry = "0.17"
|
||||
opentelemetry-http = "0.6"
|
||||
opentelemetry = "0.18"
|
||||
opentelemetry-http = "0.7"
|
||||
|
||||
futures = "0.3"
|
||||
|
||||
reqwest = { version = "0.11", features = ["json"] }
|
||||
async-nats = "0.21"
|
||||
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
@ -85,7 +85,7 @@ impl BKApiClient {
|
||||
) -> Result<Vec<SearchResults>, reqwest::Error> {
|
||||
let mut futs = futures::stream::FuturesOrdered::new();
|
||||
for hash in hashes {
|
||||
futs.push(self.search(*hash, distance));
|
||||
futs.push_back(self.search(*hash, distance));
|
||||
}
|
||||
|
||||
futs.try_collect().await
|
||||
@ -111,6 +111,77 @@ impl InjectContext for reqwest::RequestBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// The BKApi client, operating over NATS instead of HTTP.
|
||||
#[derive(Clone)]
|
||||
pub struct BKApiNatsClient {
|
||||
client: async_nats::Client,
|
||||
}
|
||||
|
||||
/// A hash and distance.
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
pub struct HashDistance {
|
||||
hash: i64,
|
||||
distance: u32,
|
||||
}
|
||||
|
||||
impl BKApiNatsClient {
|
||||
const NATS_SUBJECT: &str = "bkapi.search";
|
||||
|
||||
/// Create a new client with a given NATS client.
|
||||
pub fn new(client: async_nats::Client) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
|
||||
/// Search for a single hash.
|
||||
pub async fn search(
|
||||
&self,
|
||||
hash: i64,
|
||||
distance: i32,
|
||||
) -> Result<SearchResults, async_nats::Error> {
|
||||
let hashes = [HashDistance {
|
||||
hash,
|
||||
distance: distance as u32,
|
||||
}];
|
||||
|
||||
self.search_many(&hashes)
|
||||
.await
|
||||
.map(|mut results| results.remove(0))
|
||||
}
|
||||
|
||||
/// Search many hashes at once.
|
||||
pub async fn search_many(
|
||||
&self,
|
||||
hashes: &[HashDistance],
|
||||
) -> Result<Vec<SearchResults>, async_nats::Error> {
|
||||
let payload = serde_json::to_vec(hashes).unwrap();
|
||||
|
||||
let message = self
|
||||
.client
|
||||
.request(Self::NATS_SUBJECT.to_string(), payload.into())
|
||||
.await?;
|
||||
|
||||
let results: Vec<Vec<HashDistance>> = serde_json::from_slice(&message.payload).unwrap();
|
||||
|
||||
let results = results
|
||||
.into_iter()
|
||||
.zip(hashes)
|
||||
.map(|(results, search)| SearchResults {
|
||||
hash: search.hash,
|
||||
distance: search.distance as u64,
|
||||
hashes: results
|
||||
.into_iter()
|
||||
.map(|result| SearchResult {
|
||||
hash: result.hash,
|
||||
distance: result.distance as u64,
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
fn get_test_endpoint() -> String {
|
||||
|
@ -5,16 +5,17 @@ authors = ["Syfaro <syfaro@huefox.com>"]
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
envconfig = "0.10"
|
||||
dotenvy = "0.15"
|
||||
clap = { version = "4", features = ["derive", "env"] }
|
||||
thiserror = "1"
|
||||
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
tracing-unwrap = "0.9"
|
||||
tracing-opentelemetry = "0.17"
|
||||
tracing-opentelemetry = "0.18"
|
||||
|
||||
opentelemetry = { version = "0.17", features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] }
|
||||
opentelemetry = { version = "0.18", features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.17", features = ["rt-tokio"] }
|
||||
|
||||
lazy_static = "1"
|
||||
prometheus = { version = "0.13", features = ["process"] }
|
||||
@ -31,8 +32,10 @@ serde_json = "1"
|
||||
actix-web = "4"
|
||||
actix-http = "3"
|
||||
actix-service = "2"
|
||||
tracing-actix-web = { version = "0.5", features = ["opentelemetry_0_17"] }
|
||||
tracing-actix-web = { version = "0.6", features = ["opentelemetry_0_18"] }
|
||||
|
||||
async-nats = "0.21"
|
||||
|
||||
[dependencies.sqlx]
|
||||
version = "0.5"
|
||||
version = "0.6"
|
||||
features = ["runtime-actix-rustls", "postgres"]
|
||||
|
@ -6,11 +6,12 @@ use actix_web::{
|
||||
web::{self, Data},
|
||||
App, HttpResponse, HttpServer, Responder,
|
||||
};
|
||||
use envconfig::Envconfig;
|
||||
use clap::Parser;
|
||||
use futures::StreamExt;
|
||||
use opentelemetry::KeyValue;
|
||||
use prometheus::{Encoder, TextEncoder};
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::Instrument;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_unwrap::ResultExt;
|
||||
|
||||
@ -19,9 +20,6 @@ mod tree;
|
||||
lazy_static::lazy_static! {
|
||||
static ref HTTP_REQUEST_COUNT: prometheus::CounterVec = prometheus::register_counter_vec!("http_requests_total", "Number of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
|
||||
static ref HTTP_REQUEST_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("http_request_duration_seconds", "Duration of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
|
||||
|
||||
static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap();
|
||||
static ref TREE_ADD_DURATION: prometheus::Histogram = prometheus::register_histogram!("bkapi_tree_add_duration_seconds", "Duration to add new item to tree").unwrap();
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -32,34 +30,58 @@ enum Error {
|
||||
Listener(sqlx::Error),
|
||||
#[error("listener got data that could not be decoded: {0}")]
|
||||
Data(serde_json::Error),
|
||||
#[error("nats encountered error: {0}")]
|
||||
Nats(#[from] async_nats::Error),
|
||||
#[error("io error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[derive(Envconfig, Clone)]
|
||||
#[derive(Parser, Clone)]
|
||||
struct Config {
|
||||
#[envconfig(default = "0.0.0.0:3000")]
|
||||
/// Host to listen for incoming HTTP requests.
|
||||
#[clap(long, env, default_value = "127.0.0.1:3000")]
|
||||
http_listen: String,
|
||||
#[envconfig(default = "127.0.0.1:6831")]
|
||||
|
||||
/// Jaeger agent endpoint for span collection.
|
||||
#[clap(long, env, default_value = "127.0.0.1:6831")]
|
||||
jaeger_agent: String,
|
||||
#[envconfig(default = "bkapi")]
|
||||
/// Service name for spans.
|
||||
#[clap(long, env, default_value = "bkapi")]
|
||||
service_name: String,
|
||||
|
||||
/// Database URL for fetching data.
|
||||
#[clap(long, env)]
|
||||
database_url: String,
|
||||
/// Query to perform to fetch initial values.
|
||||
#[clap(long, env)]
|
||||
database_query: String,
|
||||
database_subscribe: String,
|
||||
#[envconfig(default = "false")]
|
||||
database_is_unique: bool,
|
||||
|
||||
max_distance: Option<u32>,
|
||||
/// If provided, the Postgres notification topic to subscribe to.
|
||||
#[clap(long, env)]
|
||||
database_subscribe: Option<String>,
|
||||
|
||||
/// The NATS host.
|
||||
#[clap(long, env)]
|
||||
nats_host: Option<String>,
|
||||
/// The NATS NKEY.
|
||||
#[clap(long, env)]
|
||||
nats_nkey: Option<String>,
|
||||
|
||||
/// Maximum distance permitted in queries.
|
||||
#[clap(long, env, default_value = "10")]
|
||||
max_distance: u32,
|
||||
}
|
||||
|
||||
#[actix_web::main]
|
||||
async fn main() {
|
||||
let config = Config::init_from_env().expect("could not load config");
|
||||
let _ = dotenvy::dotenv();
|
||||
|
||||
let config = Config::parse();
|
||||
configure_tracing(&config);
|
||||
|
||||
tracing::info!("starting bkbase");
|
||||
tracing::info!("starting bkapi");
|
||||
|
||||
let tree: tree::Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(tree::Hamming)));
|
||||
let tree = tree::Tree::new();
|
||||
|
||||
tracing::trace!("connecting to postgres");
|
||||
let pool = PgPoolOptions::new()
|
||||
@ -69,18 +91,46 @@ async fn main() {
|
||||
.expect_or_log("could not connect to database");
|
||||
tracing::debug!("connected to postgres");
|
||||
|
||||
let http_listen = config.http_listen.clone();
|
||||
|
||||
let (sender, receiver) = futures::channel::oneshot::channel();
|
||||
|
||||
tracing::info!("starting to listen for payloads");
|
||||
let client = match (config.nats_host.as_deref(), config.nats_nkey.as_deref()) {
|
||||
(Some(host), None) => Some(
|
||||
async_nats::connect(host)
|
||||
.await
|
||||
.expect_or_log("could not connect to nats with no nkey"),
|
||||
),
|
||||
(Some(host), Some(nkey)) => Some(
|
||||
async_nats::ConnectOptions::with_nkey(nkey.to_string())
|
||||
.connect(host)
|
||||
.await
|
||||
.expect_or_log("could not connect to nats with nkey"),
|
||||
),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let tree_clone = tree.clone();
|
||||
let config_clone = config.clone();
|
||||
tokio::task::spawn(async {
|
||||
tree::listen_for_payloads(pool, config_clone, tree_clone, sender)
|
||||
.await
|
||||
.expect_or_log("listenting for updates failed");
|
||||
});
|
||||
if let Some(subscription) = config.database_subscribe.clone() {
|
||||
tracing::info!("starting to listen for payloads from postgres");
|
||||
|
||||
let query = config.database_query.clone();
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
tree::listen_for_payloads_db(pool, subscription, query, tree_clone, sender)
|
||||
.await
|
||||
.unwrap_or_log();
|
||||
});
|
||||
} else if let Some(client) = client.clone() {
|
||||
tracing::info!("starting to listen for payloads from nats");
|
||||
|
||||
tokio::task::spawn(async {
|
||||
tree::listen_for_payloads_nats(config_clone, pool, client, tree_clone, sender)
|
||||
.await
|
||||
.unwrap_or_log();
|
||||
});
|
||||
} else {
|
||||
panic!("no listener source available");
|
||||
};
|
||||
|
||||
tracing::info!("waiting for initial tree to load");
|
||||
receiver
|
||||
@ -88,8 +138,56 @@ async fn main() {
|
||||
.expect_or_log("tree loading was dropped before completing");
|
||||
tracing::info!("initial tree loaded, starting server");
|
||||
|
||||
if let Some(client) = client {
|
||||
let tree_clone = tree.clone();
|
||||
let config_clone = config.clone();
|
||||
tokio::task::spawn(async move {
|
||||
search_nats(client, tree_clone, config_clone)
|
||||
.await
|
||||
.unwrap_or_log();
|
||||
});
|
||||
}
|
||||
|
||||
start_server(config, tree).await.unwrap_or_log();
|
||||
}
|
||||
|
||||
fn configure_tracing(config: &Config) {
|
||||
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new());
|
||||
|
||||
let env = std::env::var("ENVIRONMENT");
|
||||
let env = if let Ok(env) = env.as_ref() {
|
||||
env.as_str()
|
||||
} else if cfg!(debug_assertions) {
|
||||
"debug"
|
||||
} else {
|
||||
"release"
|
||||
};
|
||||
|
||||
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
||||
.with_endpoint(&config.jaeger_agent)
|
||||
.with_service_name(&config.service_name)
|
||||
.with_trace_config(opentelemetry::sdk::trace::config().with_resource(
|
||||
opentelemetry::sdk::Resource::new(vec![
|
||||
KeyValue::new("environment", env.to_owned()),
|
||||
KeyValue::new("version", env!("CARGO_PKG_VERSION")),
|
||||
]),
|
||||
))
|
||||
.install_batch(opentelemetry::runtime::Tokio)
|
||||
.expect("otel jaeger pipeline could not be created");
|
||||
|
||||
let trace = tracing_opentelemetry::layer().with_tracer(tracer);
|
||||
tracing::subscriber::set_global_default(
|
||||
tracing_subscriber::Registry::default()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(trace)
|
||||
.with(tracing_subscriber::fmt::layer()),
|
||||
)
|
||||
.expect("tracing could not be configured");
|
||||
}
|
||||
|
||||
async fn start_server(config: Config, tree: tree::Tree) -> Result<(), Error> {
|
||||
let tree = Data::new(tree);
|
||||
let config = Data::new(config);
|
||||
let config_data = Data::new(config.clone());
|
||||
|
||||
HttpServer::new(move || {
|
||||
App::new()
|
||||
@ -117,48 +215,32 @@ async fn main() {
|
||||
}
|
||||
})
|
||||
.app_data(tree.clone())
|
||||
.app_data(config.clone())
|
||||
.app_data(config_data.clone())
|
||||
.service(search)
|
||||
.service(health)
|
||||
.service(metrics)
|
||||
})
|
||||
.bind(&http_listen)
|
||||
.bind(&config.http_listen)
|
||||
.expect_or_log("bind failed")
|
||||
.run()
|
||||
.await
|
||||
.expect_or_log("server failed");
|
||||
.map_err(Error::Io)
|
||||
}
|
||||
|
||||
fn configure_tracing(config: &Config) {
|
||||
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new());
|
||||
#[get("/health")]
|
||||
async fn health() -> impl Responder {
|
||||
"OK"
|
||||
}
|
||||
|
||||
let env = std::env::var("ENVIRONMENT");
|
||||
let env = if let Ok(env) = env.as_ref() {
|
||||
env.as_str()
|
||||
} else if cfg!(debug_assertions) {
|
||||
"debug"
|
||||
} else {
|
||||
"release"
|
||||
};
|
||||
#[get("/metrics")]
|
||||
async fn metrics() -> HttpResponse {
|
||||
let mut buffer = Vec::new();
|
||||
let encoder = TextEncoder::new();
|
||||
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
.with_agent_endpoint(&config.jaeger_agent)
|
||||
.with_service_name(&config.service_name)
|
||||
.with_tags(vec![
|
||||
KeyValue::new("environment", env.to_owned()),
|
||||
KeyValue::new("version", env!("CARGO_PKG_VERSION")),
|
||||
])
|
||||
.install_batch(opentelemetry::runtime::Tokio)
|
||||
.expect("otel jaeger pipeline could not be created");
|
||||
let metric_families = prometheus::gather();
|
||||
encoder.encode(&metric_families, &mut buffer).unwrap();
|
||||
|
||||
let trace = tracing_opentelemetry::layer().with_tracer(tracer);
|
||||
tracing::subscriber::set_global_default(
|
||||
tracing_subscriber::Registry::default()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(trace)
|
||||
.with(tracing_subscriber::fmt::layer()),
|
||||
)
|
||||
.expect("tracing could not be configured");
|
||||
HttpResponse::Ok().body(buffer)
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
@ -167,18 +249,12 @@ struct Query {
|
||||
distance: u32,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct HashDistance {
|
||||
hash: i64,
|
||||
distance: u32,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct SearchResponse {
|
||||
hash: i64,
|
||||
distance: u32,
|
||||
|
||||
hashes: Vec<HashDistance>,
|
||||
hashes: Vec<tree::HashDistance>,
|
||||
}
|
||||
|
||||
#[get("/search")]
|
||||
@ -187,54 +263,104 @@ async fn search(
|
||||
query: web::Query<Query>,
|
||||
tree: Data<tree::Tree>,
|
||||
config: Data<Config>,
|
||||
) -> Result<HttpResponse, std::convert::Infallible> {
|
||||
) -> HttpResponse {
|
||||
let Query { hash, distance } = query.0;
|
||||
let max_distance = config.max_distance;
|
||||
let distance = distance.clamp(0, config.max_distance);
|
||||
|
||||
tracing::info!("searching for hash {} with distance {}", hash, distance);
|
||||
|
||||
if matches!(max_distance, Some(max_distance) if distance > max_distance) {
|
||||
return Ok(HttpResponse::BadRequest().body("distance is greater than max distance"));
|
||||
}
|
||||
|
||||
let tree = tree.read().await;
|
||||
|
||||
let duration = TREE_DURATION
|
||||
.with_label_values(&[&distance.to_string()])
|
||||
.start_timer();
|
||||
let matches: Vec<HashDistance> = tree
|
||||
.find(&hash.into(), distance)
|
||||
.into_iter()
|
||||
.map(|item| HashDistance {
|
||||
distance: item.0,
|
||||
hash: (*item.1).into(),
|
||||
})
|
||||
.collect();
|
||||
let time = duration.stop_and_record();
|
||||
|
||||
tracing::debug!("found {} items in {} seconds", matches.len(), time);
|
||||
let hashes = tree
|
||||
.find([tree::HashDistance { hash, distance }])
|
||||
.await
|
||||
.remove(0);
|
||||
|
||||
let resp = SearchResponse {
|
||||
hash,
|
||||
distance,
|
||||
hashes: matches,
|
||||
hashes,
|
||||
};
|
||||
|
||||
Ok(HttpResponse::Ok().json(resp))
|
||||
HttpResponse::Ok().json(resp)
|
||||
}
|
||||
|
||||
#[get("/health")]
|
||||
async fn health() -> impl Responder {
|
||||
"OK"
|
||||
#[derive(serde::Deserialize)]
|
||||
struct SearchPayload {
|
||||
hash: i64,
|
||||
distance: u32,
|
||||
}
|
||||
|
||||
#[get("/metrics")]
|
||||
async fn metrics() -> Result<HttpResponse, std::convert::Infallible> {
|
||||
let mut buffer = Vec::new();
|
||||
let encoder = TextEncoder::new();
|
||||
#[tracing::instrument(skip(client, tree, config))]
|
||||
async fn search_nats(
|
||||
client: async_nats::Client,
|
||||
tree: tree::Tree,
|
||||
config: Config,
|
||||
) -> Result<(), Error> {
|
||||
tracing::info!("subscribing to searches");
|
||||
|
||||
let metric_families = prometheus::gather();
|
||||
encoder.encode(&metric_families, &mut buffer).unwrap();
|
||||
let client = Arc::new(client);
|
||||
let max_distance = config.max_distance;
|
||||
|
||||
Ok(HttpResponse::Ok().body(buffer))
|
||||
let mut sub = client
|
||||
.queue_subscribe("bkapi.search".to_string(), "bkapi-search".to_string())
|
||||
.await?;
|
||||
|
||||
while let Some(message) = sub.next().await {
|
||||
tracing::trace!("got search message");
|
||||
|
||||
let reply = match message.reply {
|
||||
Some(reply) => reply,
|
||||
None => {
|
||||
tracing::warn!("message had no reply subject, skipping");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = handle_search_nats(
|
||||
max_distance,
|
||||
client.clone(),
|
||||
tree.clone(),
|
||||
reply,
|
||||
&message.payload,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!("could not handle nats search: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_search_nats(
|
||||
max_distance: u32,
|
||||
client: Arc<async_nats::Client>,
|
||||
tree: tree::Tree,
|
||||
reply: String,
|
||||
payload: &[u8],
|
||||
) -> Result<(), Error> {
|
||||
let payloads: Vec<SearchPayload> = serde_json::from_slice(payload).map_err(Error::Data)?;
|
||||
|
||||
tokio::task::spawn(
|
||||
async move {
|
||||
let hashes = payloads.into_iter().map(|payload| tree::HashDistance {
|
||||
hash: payload.hash,
|
||||
distance: payload.distance.clamp(0, max_distance),
|
||||
});
|
||||
|
||||
let results = tree.find(hashes).await;
|
||||
|
||||
if let Err(err) = client
|
||||
.publish(
|
||||
reply,
|
||||
serde_json::to_vec(&results)
|
||||
.expect_or_log("results could not be serialized")
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!("could not publish results: {err}");
|
||||
}
|
||||
}
|
||||
.in_current_span(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,15 +1,136 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use bk_tree::BKTree;
|
||||
use futures::TryStreamExt;
|
||||
use sqlx::{postgres::PgListener, Pool, Postgres, Row};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing_unwrap::ResultExt;
|
||||
|
||||
use crate::{Config, Error};
|
||||
|
||||
pub(crate) type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
|
||||
lazy_static::lazy_static! {
|
||||
static ref TREE_ENTRIES: prometheus::IntCounter = prometheus::register_int_counter!("bkapi_tree_entries", "Total number of entries within tree").unwrap();
|
||||
|
||||
static ref TREE_ADD_DURATION: prometheus::Histogram = prometheus::register_histogram!("bkapi_tree_add_duration_seconds", "Duration to add new item to tree").unwrap();
|
||||
static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap();
|
||||
}
|
||||
|
||||
/// A BKTree wrapper to cover common operations.
|
||||
#[derive(Clone)]
|
||||
pub struct Tree {
|
||||
tree: Arc<RwLock<BKTree<Node, Hamming>>>,
|
||||
}
|
||||
|
||||
/// A hash and distance pair. May be used for searching or in search results.
|
||||
#[derive(serde::Serialize)]
|
||||
pub struct HashDistance {
|
||||
pub hash: i64,
|
||||
pub distance: u32,
|
||||
}
|
||||
|
||||
impl Tree {
|
||||
/// Create an empty tree.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tree: Arc::new(RwLock::new(BKTree::new(Hamming))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace tree contents with the results of a SQL query.
|
||||
///
|
||||
/// The tree is only replaced after it finishes loading, leaving stale/empty
|
||||
/// data available while running.
|
||||
pub(crate) async fn reload(&self, pool: &sqlx::PgPool, query: &str) -> Result<(), Error> {
|
||||
let mut new_tree = BKTree::new(Hamming);
|
||||
let mut rows = sqlx::query(query).fetch(pool);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let mut count = 0;
|
||||
|
||||
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
|
||||
let node: Node = row.get::<i64, _>(0).into();
|
||||
|
||||
if new_tree.find_exact(&node).is_none() {
|
||||
new_tree.add(node);
|
||||
}
|
||||
|
||||
count += 1;
|
||||
if count % 250_000 == 0 {
|
||||
tracing::debug!(count, "loaded more rows");
|
||||
}
|
||||
}
|
||||
|
||||
let dur = std::time::Instant::now().duration_since(start);
|
||||
tracing::info!(count, "completed loading rows in {:?}", dur);
|
||||
|
||||
let mut tree = self.tree.write().await;
|
||||
*tree = new_tree;
|
||||
|
||||
TREE_ENTRIES.reset();
|
||||
TREE_ENTRIES.inc_by(count);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a hash to the tree, returning if it already existed.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub async fn add(&self, hash: i64) -> bool {
|
||||
let node = Node::from(hash);
|
||||
|
||||
let is_new_hash = {
|
||||
let tree = self.tree.read().await;
|
||||
tree.find_exact(&node).is_none()
|
||||
};
|
||||
|
||||
if is_new_hash {
|
||||
let mut tree = self.tree.write().await;
|
||||
tree.add(node);
|
||||
TREE_ENTRIES.inc();
|
||||
}
|
||||
|
||||
tracing::info!(is_new_hash, "added hash");
|
||||
|
||||
is_new_hash
|
||||
}
|
||||
|
||||
/// Attempt to find any number of hashes within the tree.
|
||||
pub async fn find<H>(&self, hashes: H) -> Vec<Vec<HashDistance>>
|
||||
where
|
||||
H: IntoIterator<Item = HashDistance>,
|
||||
{
|
||||
let tree = self.tree.read().await;
|
||||
|
||||
hashes
|
||||
.into_iter()
|
||||
.map(|HashDistance { hash, distance }| Self::search(&tree, hash, distance))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Search a read-locked tree for a hash with a given distance.
|
||||
#[tracing::instrument(skip(tree))]
|
||||
fn search(tree: &BKTree<Node, Hamming>, hash: i64, distance: u32) -> Vec<HashDistance> {
|
||||
tracing::debug!("searching tree");
|
||||
|
||||
let duration = TREE_DURATION
|
||||
.with_label_values(&[&distance.to_string()])
|
||||
.start_timer();
|
||||
let results: Vec<_> = tree
|
||||
.find(&hash.into(), distance)
|
||||
.into_iter()
|
||||
.map(|item| HashDistance {
|
||||
distance: item.0,
|
||||
hash: (*item.1).into(),
|
||||
})
|
||||
.collect();
|
||||
let time = duration.stop_and_record();
|
||||
|
||||
tracing::info!(time, results = results.len(), "found results");
|
||||
results
|
||||
}
|
||||
}
|
||||
|
||||
/// A hamming distance metric.
|
||||
pub(crate) struct Hamming;
|
||||
struct Hamming;
|
||||
|
||||
impl bk_tree::Metric<Node> for Hamming {
|
||||
fn distance(&self, a: &Node, b: &Node) -> u32 {
|
||||
@ -24,7 +145,7 @@ impl bk_tree::Metric<Node> for Hamming {
|
||||
|
||||
/// A value of a node in the BK tree.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct Node([u8; 8]);
|
||||
struct Node([u8; 8]);
|
||||
|
||||
impl From<i64> for Node {
|
||||
fn from(num: i64) -> Self {
|
||||
@ -38,106 +159,110 @@ impl From<Node> for i64 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new BK tree and pull in all hashes from provided query.
|
||||
///
|
||||
/// This must be called after you have started a listener, otherwise items may
|
||||
/// be lost.
|
||||
async fn create_tree(
|
||||
conn: &Pool<Postgres>,
|
||||
config: &Config,
|
||||
) -> Result<bk_tree::BKTree<Node, Hamming>, Error> {
|
||||
use futures::TryStreamExt;
|
||||
|
||||
tracing::warn!("creating new tree");
|
||||
let mut tree = bk_tree::BKTree::new(Hamming);
|
||||
let mut rows = sqlx::query(&config.database_query).fetch(conn);
|
||||
|
||||
let mut count = 0;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
|
||||
let node: Node = row.get::<i64, _>(0).into();
|
||||
|
||||
// Avoid checking if each value is unique if we were told that the
|
||||
// database query only returns unique values.
|
||||
let timer = crate::TREE_ADD_DURATION.start_timer();
|
||||
if config.database_is_unique || tree.find_exact(&node).is_none() {
|
||||
tree.add(node);
|
||||
}
|
||||
timer.stop_and_record();
|
||||
|
||||
count += 1;
|
||||
if count % 250_000 == 0 {
|
||||
tracing::debug!(count, "loaded more rows");
|
||||
}
|
||||
}
|
||||
|
||||
let dur = std::time::Instant::now().duration_since(start);
|
||||
|
||||
tracing::info!(count, "completed loading rows in {:?}", dur);
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// A payload for a new hash to add to the index from Postgres or NATS.
|
||||
#[derive(serde::Deserialize)]
|
||||
struct Payload {
|
||||
hash: i64,
|
||||
}
|
||||
|
||||
/// Listen for incoming payloads.
|
||||
///
|
||||
/// This will create a new tree to ensure all items are present. It will also
|
||||
/// automatically recreate trees as needed if the database connection is lost.
|
||||
pub(crate) async fn listen_for_payloads(
|
||||
/// Listen for incoming payloads from Postgres.
|
||||
#[tracing::instrument(skip(conn, subscription, query, tree, initial))]
|
||||
pub(crate) async fn listen_for_payloads_db(
|
||||
conn: Pool<Postgres>,
|
||||
config: Config,
|
||||
subscription: String,
|
||||
query: String,
|
||||
tree: Tree,
|
||||
initial: futures::channel::oneshot::Sender<()>,
|
||||
) -> Result<(), Error> {
|
||||
let mut listener = PgListener::connect_with(&conn)
|
||||
.await
|
||||
.map_err(Error::Listener)?;
|
||||
listener
|
||||
.listen(&config.database_subscribe)
|
||||
.await
|
||||
.map_err(Error::Listener)?;
|
||||
|
||||
let new_tree = create_tree(&conn, &config).await?;
|
||||
{
|
||||
let mut tree = tree.write().await;
|
||||
*tree = new_tree;
|
||||
}
|
||||
|
||||
initial
|
||||
.send(())
|
||||
.expect_or_log("nothing listening for initial data");
|
||||
let mut initial = Some(initial);
|
||||
|
||||
loop {
|
||||
let mut listener = PgListener::connect_with(&conn)
|
||||
.await
|
||||
.map_err(Error::Listener)?;
|
||||
listener
|
||||
.listen(&subscription)
|
||||
.await
|
||||
.map_err(Error::Listener)?;
|
||||
|
||||
tree.reload(&conn, &query).await?;
|
||||
|
||||
if let Some(initial) = initial.take() {
|
||||
initial
|
||||
.send(())
|
||||
.expect_or_log("nothing listening for initial data");
|
||||
}
|
||||
|
||||
while let Some(notification) = listener.try_recv().await.map_err(Error::Listener)? {
|
||||
let payload: Payload =
|
||||
serde_json::from_str(notification.payload()).map_err(Error::Data)?;
|
||||
tracing::debug!(hash = payload.hash, "evaluating new payload");
|
||||
|
||||
let node: Node = payload.hash.into();
|
||||
|
||||
let _timer = crate::TREE_ADD_DURATION.start_timer();
|
||||
|
||||
let mut tree = tree.write().await;
|
||||
if tree.find_exact(&node).is_some() {
|
||||
tracing::trace!("hash already existed in tree");
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::trace!("hash did not exist, adding to tree");
|
||||
tree.add(node);
|
||||
tracing::trace!("got postgres payload");
|
||||
process_payload(&tree, notification.payload().as_bytes()).await?;
|
||||
}
|
||||
|
||||
tracing::error!("disconnected from listener, recreating tree");
|
||||
tracing::error!("disconnected from postgres listener, recreating tree");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||
let new_tree = create_tree(&conn, &config).await?;
|
||||
{
|
||||
let mut tree = tree.write().await;
|
||||
*tree = new_tree;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Listen for incoming payloads from NATS.
|
||||
#[tracing::instrument(skip(config, pool, client, tree, initial))]
|
||||
pub(crate) async fn listen_for_payloads_nats(
|
||||
config: Config,
|
||||
pool: sqlx::PgPool,
|
||||
client: async_nats::Client,
|
||||
tree: Tree,
|
||||
initial: futures::channel::oneshot::Sender<()>,
|
||||
) -> Result<(), Error> {
|
||||
let jetstream = async_nats::jetstream::new(client);
|
||||
let mut initial = Some(initial);
|
||||
|
||||
let stream = jetstream
|
||||
.get_or_create_stream(async_nats::jetstream::stream::Config {
|
||||
name: "bkapi-hashes".to_string(),
|
||||
subjects: vec!["bkapi.add".to_string()],
|
||||
max_age: std::time::Duration::from_secs(60 * 60 * 24),
|
||||
retention: async_nats::jetstream::stream::RetentionPolicy::Interest,
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let consumer = stream
|
||||
.get_or_create_consumer(
|
||||
"bkapi-consumer",
|
||||
async_nats::jetstream::consumer::pull::Config {
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
tree.reload(&pool, &config.database_query).await?;
|
||||
|
||||
if let Some(initial) = initial.take() {
|
||||
initial
|
||||
.send(())
|
||||
.expect_or_log("nothing listening for initial data");
|
||||
}
|
||||
|
||||
let mut messages = consumer.messages().await?;
|
||||
|
||||
while let Ok(Some(message)) = messages.try_next().await {
|
||||
tracing::trace!("got nats payload");
|
||||
message.ack().await?;
|
||||
process_payload(&tree, &message.payload).await?;
|
||||
}
|
||||
|
||||
tracing::error!("disconnected from nats listener, recreating tree");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a payload from Postgres or NATS and add to the tree.
|
||||
async fn process_payload(tree: &Tree, payload: &[u8]) -> Result<(), Error> {
|
||||
let payload: Payload = serde_json::from_slice(payload).map_err(Error::Data)?;
|
||||
tracing::trace!("got hash: {}", payload.hash);
|
||||
|
||||
let _timer = TREE_ADD_DURATION.start_timer();
|
||||
tree.add(payload.hash).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user