mirror of
https://github.com/Syfaro/bkapi.git
synced 2024-11-05 14:44:29 +00:00
commit
510da555e6
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@ -63,15 +63,3 @@ jobs:
|
|||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
file: bkapi/Dockerfile
|
file: bkapi/Dockerfile
|
||||||
|
|
||||||
sourcegraph:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
- name: Generate LSIF data
|
|
||||||
uses: sourcegraph/lsif-rust-action@main
|
|
||||||
- name: Upload LSIF data
|
|
||||||
uses: sourcegraph/lsif-upload-action@master
|
|
||||||
with:
|
|
||||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
/target
|
/target
|
||||||
|
.env
|
||||||
|
1331
Cargo.lock
generated
1331
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -6,14 +6,15 @@ edition = "2018"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-opentelemetry = "0.17"
|
tracing-opentelemetry = "0.18"
|
||||||
|
|
||||||
opentelemetry = "0.17"
|
opentelemetry = "0.18"
|
||||||
opentelemetry-http = "0.6"
|
opentelemetry-http = "0.7"
|
||||||
|
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
|
|
||||||
reqwest = { version = "0.11", features = ["json"] }
|
reqwest = { version = "0.11", features = ["json"] }
|
||||||
|
async-nats = "0.21"
|
||||||
|
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
|
@ -85,7 +85,7 @@ impl BKApiClient {
|
|||||||
) -> Result<Vec<SearchResults>, reqwest::Error> {
|
) -> Result<Vec<SearchResults>, reqwest::Error> {
|
||||||
let mut futs = futures::stream::FuturesOrdered::new();
|
let mut futs = futures::stream::FuturesOrdered::new();
|
||||||
for hash in hashes {
|
for hash in hashes {
|
||||||
futs.push(self.search(*hash, distance));
|
futs.push_back(self.search(*hash, distance));
|
||||||
}
|
}
|
||||||
|
|
||||||
futs.try_collect().await
|
futs.try_collect().await
|
||||||
@ -111,6 +111,77 @@ impl InjectContext for reqwest::RequestBuilder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The BKApi client, operating over NATS instead of HTTP.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct BKApiNatsClient {
|
||||||
|
client: async_nats::Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A hash and distance.
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct HashDistance {
|
||||||
|
hash: i64,
|
||||||
|
distance: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BKApiNatsClient {
|
||||||
|
const NATS_SUBJECT: &str = "bkapi.search";
|
||||||
|
|
||||||
|
/// Create a new client with a given NATS client.
|
||||||
|
pub fn new(client: async_nats::Client) -> Self {
|
||||||
|
Self { client }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search for a single hash.
|
||||||
|
pub async fn search(
|
||||||
|
&self,
|
||||||
|
hash: i64,
|
||||||
|
distance: i32,
|
||||||
|
) -> Result<SearchResults, async_nats::Error> {
|
||||||
|
let hashes = [HashDistance {
|
||||||
|
hash,
|
||||||
|
distance: distance as u32,
|
||||||
|
}];
|
||||||
|
|
||||||
|
self.search_many(&hashes)
|
||||||
|
.await
|
||||||
|
.map(|mut results| results.remove(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search many hashes at once.
|
||||||
|
pub async fn search_many(
|
||||||
|
&self,
|
||||||
|
hashes: &[HashDistance],
|
||||||
|
) -> Result<Vec<SearchResults>, async_nats::Error> {
|
||||||
|
let payload = serde_json::to_vec(hashes).unwrap();
|
||||||
|
|
||||||
|
let message = self
|
||||||
|
.client
|
||||||
|
.request(Self::NATS_SUBJECT.to_string(), payload.into())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let results: Vec<Vec<HashDistance>> = serde_json::from_slice(&message.payload).unwrap();
|
||||||
|
|
||||||
|
let results = results
|
||||||
|
.into_iter()
|
||||||
|
.zip(hashes)
|
||||||
|
.map(|(results, search)| SearchResults {
|
||||||
|
hash: search.hash,
|
||||||
|
distance: search.distance as u64,
|
||||||
|
hashes: results
|
||||||
|
.into_iter()
|
||||||
|
.map(|result| SearchResult {
|
||||||
|
hash: result.hash,
|
||||||
|
distance: result.distance as u64,
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
fn get_test_endpoint() -> String {
|
fn get_test_endpoint() -> String {
|
||||||
|
@ -5,16 +5,17 @@ authors = ["Syfaro <syfaro@huefox.com>"]
|
|||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
envconfig = "0.10"
|
dotenvy = "0.15"
|
||||||
|
clap = { version = "4", features = ["derive", "env"] }
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
|
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
tracing-unwrap = "0.9"
|
tracing-unwrap = "0.9"
|
||||||
tracing-opentelemetry = "0.17"
|
tracing-opentelemetry = "0.18"
|
||||||
|
|
||||||
opentelemetry = { version = "0.17", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.18", features = ["rt-tokio"] }
|
||||||
opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] }
|
opentelemetry-jaeger = { version = "0.17", features = ["rt-tokio"] }
|
||||||
|
|
||||||
lazy_static = "1"
|
lazy_static = "1"
|
||||||
prometheus = { version = "0.13", features = ["process"] }
|
prometheus = { version = "0.13", features = ["process"] }
|
||||||
@ -31,8 +32,10 @@ serde_json = "1"
|
|||||||
actix-web = "4"
|
actix-web = "4"
|
||||||
actix-http = "3"
|
actix-http = "3"
|
||||||
actix-service = "2"
|
actix-service = "2"
|
||||||
tracing-actix-web = { version = "0.5", features = ["opentelemetry_0_17"] }
|
tracing-actix-web = { version = "0.6", features = ["opentelemetry_0_18"] }
|
||||||
|
|
||||||
|
async-nats = "0.21"
|
||||||
|
|
||||||
[dependencies.sqlx]
|
[dependencies.sqlx]
|
||||||
version = "0.5"
|
version = "0.6"
|
||||||
features = ["runtime-actix-rustls", "postgres"]
|
features = ["runtime-actix-rustls", "postgres"]
|
||||||
|
@ -6,11 +6,12 @@ use actix_web::{
|
|||||||
web::{self, Data},
|
web::{self, Data},
|
||||||
App, HttpResponse, HttpServer, Responder,
|
App, HttpResponse, HttpServer, Responder,
|
||||||
};
|
};
|
||||||
use envconfig::Envconfig;
|
use clap::Parser;
|
||||||
|
use futures::StreamExt;
|
||||||
use opentelemetry::KeyValue;
|
use opentelemetry::KeyValue;
|
||||||
use prometheus::{Encoder, TextEncoder};
|
use prometheus::{Encoder, TextEncoder};
|
||||||
use sqlx::postgres::PgPoolOptions;
|
use sqlx::postgres::PgPoolOptions;
|
||||||
use tokio::sync::RwLock;
|
use tracing::Instrument;
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
use tracing_unwrap::ResultExt;
|
use tracing_unwrap::ResultExt;
|
||||||
|
|
||||||
@ -19,9 +20,6 @@ mod tree;
|
|||||||
lazy_static::lazy_static! {
|
lazy_static::lazy_static! {
|
||||||
static ref HTTP_REQUEST_COUNT: prometheus::CounterVec = prometheus::register_counter_vec!("http_requests_total", "Number of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
|
static ref HTTP_REQUEST_COUNT: prometheus::CounterVec = prometheus::register_counter_vec!("http_requests_total", "Number of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
|
||||||
static ref HTTP_REQUEST_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("http_request_duration_seconds", "Duration of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
|
static ref HTTP_REQUEST_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("http_request_duration_seconds", "Duration of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
|
||||||
|
|
||||||
static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap();
|
|
||||||
static ref TREE_ADD_DURATION: prometheus::Histogram = prometheus::register_histogram!("bkapi_tree_add_duration_seconds", "Duration to add new item to tree").unwrap();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -32,34 +30,58 @@ enum Error {
|
|||||||
Listener(sqlx::Error),
|
Listener(sqlx::Error),
|
||||||
#[error("listener got data that could not be decoded: {0}")]
|
#[error("listener got data that could not be decoded: {0}")]
|
||||||
Data(serde_json::Error),
|
Data(serde_json::Error),
|
||||||
|
#[error("nats encountered error: {0}")]
|
||||||
|
Nats(#[from] async_nats::Error),
|
||||||
|
#[error("io error: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Envconfig, Clone)]
|
#[derive(Parser, Clone)]
|
||||||
struct Config {
|
struct Config {
|
||||||
#[envconfig(default = "0.0.0.0:3000")]
|
/// Host to listen for incoming HTTP requests.
|
||||||
|
#[clap(long, env, default_value = "127.0.0.1:3000")]
|
||||||
http_listen: String,
|
http_listen: String,
|
||||||
#[envconfig(default = "127.0.0.1:6831")]
|
|
||||||
|
/// Jaeger agent endpoint for span collection.
|
||||||
|
#[clap(long, env, default_value = "127.0.0.1:6831")]
|
||||||
jaeger_agent: String,
|
jaeger_agent: String,
|
||||||
#[envconfig(default = "bkapi")]
|
/// Service name for spans.
|
||||||
|
#[clap(long, env, default_value = "bkapi")]
|
||||||
service_name: String,
|
service_name: String,
|
||||||
|
|
||||||
|
/// Database URL for fetching data.
|
||||||
|
#[clap(long, env)]
|
||||||
database_url: String,
|
database_url: String,
|
||||||
|
/// Query to perform to fetch initial values.
|
||||||
|
#[clap(long, env)]
|
||||||
database_query: String,
|
database_query: String,
|
||||||
database_subscribe: String,
|
|
||||||
#[envconfig(default = "false")]
|
|
||||||
database_is_unique: bool,
|
|
||||||
|
|
||||||
max_distance: Option<u32>,
|
/// If provided, the Postgres notification topic to subscribe to.
|
||||||
|
#[clap(long, env)]
|
||||||
|
database_subscribe: Option<String>,
|
||||||
|
|
||||||
|
/// The NATS host.
|
||||||
|
#[clap(long, env)]
|
||||||
|
nats_host: Option<String>,
|
||||||
|
/// The NATS NKEY.
|
||||||
|
#[clap(long, env)]
|
||||||
|
nats_nkey: Option<String>,
|
||||||
|
|
||||||
|
/// Maximum distance permitted in queries.
|
||||||
|
#[clap(long, env, default_value = "10")]
|
||||||
|
max_distance: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::main]
|
#[actix_web::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let config = Config::init_from_env().expect("could not load config");
|
let _ = dotenvy::dotenv();
|
||||||
|
|
||||||
|
let config = Config::parse();
|
||||||
configure_tracing(&config);
|
configure_tracing(&config);
|
||||||
|
|
||||||
tracing::info!("starting bkbase");
|
tracing::info!("starting bkapi");
|
||||||
|
|
||||||
let tree: tree::Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(tree::Hamming)));
|
let tree = tree::Tree::new();
|
||||||
|
|
||||||
tracing::trace!("connecting to postgres");
|
tracing::trace!("connecting to postgres");
|
||||||
let pool = PgPoolOptions::new()
|
let pool = PgPoolOptions::new()
|
||||||
@ -69,18 +91,46 @@ async fn main() {
|
|||||||
.expect_or_log("could not connect to database");
|
.expect_or_log("could not connect to database");
|
||||||
tracing::debug!("connected to postgres");
|
tracing::debug!("connected to postgres");
|
||||||
|
|
||||||
let http_listen = config.http_listen.clone();
|
|
||||||
|
|
||||||
let (sender, receiver) = futures::channel::oneshot::channel();
|
let (sender, receiver) = futures::channel::oneshot::channel();
|
||||||
|
|
||||||
tracing::info!("starting to listen for payloads");
|
let client = match (config.nats_host.as_deref(), config.nats_nkey.as_deref()) {
|
||||||
|
(Some(host), None) => Some(
|
||||||
|
async_nats::connect(host)
|
||||||
|
.await
|
||||||
|
.expect_or_log("could not connect to nats with no nkey"),
|
||||||
|
),
|
||||||
|
(Some(host), Some(nkey)) => Some(
|
||||||
|
async_nats::ConnectOptions::with_nkey(nkey.to_string())
|
||||||
|
.connect(host)
|
||||||
|
.await
|
||||||
|
.expect_or_log("could not connect to nats with nkey"),
|
||||||
|
),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
let tree_clone = tree.clone();
|
let tree_clone = tree.clone();
|
||||||
let config_clone = config.clone();
|
let config_clone = config.clone();
|
||||||
tokio::task::spawn(async {
|
if let Some(subscription) = config.database_subscribe.clone() {
|
||||||
tree::listen_for_payloads(pool, config_clone, tree_clone, sender)
|
tracing::info!("starting to listen for payloads from postgres");
|
||||||
|
|
||||||
|
let query = config.database_query.clone();
|
||||||
|
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
tree::listen_for_payloads_db(pool, subscription, query, tree_clone, sender)
|
||||||
.await
|
.await
|
||||||
.expect_or_log("listenting for updates failed");
|
.unwrap_or_log();
|
||||||
});
|
});
|
||||||
|
} else if let Some(client) = client.clone() {
|
||||||
|
tracing::info!("starting to listen for payloads from nats");
|
||||||
|
|
||||||
|
tokio::task::spawn(async {
|
||||||
|
tree::listen_for_payloads_nats(config_clone, pool, client, tree_clone, sender)
|
||||||
|
.await
|
||||||
|
.unwrap_or_log();
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
panic!("no listener source available");
|
||||||
|
};
|
||||||
|
|
||||||
tracing::info!("waiting for initial tree to load");
|
tracing::info!("waiting for initial tree to load");
|
||||||
receiver
|
receiver
|
||||||
@ -88,8 +138,56 @@ async fn main() {
|
|||||||
.expect_or_log("tree loading was dropped before completing");
|
.expect_or_log("tree loading was dropped before completing");
|
||||||
tracing::info!("initial tree loaded, starting server");
|
tracing::info!("initial tree loaded, starting server");
|
||||||
|
|
||||||
|
if let Some(client) = client {
|
||||||
|
let tree_clone = tree.clone();
|
||||||
|
let config_clone = config.clone();
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
search_nats(client, tree_clone, config_clone)
|
||||||
|
.await
|
||||||
|
.unwrap_or_log();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
start_server(config, tree).await.unwrap_or_log();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn configure_tracing(config: &Config) {
|
||||||
|
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new());
|
||||||
|
|
||||||
|
let env = std::env::var("ENVIRONMENT");
|
||||||
|
let env = if let Ok(env) = env.as_ref() {
|
||||||
|
env.as_str()
|
||||||
|
} else if cfg!(debug_assertions) {
|
||||||
|
"debug"
|
||||||
|
} else {
|
||||||
|
"release"
|
||||||
|
};
|
||||||
|
|
||||||
|
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
||||||
|
.with_endpoint(&config.jaeger_agent)
|
||||||
|
.with_service_name(&config.service_name)
|
||||||
|
.with_trace_config(opentelemetry::sdk::trace::config().with_resource(
|
||||||
|
opentelemetry::sdk::Resource::new(vec![
|
||||||
|
KeyValue::new("environment", env.to_owned()),
|
||||||
|
KeyValue::new("version", env!("CARGO_PKG_VERSION")),
|
||||||
|
]),
|
||||||
|
))
|
||||||
|
.install_batch(opentelemetry::runtime::Tokio)
|
||||||
|
.expect("otel jaeger pipeline could not be created");
|
||||||
|
|
||||||
|
let trace = tracing_opentelemetry::layer().with_tracer(tracer);
|
||||||
|
tracing::subscriber::set_global_default(
|
||||||
|
tracing_subscriber::Registry::default()
|
||||||
|
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||||
|
.with(trace)
|
||||||
|
.with(tracing_subscriber::fmt::layer()),
|
||||||
|
)
|
||||||
|
.expect("tracing could not be configured");
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start_server(config: Config, tree: tree::Tree) -> Result<(), Error> {
|
||||||
let tree = Data::new(tree);
|
let tree = Data::new(tree);
|
||||||
let config = Data::new(config);
|
let config_data = Data::new(config.clone());
|
||||||
|
|
||||||
HttpServer::new(move || {
|
HttpServer::new(move || {
|
||||||
App::new()
|
App::new()
|
||||||
@ -117,48 +215,32 @@ async fn main() {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
.app_data(tree.clone())
|
.app_data(tree.clone())
|
||||||
.app_data(config.clone())
|
.app_data(config_data.clone())
|
||||||
.service(search)
|
.service(search)
|
||||||
.service(health)
|
.service(health)
|
||||||
.service(metrics)
|
.service(metrics)
|
||||||
})
|
})
|
||||||
.bind(&http_listen)
|
.bind(&config.http_listen)
|
||||||
.expect_or_log("bind failed")
|
.expect_or_log("bind failed")
|
||||||
.run()
|
.run()
|
||||||
.await
|
.await
|
||||||
.expect_or_log("server failed");
|
.map_err(Error::Io)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn configure_tracing(config: &Config) {
|
#[get("/health")]
|
||||||
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new());
|
async fn health() -> impl Responder {
|
||||||
|
"OK"
|
||||||
|
}
|
||||||
|
|
||||||
let env = std::env::var("ENVIRONMENT");
|
#[get("/metrics")]
|
||||||
let env = if let Ok(env) = env.as_ref() {
|
async fn metrics() -> HttpResponse {
|
||||||
env.as_str()
|
let mut buffer = Vec::new();
|
||||||
} else if cfg!(debug_assertions) {
|
let encoder = TextEncoder::new();
|
||||||
"debug"
|
|
||||||
} else {
|
|
||||||
"release"
|
|
||||||
};
|
|
||||||
|
|
||||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
let metric_families = prometheus::gather();
|
||||||
.with_agent_endpoint(&config.jaeger_agent)
|
encoder.encode(&metric_families, &mut buffer).unwrap();
|
||||||
.with_service_name(&config.service_name)
|
|
||||||
.with_tags(vec![
|
|
||||||
KeyValue::new("environment", env.to_owned()),
|
|
||||||
KeyValue::new("version", env!("CARGO_PKG_VERSION")),
|
|
||||||
])
|
|
||||||
.install_batch(opentelemetry::runtime::Tokio)
|
|
||||||
.expect("otel jaeger pipeline could not be created");
|
|
||||||
|
|
||||||
let trace = tracing_opentelemetry::layer().with_tracer(tracer);
|
HttpResponse::Ok().body(buffer)
|
||||||
tracing::subscriber::set_global_default(
|
|
||||||
tracing_subscriber::Registry::default()
|
|
||||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
|
||||||
.with(trace)
|
|
||||||
.with(tracing_subscriber::fmt::layer()),
|
|
||||||
)
|
|
||||||
.expect("tracing could not be configured");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize)]
|
#[derive(Debug, serde::Deserialize)]
|
||||||
@ -167,18 +249,12 @@ struct Query {
|
|||||||
distance: u32,
|
distance: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Serialize)]
|
|
||||||
struct HashDistance {
|
|
||||||
hash: i64,
|
|
||||||
distance: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(serde::Serialize)]
|
#[derive(serde::Serialize)]
|
||||||
struct SearchResponse {
|
struct SearchResponse {
|
||||||
hash: i64,
|
hash: i64,
|
||||||
distance: u32,
|
distance: u32,
|
||||||
|
|
||||||
hashes: Vec<HashDistance>,
|
hashes: Vec<tree::HashDistance>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/search")]
|
#[get("/search")]
|
||||||
@ -187,54 +263,104 @@ async fn search(
|
|||||||
query: web::Query<Query>,
|
query: web::Query<Query>,
|
||||||
tree: Data<tree::Tree>,
|
tree: Data<tree::Tree>,
|
||||||
config: Data<Config>,
|
config: Data<Config>,
|
||||||
) -> Result<HttpResponse, std::convert::Infallible> {
|
) -> HttpResponse {
|
||||||
let Query { hash, distance } = query.0;
|
let Query { hash, distance } = query.0;
|
||||||
let max_distance = config.max_distance;
|
let distance = distance.clamp(0, config.max_distance);
|
||||||
|
|
||||||
tracing::info!("searching for hash {} with distance {}", hash, distance);
|
let hashes = tree
|
||||||
|
.find([tree::HashDistance { hash, distance }])
|
||||||
if matches!(max_distance, Some(max_distance) if distance > max_distance) {
|
.await
|
||||||
return Ok(HttpResponse::BadRequest().body("distance is greater than max distance"));
|
.remove(0);
|
||||||
}
|
|
||||||
|
|
||||||
let tree = tree.read().await;
|
|
||||||
|
|
||||||
let duration = TREE_DURATION
|
|
||||||
.with_label_values(&[&distance.to_string()])
|
|
||||||
.start_timer();
|
|
||||||
let matches: Vec<HashDistance> = tree
|
|
||||||
.find(&hash.into(), distance)
|
|
||||||
.into_iter()
|
|
||||||
.map(|item| HashDistance {
|
|
||||||
distance: item.0,
|
|
||||||
hash: (*item.1).into(),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let time = duration.stop_and_record();
|
|
||||||
|
|
||||||
tracing::debug!("found {} items in {} seconds", matches.len(), time);
|
|
||||||
|
|
||||||
let resp = SearchResponse {
|
let resp = SearchResponse {
|
||||||
hash,
|
hash,
|
||||||
distance,
|
distance,
|
||||||
hashes: matches,
|
hashes,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(HttpResponse::Ok().json(resp))
|
HttpResponse::Ok().json(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/health")]
|
#[derive(serde::Deserialize)]
|
||||||
async fn health() -> impl Responder {
|
struct SearchPayload {
|
||||||
"OK"
|
hash: i64,
|
||||||
|
distance: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/metrics")]
|
#[tracing::instrument(skip(client, tree, config))]
|
||||||
async fn metrics() -> Result<HttpResponse, std::convert::Infallible> {
|
async fn search_nats(
|
||||||
let mut buffer = Vec::new();
|
client: async_nats::Client,
|
||||||
let encoder = TextEncoder::new();
|
tree: tree::Tree,
|
||||||
|
config: Config,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
tracing::info!("subscribing to searches");
|
||||||
|
|
||||||
let metric_families = prometheus::gather();
|
let client = Arc::new(client);
|
||||||
encoder.encode(&metric_families, &mut buffer).unwrap();
|
let max_distance = config.max_distance;
|
||||||
|
|
||||||
Ok(HttpResponse::Ok().body(buffer))
|
let mut sub = client
|
||||||
|
.queue_subscribe("bkapi.search".to_string(), "bkapi-search".to_string())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
while let Some(message) = sub.next().await {
|
||||||
|
tracing::trace!("got search message");
|
||||||
|
|
||||||
|
let reply = match message.reply {
|
||||||
|
Some(reply) => reply,
|
||||||
|
None => {
|
||||||
|
tracing::warn!("message had no reply subject, skipping");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(err) = handle_search_nats(
|
||||||
|
max_distance,
|
||||||
|
client.clone(),
|
||||||
|
tree.clone(),
|
||||||
|
reply,
|
||||||
|
&message.payload,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
tracing::error!("could not handle nats search: {err}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_search_nats(
|
||||||
|
max_distance: u32,
|
||||||
|
client: Arc<async_nats::Client>,
|
||||||
|
tree: tree::Tree,
|
||||||
|
reply: String,
|
||||||
|
payload: &[u8],
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let payloads: Vec<SearchPayload> = serde_json::from_slice(payload).map_err(Error::Data)?;
|
||||||
|
|
||||||
|
tokio::task::spawn(
|
||||||
|
async move {
|
||||||
|
let hashes = payloads.into_iter().map(|payload| tree::HashDistance {
|
||||||
|
hash: payload.hash,
|
||||||
|
distance: payload.distance.clamp(0, max_distance),
|
||||||
|
});
|
||||||
|
|
||||||
|
let results = tree.find(hashes).await;
|
||||||
|
|
||||||
|
if let Err(err) = client
|
||||||
|
.publish(
|
||||||
|
reply,
|
||||||
|
serde_json::to_vec(&results)
|
||||||
|
.expect_or_log("results could not be serialized")
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
tracing::error!("could not publish results: {err}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.in_current_span(),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,15 +1,136 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use bk_tree::BKTree;
|
||||||
|
use futures::TryStreamExt;
|
||||||
use sqlx::{postgres::PgListener, Pool, Postgres, Row};
|
use sqlx::{postgres::PgListener, Pool, Postgres, Row};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tracing_unwrap::ResultExt;
|
use tracing_unwrap::ResultExt;
|
||||||
|
|
||||||
use crate::{Config, Error};
|
use crate::{Config, Error};
|
||||||
|
|
||||||
pub(crate) type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
|
lazy_static::lazy_static! {
|
||||||
|
static ref TREE_ENTRIES: prometheus::IntCounter = prometheus::register_int_counter!("bkapi_tree_entries", "Total number of entries within tree").unwrap();
|
||||||
|
|
||||||
|
static ref TREE_ADD_DURATION: prometheus::Histogram = prometheus::register_histogram!("bkapi_tree_add_duration_seconds", "Duration to add new item to tree").unwrap();
|
||||||
|
static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A BKTree wrapper to cover common operations.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Tree {
|
||||||
|
tree: Arc<RwLock<BKTree<Node, Hamming>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A hash and distance pair. May be used for searching or in search results.
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
pub struct HashDistance {
|
||||||
|
pub hash: i64,
|
||||||
|
pub distance: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tree {
|
||||||
|
/// Create an empty tree.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
tree: Arc::new(RwLock::new(BKTree::new(Hamming))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Replace tree contents with the results of a SQL query.
|
||||||
|
///
|
||||||
|
/// The tree is only replaced after it finishes loading, leaving stale/empty
|
||||||
|
/// data available while running.
|
||||||
|
pub(crate) async fn reload(&self, pool: &sqlx::PgPool, query: &str) -> Result<(), Error> {
|
||||||
|
let mut new_tree = BKTree::new(Hamming);
|
||||||
|
let mut rows = sqlx::query(query).fetch(pool);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let mut count = 0;
|
||||||
|
|
||||||
|
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
|
||||||
|
let node: Node = row.get::<i64, _>(0).into();
|
||||||
|
|
||||||
|
if new_tree.find_exact(&node).is_none() {
|
||||||
|
new_tree.add(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
count += 1;
|
||||||
|
if count % 250_000 == 0 {
|
||||||
|
tracing::debug!(count, "loaded more rows");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let dur = std::time::Instant::now().duration_since(start);
|
||||||
|
tracing::info!(count, "completed loading rows in {:?}", dur);
|
||||||
|
|
||||||
|
let mut tree = self.tree.write().await;
|
||||||
|
*tree = new_tree;
|
||||||
|
|
||||||
|
TREE_ENTRIES.reset();
|
||||||
|
TREE_ENTRIES.inc_by(count);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a hash to the tree, returning if it already existed.
|
||||||
|
#[tracing::instrument(skip(self))]
|
||||||
|
pub async fn add(&self, hash: i64) -> bool {
|
||||||
|
let node = Node::from(hash);
|
||||||
|
|
||||||
|
let is_new_hash = {
|
||||||
|
let tree = self.tree.read().await;
|
||||||
|
tree.find_exact(&node).is_none()
|
||||||
|
};
|
||||||
|
|
||||||
|
if is_new_hash {
|
||||||
|
let mut tree = self.tree.write().await;
|
||||||
|
tree.add(node);
|
||||||
|
TREE_ENTRIES.inc();
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(is_new_hash, "added hash");
|
||||||
|
|
||||||
|
is_new_hash
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Attempt to find any number of hashes within the tree.
|
||||||
|
pub async fn find<H>(&self, hashes: H) -> Vec<Vec<HashDistance>>
|
||||||
|
where
|
||||||
|
H: IntoIterator<Item = HashDistance>,
|
||||||
|
{
|
||||||
|
let tree = self.tree.read().await;
|
||||||
|
|
||||||
|
hashes
|
||||||
|
.into_iter()
|
||||||
|
.map(|HashDistance { hash, distance }| Self::search(&tree, hash, distance))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search a read-locked tree for a hash with a given distance.
|
||||||
|
#[tracing::instrument(skip(tree))]
|
||||||
|
fn search(tree: &BKTree<Node, Hamming>, hash: i64, distance: u32) -> Vec<HashDistance> {
|
||||||
|
tracing::debug!("searching tree");
|
||||||
|
|
||||||
|
let duration = TREE_DURATION
|
||||||
|
.with_label_values(&[&distance.to_string()])
|
||||||
|
.start_timer();
|
||||||
|
let results: Vec<_> = tree
|
||||||
|
.find(&hash.into(), distance)
|
||||||
|
.into_iter()
|
||||||
|
.map(|item| HashDistance {
|
||||||
|
distance: item.0,
|
||||||
|
hash: (*item.1).into(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let time = duration.stop_and_record();
|
||||||
|
|
||||||
|
tracing::info!(time, results = results.len(), "found results");
|
||||||
|
results
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// A hamming distance metric.
|
/// A hamming distance metric.
|
||||||
pub(crate) struct Hamming;
|
struct Hamming;
|
||||||
|
|
||||||
impl bk_tree::Metric<Node> for Hamming {
|
impl bk_tree::Metric<Node> for Hamming {
|
||||||
fn distance(&self, a: &Node, b: &Node) -> u32 {
|
fn distance(&self, a: &Node, b: &Node) -> u32 {
|
||||||
@ -24,7 +145,7 @@ impl bk_tree::Metric<Node> for Hamming {
|
|||||||
|
|
||||||
/// A value of a node in the BK tree.
|
/// A value of a node in the BK tree.
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
pub(crate) struct Node([u8; 8]);
|
struct Node([u8; 8]);
|
||||||
|
|
||||||
impl From<i64> for Node {
|
impl From<i64> for Node {
|
||||||
fn from(num: i64) -> Self {
|
fn from(num: i64) -> Self {
|
||||||
@ -38,106 +159,110 @@ impl From<Node> for i64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new BK tree and pull in all hashes from provided query.
|
/// A payload for a new hash to add to the index from Postgres or NATS.
|
||||||
///
|
|
||||||
/// This must be called after you have started a listener, otherwise items may
|
|
||||||
/// be lost.
|
|
||||||
async fn create_tree(
|
|
||||||
conn: &Pool<Postgres>,
|
|
||||||
config: &Config,
|
|
||||||
) -> Result<bk_tree::BKTree<Node, Hamming>, Error> {
|
|
||||||
use futures::TryStreamExt;
|
|
||||||
|
|
||||||
tracing::warn!("creating new tree");
|
|
||||||
let mut tree = bk_tree::BKTree::new(Hamming);
|
|
||||||
let mut rows = sqlx::query(&config.database_query).fetch(conn);
|
|
||||||
|
|
||||||
let mut count = 0;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
|
|
||||||
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
|
|
||||||
let node: Node = row.get::<i64, _>(0).into();
|
|
||||||
|
|
||||||
// Avoid checking if each value is unique if we were told that the
|
|
||||||
// database query only returns unique values.
|
|
||||||
let timer = crate::TREE_ADD_DURATION.start_timer();
|
|
||||||
if config.database_is_unique || tree.find_exact(&node).is_none() {
|
|
||||||
tree.add(node);
|
|
||||||
}
|
|
||||||
timer.stop_and_record();
|
|
||||||
|
|
||||||
count += 1;
|
|
||||||
if count % 250_000 == 0 {
|
|
||||||
tracing::debug!(count, "loaded more rows");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let dur = std::time::Instant::now().duration_since(start);
|
|
||||||
|
|
||||||
tracing::info!(count, "completed loading rows in {:?}", dur);
|
|
||||||
Ok(tree)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(serde::Deserialize)]
|
||||||
struct Payload {
|
struct Payload {
|
||||||
hash: i64,
|
hash: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Listen for incoming payloads.
|
/// Listen for incoming payloads from Postgres.
|
||||||
///
|
#[tracing::instrument(skip(conn, subscription, query, tree, initial))]
|
||||||
/// This will create a new tree to ensure all items are present. It will also
|
pub(crate) async fn listen_for_payloads_db(
|
||||||
/// automatically recreate trees as needed if the database connection is lost.
|
|
||||||
pub(crate) async fn listen_for_payloads(
|
|
||||||
conn: Pool<Postgres>,
|
conn: Pool<Postgres>,
|
||||||
config: Config,
|
subscription: String,
|
||||||
|
query: String,
|
||||||
tree: Tree,
|
tree: Tree,
|
||||||
initial: futures::channel::oneshot::Sender<()>,
|
initial: futures::channel::oneshot::Sender<()>,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
|
let mut initial = Some(initial);
|
||||||
|
|
||||||
|
loop {
|
||||||
let mut listener = PgListener::connect_with(&conn)
|
let mut listener = PgListener::connect_with(&conn)
|
||||||
.await
|
.await
|
||||||
.map_err(Error::Listener)?;
|
.map_err(Error::Listener)?;
|
||||||
listener
|
listener
|
||||||
.listen(&config.database_subscribe)
|
.listen(&subscription)
|
||||||
.await
|
.await
|
||||||
.map_err(Error::Listener)?;
|
.map_err(Error::Listener)?;
|
||||||
|
|
||||||
let new_tree = create_tree(&conn, &config).await?;
|
tree.reload(&conn, &query).await?;
|
||||||
{
|
|
||||||
let mut tree = tree.write().await;
|
|
||||||
*tree = new_tree;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if let Some(initial) = initial.take() {
|
||||||
initial
|
initial
|
||||||
.send(())
|
.send(())
|
||||||
.expect_or_log("nothing listening for initial data");
|
.expect_or_log("nothing listening for initial data");
|
||||||
|
}
|
||||||
|
|
||||||
loop {
|
|
||||||
while let Some(notification) = listener.try_recv().await.map_err(Error::Listener)? {
|
while let Some(notification) = listener.try_recv().await.map_err(Error::Listener)? {
|
||||||
let payload: Payload =
|
tracing::trace!("got postgres payload");
|
||||||
serde_json::from_str(notification.payload()).map_err(Error::Data)?;
|
process_payload(&tree, notification.payload().as_bytes()).await?;
|
||||||
tracing::debug!(hash = payload.hash, "evaluating new payload");
|
|
||||||
|
|
||||||
let node: Node = payload.hash.into();
|
|
||||||
|
|
||||||
let _timer = crate::TREE_ADD_DURATION.start_timer();
|
|
||||||
|
|
||||||
let mut tree = tree.write().await;
|
|
||||||
if tree.find_exact(&node).is_some() {
|
|
||||||
tracing::trace!("hash already existed in tree");
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::trace!("hash did not exist, adding to tree");
|
tracing::error!("disconnected from postgres listener, recreating tree");
|
||||||
tree.add(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::error!("disconnected from listener, recreating tree");
|
|
||||||
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||||
let new_tree = create_tree(&conn, &config).await?;
|
|
||||||
{
|
|
||||||
let mut tree = tree.write().await;
|
|
||||||
*tree = new_tree;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Listen for incoming payloads from NATS.
|
||||||
|
#[tracing::instrument(skip(config, pool, client, tree, initial))]
|
||||||
|
pub(crate) async fn listen_for_payloads_nats(
|
||||||
|
config: Config,
|
||||||
|
pool: sqlx::PgPool,
|
||||||
|
client: async_nats::Client,
|
||||||
|
tree: Tree,
|
||||||
|
initial: futures::channel::oneshot::Sender<()>,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let jetstream = async_nats::jetstream::new(client);
|
||||||
|
let mut initial = Some(initial);
|
||||||
|
|
||||||
|
let stream = jetstream
|
||||||
|
.get_or_create_stream(async_nats::jetstream::stream::Config {
|
||||||
|
name: "bkapi-hashes".to_string(),
|
||||||
|
subjects: vec!["bkapi.add".to_string()],
|
||||||
|
max_age: std::time::Duration::from_secs(60 * 60 * 24),
|
||||||
|
retention: async_nats::jetstream::stream::RetentionPolicy::Interest,
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let consumer = stream
|
||||||
|
.get_or_create_consumer(
|
||||||
|
"bkapi-consumer",
|
||||||
|
async_nats::jetstream::consumer::pull::Config {
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
tree.reload(&pool, &config.database_query).await?;
|
||||||
|
|
||||||
|
if let Some(initial) = initial.take() {
|
||||||
|
initial
|
||||||
|
.send(())
|
||||||
|
.expect_or_log("nothing listening for initial data");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut messages = consumer.messages().await?;
|
||||||
|
|
||||||
|
while let Ok(Some(message)) = messages.try_next().await {
|
||||||
|
tracing::trace!("got nats payload");
|
||||||
|
message.ack().await?;
|
||||||
|
process_payload(&tree, &message.payload).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::error!("disconnected from nats listener, recreating tree");
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a payload from Postgres or NATS and add to the tree.
|
||||||
|
async fn process_payload(tree: &Tree, payload: &[u8]) -> Result<(), Error> {
|
||||||
|
let payload: Payload = serde_json::from_slice(payload).map_err(Error::Data)?;
|
||||||
|
tracing::trace!("got hash: {}", payload.hash);
|
||||||
|
|
||||||
|
let _timer = TREE_ADD_DURATION.start_timer();
|
||||||
|
tree.add(payload.hash).await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user