Merge pull request #1 from Syfaro/nats

Support NATS
This commit is contained in:
Syfaro 2022-10-17 23:50:43 -04:00 committed by GitHub
commit 510da555e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1374 additions and 690 deletions

View File

@ -63,15 +63,3 @@ jobs:
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
file: bkapi/Dockerfile file: bkapi/Dockerfile
sourcegraph:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Generate LSIF data
uses: sourcegraph/lsif-rust-action@main
- name: Upload LSIF data
uses: sourcegraph/lsif-upload-action@master
with:
github_token: ${{ secrets.GITHUB_TOKEN }}

1
.gitignore vendored
View File

@ -1 +1,2 @@
/target /target
.env

1331
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -6,14 +6,15 @@ edition = "2018"
[dependencies] [dependencies]
tracing = "0.1" tracing = "0.1"
tracing-opentelemetry = "0.17" tracing-opentelemetry = "0.18"
opentelemetry = "0.17" opentelemetry = "0.18"
opentelemetry-http = "0.6" opentelemetry-http = "0.7"
futures = "0.3" futures = "0.3"
reqwest = { version = "0.11", features = ["json"] } reqwest = { version = "0.11", features = ["json"] }
async-nats = "0.21"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"

View File

@ -85,7 +85,7 @@ impl BKApiClient {
) -> Result<Vec<SearchResults>, reqwest::Error> { ) -> Result<Vec<SearchResults>, reqwest::Error> {
let mut futs = futures::stream::FuturesOrdered::new(); let mut futs = futures::stream::FuturesOrdered::new();
for hash in hashes { for hash in hashes {
futs.push(self.search(*hash, distance)); futs.push_back(self.search(*hash, distance));
} }
futs.try_collect().await futs.try_collect().await
@ -111,6 +111,77 @@ impl InjectContext for reqwest::RequestBuilder {
} }
} }
/// The BKApi client, operating over NATS instead of HTTP.
#[derive(Clone)]
pub struct BKApiNatsClient {
client: async_nats::Client,
}
/// A hash and distance.
#[derive(serde::Serialize, serde::Deserialize)]
pub struct HashDistance {
hash: i64,
distance: u32,
}
impl BKApiNatsClient {
const NATS_SUBJECT: &str = "bkapi.search";
/// Create a new client with a given NATS client.
pub fn new(client: async_nats::Client) -> Self {
Self { client }
}
/// Search for a single hash.
pub async fn search(
&self,
hash: i64,
distance: i32,
) -> Result<SearchResults, async_nats::Error> {
let hashes = [HashDistance {
hash,
distance: distance as u32,
}];
self.search_many(&hashes)
.await
.map(|mut results| results.remove(0))
}
/// Search many hashes at once.
pub async fn search_many(
&self,
hashes: &[HashDistance],
) -> Result<Vec<SearchResults>, async_nats::Error> {
let payload = serde_json::to_vec(hashes).unwrap();
let message = self
.client
.request(Self::NATS_SUBJECT.to_string(), payload.into())
.await?;
let results: Vec<Vec<HashDistance>> = serde_json::from_slice(&message.payload).unwrap();
let results = results
.into_iter()
.zip(hashes)
.map(|(results, search)| SearchResults {
hash: search.hash,
distance: search.distance as u64,
hashes: results
.into_iter()
.map(|result| SearchResult {
hash: result.hash,
distance: result.distance as u64,
})
.collect(),
})
.collect();
Ok(results)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
fn get_test_endpoint() -> String { fn get_test_endpoint() -> String {

View File

@ -5,16 +5,17 @@ authors = ["Syfaro <syfaro@huefox.com>"]
edition = "2018" edition = "2018"
[dependencies] [dependencies]
envconfig = "0.10" dotenvy = "0.15"
clap = { version = "4", features = ["derive", "env"] }
thiserror = "1" thiserror = "1"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-unwrap = "0.9" tracing-unwrap = "0.9"
tracing-opentelemetry = "0.17" tracing-opentelemetry = "0.18"
opentelemetry = { version = "0.17", features = ["rt-tokio"] } opentelemetry = { version = "0.18", features = ["rt-tokio"] }
opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] } opentelemetry-jaeger = { version = "0.17", features = ["rt-tokio"] }
lazy_static = "1" lazy_static = "1"
prometheus = { version = "0.13", features = ["process"] } prometheus = { version = "0.13", features = ["process"] }
@ -31,8 +32,10 @@ serde_json = "1"
actix-web = "4" actix-web = "4"
actix-http = "3" actix-http = "3"
actix-service = "2" actix-service = "2"
tracing-actix-web = { version = "0.5", features = ["opentelemetry_0_17"] } tracing-actix-web = { version = "0.6", features = ["opentelemetry_0_18"] }
async-nats = "0.21"
[dependencies.sqlx] [dependencies.sqlx]
version = "0.5" version = "0.6"
features = ["runtime-actix-rustls", "postgres"] features = ["runtime-actix-rustls", "postgres"]

View File

@ -6,11 +6,12 @@ use actix_web::{
web::{self, Data}, web::{self, Data},
App, HttpResponse, HttpServer, Responder, App, HttpResponse, HttpServer, Responder,
}; };
use envconfig::Envconfig; use clap::Parser;
use futures::StreamExt;
use opentelemetry::KeyValue; use opentelemetry::KeyValue;
use prometheus::{Encoder, TextEncoder}; use prometheus::{Encoder, TextEncoder};
use sqlx::postgres::PgPoolOptions; use sqlx::postgres::PgPoolOptions;
use tokio::sync::RwLock; use tracing::Instrument;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_unwrap::ResultExt; use tracing_unwrap::ResultExt;
@ -19,9 +20,6 @@ mod tree;
lazy_static::lazy_static! { lazy_static::lazy_static! {
static ref HTTP_REQUEST_COUNT: prometheus::CounterVec = prometheus::register_counter_vec!("http_requests_total", "Number of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap(); static ref HTTP_REQUEST_COUNT: prometheus::CounterVec = prometheus::register_counter_vec!("http_requests_total", "Number of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
static ref HTTP_REQUEST_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("http_request_duration_seconds", "Duration of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap(); static ref HTTP_REQUEST_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("http_request_duration_seconds", "Duration of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap();
static ref TREE_ADD_DURATION: prometheus::Histogram = prometheus::register_histogram!("bkapi_tree_add_duration_seconds", "Duration to add new item to tree").unwrap();
} }
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -32,34 +30,58 @@ enum Error {
Listener(sqlx::Error), Listener(sqlx::Error),
#[error("listener got data that could not be decoded: {0}")] #[error("listener got data that could not be decoded: {0}")]
Data(serde_json::Error), Data(serde_json::Error),
#[error("nats encountered error: {0}")]
Nats(#[from] async_nats::Error),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
} }
#[derive(Envconfig, Clone)] #[derive(Parser, Clone)]
struct Config { struct Config {
#[envconfig(default = "0.0.0.0:3000")] /// Host to listen for incoming HTTP requests.
#[clap(long, env, default_value = "127.0.0.1:3000")]
http_listen: String, http_listen: String,
#[envconfig(default = "127.0.0.1:6831")]
/// Jaeger agent endpoint for span collection.
#[clap(long, env, default_value = "127.0.0.1:6831")]
jaeger_agent: String, jaeger_agent: String,
#[envconfig(default = "bkapi")] /// Service name for spans.
#[clap(long, env, default_value = "bkapi")]
service_name: String, service_name: String,
/// Database URL for fetching data.
#[clap(long, env)]
database_url: String, database_url: String,
/// Query to perform to fetch initial values.
#[clap(long, env)]
database_query: String, database_query: String,
database_subscribe: String,
#[envconfig(default = "false")]
database_is_unique: bool,
max_distance: Option<u32>, /// If provided, the Postgres notification topic to subscribe to.
#[clap(long, env)]
database_subscribe: Option<String>,
/// The NATS host.
#[clap(long, env)]
nats_host: Option<String>,
/// The NATS NKEY.
#[clap(long, env)]
nats_nkey: Option<String>,
/// Maximum distance permitted in queries.
#[clap(long, env, default_value = "10")]
max_distance: u32,
} }
#[actix_web::main] #[actix_web::main]
async fn main() { async fn main() {
let config = Config::init_from_env().expect("could not load config"); let _ = dotenvy::dotenv();
let config = Config::parse();
configure_tracing(&config); configure_tracing(&config);
tracing::info!("starting bkbase"); tracing::info!("starting bkapi");
let tree: tree::Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(tree::Hamming))); let tree = tree::Tree::new();
tracing::trace!("connecting to postgres"); tracing::trace!("connecting to postgres");
let pool = PgPoolOptions::new() let pool = PgPoolOptions::new()
@ -69,18 +91,46 @@ async fn main() {
.expect_or_log("could not connect to database"); .expect_or_log("could not connect to database");
tracing::debug!("connected to postgres"); tracing::debug!("connected to postgres");
let http_listen = config.http_listen.clone();
let (sender, receiver) = futures::channel::oneshot::channel(); let (sender, receiver) = futures::channel::oneshot::channel();
tracing::info!("starting to listen for payloads"); let client = match (config.nats_host.as_deref(), config.nats_nkey.as_deref()) {
(Some(host), None) => Some(
async_nats::connect(host)
.await
.expect_or_log("could not connect to nats with no nkey"),
),
(Some(host), Some(nkey)) => Some(
async_nats::ConnectOptions::with_nkey(nkey.to_string())
.connect(host)
.await
.expect_or_log("could not connect to nats with nkey"),
),
_ => None,
};
let tree_clone = tree.clone(); let tree_clone = tree.clone();
let config_clone = config.clone(); let config_clone = config.clone();
tokio::task::spawn(async { if let Some(subscription) = config.database_subscribe.clone() {
tree::listen_for_payloads(pool, config_clone, tree_clone, sender) tracing::info!("starting to listen for payloads from postgres");
let query = config.database_query.clone();
tokio::task::spawn(async move {
tree::listen_for_payloads_db(pool, subscription, query, tree_clone, sender)
.await .await
.expect_or_log("listenting for updates failed"); .unwrap_or_log();
}); });
} else if let Some(client) = client.clone() {
tracing::info!("starting to listen for payloads from nats");
tokio::task::spawn(async {
tree::listen_for_payloads_nats(config_clone, pool, client, tree_clone, sender)
.await
.unwrap_or_log();
});
} else {
panic!("no listener source available");
};
tracing::info!("waiting for initial tree to load"); tracing::info!("waiting for initial tree to load");
receiver receiver
@ -88,8 +138,56 @@ async fn main() {
.expect_or_log("tree loading was dropped before completing"); .expect_or_log("tree loading was dropped before completing");
tracing::info!("initial tree loaded, starting server"); tracing::info!("initial tree loaded, starting server");
if let Some(client) = client {
let tree_clone = tree.clone();
let config_clone = config.clone();
tokio::task::spawn(async move {
search_nats(client, tree_clone, config_clone)
.await
.unwrap_or_log();
});
}
start_server(config, tree).await.unwrap_or_log();
}
fn configure_tracing(config: &Config) {
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new());
let env = std::env::var("ENVIRONMENT");
let env = if let Ok(env) = env.as_ref() {
env.as_str()
} else if cfg!(debug_assertions) {
"debug"
} else {
"release"
};
let tracer = opentelemetry_jaeger::new_agent_pipeline()
.with_endpoint(&config.jaeger_agent)
.with_service_name(&config.service_name)
.with_trace_config(opentelemetry::sdk::trace::config().with_resource(
opentelemetry::sdk::Resource::new(vec![
KeyValue::new("environment", env.to_owned()),
KeyValue::new("version", env!("CARGO_PKG_VERSION")),
]),
))
.install_batch(opentelemetry::runtime::Tokio)
.expect("otel jaeger pipeline could not be created");
let trace = tracing_opentelemetry::layer().with_tracer(tracer);
tracing::subscriber::set_global_default(
tracing_subscriber::Registry::default()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(trace)
.with(tracing_subscriber::fmt::layer()),
)
.expect("tracing could not be configured");
}
async fn start_server(config: Config, tree: tree::Tree) -> Result<(), Error> {
let tree = Data::new(tree); let tree = Data::new(tree);
let config = Data::new(config); let config_data = Data::new(config.clone());
HttpServer::new(move || { HttpServer::new(move || {
App::new() App::new()
@ -117,48 +215,32 @@ async fn main() {
} }
}) })
.app_data(tree.clone()) .app_data(tree.clone())
.app_data(config.clone()) .app_data(config_data.clone())
.service(search) .service(search)
.service(health) .service(health)
.service(metrics) .service(metrics)
}) })
.bind(&http_listen) .bind(&config.http_listen)
.expect_or_log("bind failed") .expect_or_log("bind failed")
.run() .run()
.await .await
.expect_or_log("server failed"); .map_err(Error::Io)
} }
fn configure_tracing(config: &Config) { #[get("/health")]
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); async fn health() -> impl Responder {
"OK"
}
let env = std::env::var("ENVIRONMENT"); #[get("/metrics")]
let env = if let Ok(env) = env.as_ref() { async fn metrics() -> HttpResponse {
env.as_str() let mut buffer = Vec::new();
} else if cfg!(debug_assertions) { let encoder = TextEncoder::new();
"debug"
} else {
"release"
};
let tracer = opentelemetry_jaeger::new_pipeline() let metric_families = prometheus::gather();
.with_agent_endpoint(&config.jaeger_agent) encoder.encode(&metric_families, &mut buffer).unwrap();
.with_service_name(&config.service_name)
.with_tags(vec![
KeyValue::new("environment", env.to_owned()),
KeyValue::new("version", env!("CARGO_PKG_VERSION")),
])
.install_batch(opentelemetry::runtime::Tokio)
.expect("otel jaeger pipeline could not be created");
let trace = tracing_opentelemetry::layer().with_tracer(tracer); HttpResponse::Ok().body(buffer)
tracing::subscriber::set_global_default(
tracing_subscriber::Registry::default()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(trace)
.with(tracing_subscriber::fmt::layer()),
)
.expect("tracing could not be configured");
} }
#[derive(Debug, serde::Deserialize)] #[derive(Debug, serde::Deserialize)]
@ -167,18 +249,12 @@ struct Query {
distance: u32, distance: u32,
} }
#[derive(serde::Serialize)]
struct HashDistance {
hash: i64,
distance: u32,
}
#[derive(serde::Serialize)] #[derive(serde::Serialize)]
struct SearchResponse { struct SearchResponse {
hash: i64, hash: i64,
distance: u32, distance: u32,
hashes: Vec<HashDistance>, hashes: Vec<tree::HashDistance>,
} }
#[get("/search")] #[get("/search")]
@ -187,54 +263,104 @@ async fn search(
query: web::Query<Query>, query: web::Query<Query>,
tree: Data<tree::Tree>, tree: Data<tree::Tree>,
config: Data<Config>, config: Data<Config>,
) -> Result<HttpResponse, std::convert::Infallible> { ) -> HttpResponse {
let Query { hash, distance } = query.0; let Query { hash, distance } = query.0;
let max_distance = config.max_distance; let distance = distance.clamp(0, config.max_distance);
tracing::info!("searching for hash {} with distance {}", hash, distance); let hashes = tree
.find([tree::HashDistance { hash, distance }])
if matches!(max_distance, Some(max_distance) if distance > max_distance) { .await
return Ok(HttpResponse::BadRequest().body("distance is greater than max distance")); .remove(0);
}
let tree = tree.read().await;
let duration = TREE_DURATION
.with_label_values(&[&distance.to_string()])
.start_timer();
let matches: Vec<HashDistance> = tree
.find(&hash.into(), distance)
.into_iter()
.map(|item| HashDistance {
distance: item.0,
hash: (*item.1).into(),
})
.collect();
let time = duration.stop_and_record();
tracing::debug!("found {} items in {} seconds", matches.len(), time);
let resp = SearchResponse { let resp = SearchResponse {
hash, hash,
distance, distance,
hashes: matches, hashes,
}; };
Ok(HttpResponse::Ok().json(resp)) HttpResponse::Ok().json(resp)
} }
#[get("/health")] #[derive(serde::Deserialize)]
async fn health() -> impl Responder { struct SearchPayload {
"OK" hash: i64,
distance: u32,
} }
#[get("/metrics")] #[tracing::instrument(skip(client, tree, config))]
async fn metrics() -> Result<HttpResponse, std::convert::Infallible> { async fn search_nats(
let mut buffer = Vec::new(); client: async_nats::Client,
let encoder = TextEncoder::new(); tree: tree::Tree,
config: Config,
) -> Result<(), Error> {
tracing::info!("subscribing to searches");
let metric_families = prometheus::gather(); let client = Arc::new(client);
encoder.encode(&metric_families, &mut buffer).unwrap(); let max_distance = config.max_distance;
Ok(HttpResponse::Ok().body(buffer)) let mut sub = client
.queue_subscribe("bkapi.search".to_string(), "bkapi-search".to_string())
.await?;
while let Some(message) = sub.next().await {
tracing::trace!("got search message");
let reply = match message.reply {
Some(reply) => reply,
None => {
tracing::warn!("message had no reply subject, skipping");
continue;
}
};
if let Err(err) = handle_search_nats(
max_distance,
client.clone(),
tree.clone(),
reply,
&message.payload,
)
.await
{
tracing::error!("could not handle nats search: {err}");
}
}
Ok(())
}
async fn handle_search_nats(
max_distance: u32,
client: Arc<async_nats::Client>,
tree: tree::Tree,
reply: String,
payload: &[u8],
) -> Result<(), Error> {
let payloads: Vec<SearchPayload> = serde_json::from_slice(payload).map_err(Error::Data)?;
tokio::task::spawn(
async move {
let hashes = payloads.into_iter().map(|payload| tree::HashDistance {
hash: payload.hash,
distance: payload.distance.clamp(0, max_distance),
});
let results = tree.find(hashes).await;
if let Err(err) = client
.publish(
reply,
serde_json::to_vec(&results)
.expect_or_log("results could not be serialized")
.into(),
)
.await
{
tracing::error!("could not publish results: {err}");
}
}
.in_current_span(),
);
Ok(())
} }

View File

@ -1,15 +1,136 @@
use std::sync::Arc; use std::sync::Arc;
use bk_tree::BKTree;
use futures::TryStreamExt;
use sqlx::{postgres::PgListener, Pool, Postgres, Row}; use sqlx::{postgres::PgListener, Pool, Postgres, Row};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing_unwrap::ResultExt; use tracing_unwrap::ResultExt;
use crate::{Config, Error}; use crate::{Config, Error};
pub(crate) type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>; lazy_static::lazy_static! {
static ref TREE_ENTRIES: prometheus::IntCounter = prometheus::register_int_counter!("bkapi_tree_entries", "Total number of entries within tree").unwrap();
static ref TREE_ADD_DURATION: prometheus::Histogram = prometheus::register_histogram!("bkapi_tree_add_duration_seconds", "Duration to add new item to tree").unwrap();
static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap();
}
/// A BKTree wrapper to cover common operations.
#[derive(Clone)]
pub struct Tree {
tree: Arc<RwLock<BKTree<Node, Hamming>>>,
}
/// A hash and distance pair. May be used for searching or in search results.
#[derive(serde::Serialize)]
pub struct HashDistance {
pub hash: i64,
pub distance: u32,
}
impl Tree {
/// Create an empty tree.
pub fn new() -> Self {
Self {
tree: Arc::new(RwLock::new(BKTree::new(Hamming))),
}
}
/// Replace tree contents with the results of a SQL query.
///
/// The tree is only replaced after it finishes loading, leaving stale/empty
/// data available while running.
pub(crate) async fn reload(&self, pool: &sqlx::PgPool, query: &str) -> Result<(), Error> {
let mut new_tree = BKTree::new(Hamming);
let mut rows = sqlx::query(query).fetch(pool);
let start = std::time::Instant::now();
let mut count = 0;
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
let node: Node = row.get::<i64, _>(0).into();
if new_tree.find_exact(&node).is_none() {
new_tree.add(node);
}
count += 1;
if count % 250_000 == 0 {
tracing::debug!(count, "loaded more rows");
}
}
let dur = std::time::Instant::now().duration_since(start);
tracing::info!(count, "completed loading rows in {:?}", dur);
let mut tree = self.tree.write().await;
*tree = new_tree;
TREE_ENTRIES.reset();
TREE_ENTRIES.inc_by(count);
Ok(())
}
/// Add a hash to the tree, returning if it already existed.
#[tracing::instrument(skip(self))]
pub async fn add(&self, hash: i64) -> bool {
let node = Node::from(hash);
let is_new_hash = {
let tree = self.tree.read().await;
tree.find_exact(&node).is_none()
};
if is_new_hash {
let mut tree = self.tree.write().await;
tree.add(node);
TREE_ENTRIES.inc();
}
tracing::info!(is_new_hash, "added hash");
is_new_hash
}
/// Attempt to find any number of hashes within the tree.
pub async fn find<H>(&self, hashes: H) -> Vec<Vec<HashDistance>>
where
H: IntoIterator<Item = HashDistance>,
{
let tree = self.tree.read().await;
hashes
.into_iter()
.map(|HashDistance { hash, distance }| Self::search(&tree, hash, distance))
.collect()
}
/// Search a read-locked tree for a hash with a given distance.
#[tracing::instrument(skip(tree))]
fn search(tree: &BKTree<Node, Hamming>, hash: i64, distance: u32) -> Vec<HashDistance> {
tracing::debug!("searching tree");
let duration = TREE_DURATION
.with_label_values(&[&distance.to_string()])
.start_timer();
let results: Vec<_> = tree
.find(&hash.into(), distance)
.into_iter()
.map(|item| HashDistance {
distance: item.0,
hash: (*item.1).into(),
})
.collect();
let time = duration.stop_and_record();
tracing::info!(time, results = results.len(), "found results");
results
}
}
/// A hamming distance metric. /// A hamming distance metric.
pub(crate) struct Hamming; struct Hamming;
impl bk_tree::Metric<Node> for Hamming { impl bk_tree::Metric<Node> for Hamming {
fn distance(&self, a: &Node, b: &Node) -> u32 { fn distance(&self, a: &Node, b: &Node) -> u32 {
@ -24,7 +145,7 @@ impl bk_tree::Metric<Node> for Hamming {
/// A value of a node in the BK tree. /// A value of a node in the BK tree.
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct Node([u8; 8]); struct Node([u8; 8]);
impl From<i64> for Node { impl From<i64> for Node {
fn from(num: i64) -> Self { fn from(num: i64) -> Self {
@ -38,106 +159,110 @@ impl From<Node> for i64 {
} }
} }
/// Create a new BK tree and pull in all hashes from provided query. /// A payload for a new hash to add to the index from Postgres or NATS.
///
/// This must be called after you have started a listener, otherwise items may
/// be lost.
async fn create_tree(
conn: &Pool<Postgres>,
config: &Config,
) -> Result<bk_tree::BKTree<Node, Hamming>, Error> {
use futures::TryStreamExt;
tracing::warn!("creating new tree");
let mut tree = bk_tree::BKTree::new(Hamming);
let mut rows = sqlx::query(&config.database_query).fetch(conn);
let mut count = 0;
let start = std::time::Instant::now();
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
let node: Node = row.get::<i64, _>(0).into();
// Avoid checking if each value is unique if we were told that the
// database query only returns unique values.
let timer = crate::TREE_ADD_DURATION.start_timer();
if config.database_is_unique || tree.find_exact(&node).is_none() {
tree.add(node);
}
timer.stop_and_record();
count += 1;
if count % 250_000 == 0 {
tracing::debug!(count, "loaded more rows");
}
}
let dur = std::time::Instant::now().duration_since(start);
tracing::info!(count, "completed loading rows in {:?}", dur);
Ok(tree)
}
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct Payload { struct Payload {
hash: i64, hash: i64,
} }
/// Listen for incoming payloads. /// Listen for incoming payloads from Postgres.
/// #[tracing::instrument(skip(conn, subscription, query, tree, initial))]
/// This will create a new tree to ensure all items are present. It will also pub(crate) async fn listen_for_payloads_db(
/// automatically recreate trees as needed if the database connection is lost.
pub(crate) async fn listen_for_payloads(
conn: Pool<Postgres>, conn: Pool<Postgres>,
config: Config, subscription: String,
query: String,
tree: Tree, tree: Tree,
initial: futures::channel::oneshot::Sender<()>, initial: futures::channel::oneshot::Sender<()>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut initial = Some(initial);
loop {
let mut listener = PgListener::connect_with(&conn) let mut listener = PgListener::connect_with(&conn)
.await .await
.map_err(Error::Listener)?; .map_err(Error::Listener)?;
listener listener
.listen(&config.database_subscribe) .listen(&subscription)
.await .await
.map_err(Error::Listener)?; .map_err(Error::Listener)?;
let new_tree = create_tree(&conn, &config).await?; tree.reload(&conn, &query).await?;
{
let mut tree = tree.write().await;
*tree = new_tree;
}
if let Some(initial) = initial.take() {
initial initial
.send(()) .send(())
.expect_or_log("nothing listening for initial data"); .expect_or_log("nothing listening for initial data");
}
while let Some(notification) = listener.try_recv().await.map_err(Error::Listener)? {
tracing::trace!("got postgres payload");
process_payload(&tree, notification.payload().as_bytes()).await?;
}
tracing::error!("disconnected from postgres listener, recreating tree");
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
}
}
/// Listen for incoming payloads from NATS.
#[tracing::instrument(skip(config, pool, client, tree, initial))]
pub(crate) async fn listen_for_payloads_nats(
config: Config,
pool: sqlx::PgPool,
client: async_nats::Client,
tree: Tree,
initial: futures::channel::oneshot::Sender<()>,
) -> Result<(), Error> {
let jetstream = async_nats::jetstream::new(client);
let mut initial = Some(initial);
let stream = jetstream
.get_or_create_stream(async_nats::jetstream::stream::Config {
name: "bkapi-hashes".to_string(),
subjects: vec!["bkapi.add".to_string()],
max_age: std::time::Duration::from_secs(60 * 60 * 24),
retention: async_nats::jetstream::stream::RetentionPolicy::Interest,
..Default::default()
})
.await?;
loop { loop {
while let Some(notification) = listener.try_recv().await.map_err(Error::Listener)? { let consumer = stream
let payload: Payload = .get_or_create_consumer(
serde_json::from_str(notification.payload()).map_err(Error::Data)?; "bkapi-consumer",
tracing::debug!(hash = payload.hash, "evaluating new payload"); async_nats::jetstream::consumer::pull::Config {
..Default::default()
},
)
.await?;
let node: Node = payload.hash.into(); tree.reload(&pool, &config.database_query).await?;
let _timer = crate::TREE_ADD_DURATION.start_timer(); if let Some(initial) = initial.take() {
initial
let mut tree = tree.write().await; .send(())
if tree.find_exact(&node).is_some() { .expect_or_log("nothing listening for initial data");
tracing::trace!("hash already existed in tree");
continue;
} }
tracing::trace!("hash did not exist, adding to tree"); let mut messages = consumer.messages().await?;
tree.add(node);
while let Ok(Some(message)) = messages.try_next().await {
tracing::trace!("got nats payload");
message.ack().await?;
process_payload(&tree, &message.payload).await?;
} }
tracing::error!("disconnected from listener, recreating tree"); tracing::error!("disconnected from nats listener, recreating tree");
tokio::time::sleep(std::time::Duration::from_secs(10)).await; tokio::time::sleep(std::time::Duration::from_secs(10)).await;
let new_tree = create_tree(&conn, &config).await?;
{
let mut tree = tree.write().await;
*tree = new_tree;
} }
} }
/// Process a payload from Postgres or NATS and add to the tree.
async fn process_payload(tree: &Tree, payload: &[u8]) -> Result<(), Error> {
let payload: Payload = serde_json::from_slice(payload).map_err(Error::Data)?;
tracing::trace!("got hash: {}", payload.hash);
let _timer = TREE_ADD_DURATION.start_timer();
tree.add(payload.hash).await;
Ok(())
} }