Major refactoring, better usage of NATS.

* Gracefully handle shutdowns.
* Support prefix for NATS subjects.
* Use NATS services for search.
* Much better use of NATS stream/consumers.
This commit is contained in:
Syfaro 2023-08-14 23:17:16 -04:00
parent 26dca9a51d
commit 365512b0c2
6 changed files with 1088 additions and 775 deletions

1416
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ publish = false
nats = ["async-nats"] nats = ["async-nats"]
[dependencies] [dependencies]
async-nats = { version = "0.27", optional = true } async-nats = { version = "0.31.0", optional = true }
futures = "0.3" futures = "0.3"
opentelemetry = "0.18" opentelemetry = "0.18"
opentelemetry-http = "0.7" opentelemetry-http = "0.7"

View File

@ -4,6 +4,7 @@ use crate::{SearchResult, SearchResults};
#[derive(Clone)] #[derive(Clone)]
pub struct BKApiNatsClient { pub struct BKApiNatsClient {
client: async_nats::Client, client: async_nats::Client,
subject: String,
} }
/// A hash and distance. /// A hash and distance.
@ -19,8 +20,11 @@ impl BKApiNatsClient {
const NATS_SUBJECT: &str = "bkapi.search"; const NATS_SUBJECT: &str = "bkapi.search";
/// Create a new client with a given NATS client. /// Create a new client with a given NATS client.
pub fn new(client: async_nats::Client) -> Self { pub fn new(client: async_nats::Client, prefix: &str) -> Self {
Self { client } Self {
client,
subject: format!("{}.{}", prefix, Self::NATS_SUBJECT),
}
} }
/// Search for a single hash. /// Search for a single hash.
@ -48,7 +52,7 @@ impl BKApiNatsClient {
let message = self let message = self
.client .client
.request(Self::NATS_SUBJECT.to_string(), payload.into()) .request(self.subject.clone(), payload.into())
.await?; .await?;
let results: Vec<Vec<HashDistance>> = serde_json::from_slice(&message.payload).unwrap(); let results: Vec<Vec<HashDistance>> = serde_json::from_slice(&message.payload).unwrap();

View File

@ -6,36 +6,30 @@ edition = "2018"
publish = false publish = false
[dependencies] [dependencies]
dotenvy = "0.15"
clap = { version = "4", features = ["derive", "env"] }
thiserror = "1"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-unwrap = "0.10"
tracing-opentelemetry = "0.18"
opentelemetry = { version = "0.18", features = ["rt-tokio"] }
opentelemetry-jaeger = { version = "0.17", features = ["rt-tokio"] }
lazy_static = "1"
prometheus = { version = "0.13", features = ["process"] }
bk-tree = "0.4.0"
hamming = "0.1"
futures = "0.3"
tokio = { version = "1", features = ["sync"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
actix-web = "4"
actix-http = "3" actix-http = "3"
actix-service = "2" actix-service = "2"
actix-web = "4"
async-nats = { version = "0.31", features = ["service"] }
bincode = { version = "2.0.0-rc.3", features = ["serde"] }
bk-tree = { version = "0.5.0", features = ["serde"] }
clap = { version = "4", features = ["derive", "env"] }
dotenvy = "0.15"
flate2 = "1"
futures = "0.3"
hamming = "0.1"
lazy_static = "1"
opentelemetry = { version = "0.18", features = ["rt-tokio"] }
prometheus = { version = "0.13", features = ["process"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "1"
tokio = { version = "1", features = ["sync"] }
tokio-util = { version = "0.7", features = ["io", "io-util"] }
tracing = "0.1"
tracing-actix-web = { version = "0.7", features = ["opentelemetry_0_18"] } tracing-actix-web = { version = "0.7", features = ["opentelemetry_0_18"] }
tracing-opentelemetry = "0.18"
async-nats = "0.27" tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-unwrap = "0.10"
foxlib = { git = "https://github.com/Syfaro/foxlib.git" } foxlib = { git = "https://github.com/Syfaro/foxlib.git" }

View File

@ -6,9 +6,12 @@ use actix_web::{
web::{self, Data}, web::{self, Data},
App, HttpResponse, HttpServer, App, HttpResponse, HttpServer,
}; };
use async_nats::service::ServiceExt;
use clap::Parser; use clap::Parser;
use futures::StreamExt; use flate2::{write::GzEncoder, Compression};
use futures::{StreamExt, TryStreamExt};
use sqlx::postgres::PgPoolOptions; use sqlx::postgres::PgPoolOptions;
use tokio_util::sync::CancellationToken;
use tracing::Instrument; use tracing::Instrument;
use tracing_unwrap::ResultExt; use tracing_unwrap::ResultExt;
@ -28,11 +31,27 @@ enum Error {
#[error("listener got data that could not be decoded: {0}")] #[error("listener got data that could not be decoded: {0}")]
Data(serde_json::Error), Data(serde_json::Error),
#[error("nats encountered error: {0}")] #[error("nats encountered error: {0}")]
Nats(#[from] async_nats::Error), Nats(#[from] NatsError),
#[error("io error: {0}")] #[error("io error: {0}")]
Io(#[from] std::io::Error), Io(#[from] std::io::Error),
} }
#[derive(thiserror::Error, Debug)]
enum NatsError {
#[error("{0}")]
Subscribe(#[from] async_nats::SubscribeError),
#[error("{0}")]
CreateStream(#[from] async_nats::jetstream::context::CreateStreamError),
#[error("{0}")]
Consumer(#[from] async_nats::jetstream::stream::ConsumerError),
#[error("{0}")]
Stream(#[from] async_nats::jetstream::consumer::StreamError),
#[error("{0}")]
Request(#[from] async_nats::jetstream::context::RequestError),
#[error("{0}")]
Generic(#[from] async_nats::Error),
}
#[derive(Parser, Clone)] #[derive(Parser, Clone)]
struct Config { struct Config {
/// Host to listen for incoming HTTP requests. /// Host to listen for incoming HTTP requests.
@ -58,11 +77,14 @@ struct Config {
database_subscribe: Option<String>, database_subscribe: Option<String>,
/// The NATS host. /// The NATS host.
#[clap(long, env)] #[clap(long, env, requires = "nats_prefix")]
nats_host: Option<String>, nats_host: Option<String>,
/// The NATS NKEY. /// The NATS NKEY.
#[clap(long, env)] #[clap(long, env)]
nats_nkey: Option<String>, nats_nkey: Option<String>,
/// A prefix to use for NATS subjects.
#[clap(long, env)]
nats_prefix: Option<String>,
/// Maximum distance permitted in queries. /// Maximum distance permitted in queries.
#[clap(long, env, default_value = "10")] #[clap(long, env, default_value = "10")]
@ -84,6 +106,8 @@ async fn main() {
tracing::info!("starting bkapi"); tracing::info!("starting bkapi");
let token = CancellationToken::new();
let metrics_server = foxlib::MetricsServer::serve(config.metrics_host, false).await; let metrics_server = foxlib::MetricsServer::serve(config.metrics_host, false).await;
let tree = tree::Tree::new(); let tree = tree::Tree::new();
@ -115,24 +139,26 @@ async fn main() {
let tree_clone = tree.clone(); let tree_clone = tree.clone();
let config_clone = config.clone(); let config_clone = config.clone();
if let Some(subscription) = config.database_subscribe.clone() { let mut listener_task = if let Some(subscription) = config.database_subscribe.clone() {
tracing::info!("starting to listen for payloads from postgres"); tracing::info!("starting to listen for payloads from postgres");
tokio::spawn(tree::listen_for_payloads_db(
let query = config.database_query.clone(); pool,
subscription,
tokio::task::spawn(async move { config.database_query.clone(),
tree::listen_for_payloads_db(pool, subscription, query, tree_clone, sender) tree_clone,
.await sender,
.unwrap_or_log(); token.clone(),
}); ))
} else if let Some(client) = client.clone() { } else if let Some(client) = client.clone() {
tracing::info!("starting to listen for payloads from nats"); tracing::info!("starting to listen for payloads from nats");
tokio::spawn(tree::listen_for_payloads_nats(
tokio::task::spawn(async { config_clone,
tree::listen_for_payloads_nats(config_clone, pool, client, tree_clone, sender) pool,
.await client,
.unwrap_or_log(); tree_clone,
}); sender,
token.clone(),
))
} else { } else {
panic!("no listener source available"); panic!("no listener source available");
}; };
@ -146,19 +172,48 @@ async fn main() {
metrics_server.set_ready(true); metrics_server.set_ready(true);
if let Some(client) = client { if let Some(client) = client {
let tree_clone = tree.clone(); let tree = tree.clone();
let config_clone = config.clone(); let config = config.clone();
tokio::task::spawn(async move { let token = token.clone();
search_nats(client, tree_clone, config_clone)
tokio::spawn(async move {
search_nats(client, tree, config, token)
.await .await
.unwrap_or_log(); .unwrap_or_log();
}); });
} }
start_server(config, tree).await.unwrap_or_log(); let mut server = start_server(config, tree);
let server_handle = server.handle();
tokio::spawn({
let token = token.clone();
async move {
tokio::signal::ctrl_c()
.await
.expect_or_log("ctrl+c handler failed to install");
token.cancel();
}
});
tokio::select! {
_ = token.cancelled() => {
tracing::info!("got cancellation, stopping server");
let _ = tokio::join!(server_handle.stop(true), server, listener_task);
}
res = &mut listener_task => {
tracing::error!("listener task ended: {res:?}");
let _ = tokio::join!(server_handle.stop(true), server);
}
res = &mut server => {
tracing::error!("server ended: {res:?}");
let _ = tokio::join!(server_handle.stop(true), listener_task);
}
}
} }
async fn start_server(config: Config, tree: tree::Tree) -> Result<(), Error> { fn start_server(config: Config, tree: tree::Tree) -> actix_web::dev::Server {
let tree = Data::new(tree); let tree = Data::new(tree);
let config_data = Data::new(config.clone()); let config_data = Data::new(config.clone());
@ -190,12 +245,12 @@ async fn start_server(config: Config, tree: tree::Tree) -> Result<(), Error> {
.app_data(tree.clone()) .app_data(tree.clone())
.app_data(config_data.clone()) .app_data(config_data.clone())
.service(search) .service(search)
.service(dump)
}) })
.bind(&config.http_listen) .bind(&config.http_listen)
.expect_or_log("bind failed") .expect_or_log("bind failed")
.disable_signals()
.run() .run()
.await
.map_err(Error::Io)
} }
#[derive(Debug, serde::Deserialize)] #[derive(Debug, serde::Deserialize)]
@ -242,56 +297,77 @@ struct SearchPayload {
distance: u32, distance: u32,
} }
#[tracing::instrument(skip(client, tree, config))] #[tracing::instrument(skip_all)]
async fn search_nats( async fn search_nats(
client: async_nats::Client, client: async_nats::Client,
tree: tree::Tree, tree: tree::Tree,
config: Config, config: Config,
token: CancellationToken,
) -> Result<(), Error> { ) -> Result<(), Error> {
tracing::info!("subscribing to searches"); tracing::info!("subscribing to searches");
let client = Arc::new(client); let client = Arc::new(client);
let max_distance = config.max_distance; let max_distance = config.max_distance;
let mut sub = client let service = client
.queue_subscribe("bkapi.search".to_string(), "bkapi-search".to_string()) .add_service(async_nats::service::Config {
.await?; name: "bkapi-search".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
description: None,
stats_handler: None,
metadata: None,
})
.await
.map_err(NatsError::Generic)?;
while let Some(message) = sub.next().await { let mut endpoint = service
.endpoint(format!("{}.bkapi.search", config.nats_prefix.unwrap()))
.await
.map_err(NatsError::Generic)?;
loop {
tokio::select! {
Some(request) = endpoint.next() => {
tracing::trace!("got search message"); tracing::trace!("got search message");
let reply = match message.reply { if let Err(err) = handle_search_nats(max_distance, tree.clone(), request).await {
Some(reply) => reply,
None => {
tracing::warn!("message had no reply subject, skipping");
continue;
}
};
if let Err(err) = handle_search_nats(
max_distance,
client.clone(),
tree.clone(),
reply,
&message.payload,
)
.await
{
tracing::error!("could not handle nats search: {err}"); tracing::error!("could not handle nats search: {err}");
} }
} }
_ = token.cancelled() => {
tracing::info!("cancelled, stopping endpoint");
break;
}
}
}
if let Err(err) = endpoint.stop().await {
tracing::error!("could not stop endpoint: {err}");
}
Ok(()) Ok(())
} }
async fn handle_search_nats( async fn handle_search_nats(
max_distance: u32, max_distance: u32,
client: Arc<async_nats::Client>,
tree: tree::Tree, tree: tree::Tree,
reply: String, request: async_nats::service::Request,
payload: &[u8],
) -> Result<(), Error> { ) -> Result<(), Error> {
let payloads: Vec<SearchPayload> = serde_json::from_slice(payload).map_err(Error::Data)?; let payloads: Vec<SearchPayload> = match serde_json::from_slice(&request.message.payload) {
Ok(payloads) => payloads,
Err(err) => {
let err = Err(async_nats::service::error::Error {
status: err.to_string(),
code: 400,
});
if let Err(err) = request.respond(err).await {
tracing::error!("could not respond with error: {err}");
}
return Ok(());
}
};
tokio::task::spawn( tokio::task::spawn(
async move { async move {
@ -302,16 +378,15 @@ async fn handle_search_nats(
let results = tree.find(hashes).await; let results = tree.find(hashes).await;
if let Err(err) = client let resp = serde_json::to_vec(&results).map(Into::into).map_err(|err| {
.publish( async_nats::service::error::Error {
reply, status: err.to_string(),
serde_json::to_vec(&results) code: 503,
.expect_or_log("results could not be serialized") }
.into(), });
)
.await if let Err(err) = request.respond(resp).await {
{ tracing::error!("could not respond: {err}");
tracing::error!("could not publish results: {err}");
} }
} }
.in_current_span(), .in_current_span(),
@ -319,3 +394,44 @@ async fn handle_search_nats(
Ok(()) Ok(())
} }
#[get("/dump")]
async fn dump(tree: Data<tree::Tree>) -> HttpResponse {
let (wtr, rdr) = tokio::io::duplex(4096);
tokio::task::spawn_blocking(move || {
let tree = tree.tree.blocking_read();
let bridge = tokio_util::io::SyncIoBridge::new(wtr);
let mut compressor = GzEncoder::new(bridge, Compression::default());
if let Err(err) = bincode::serde::encode_into_std_write(
&*tree,
&mut compressor,
bincode::config::standard(),
) {
tracing::error!("could not write tree to compressor: {err}");
}
match compressor.finish() {
Ok(mut file) => match file.shutdown() {
Ok(_) => tracing::info!("finished writing dump"),
Err(err) => tracing::error!("could not finish writing dump: {err}"),
},
Err(err) => {
tracing::error!("could not finish compressor: {err}");
}
}
});
let stream = tokio_util::codec::FramedRead::new(rdr, tokio_util::codec::BytesCodec::new())
.map_ok(|b| b.freeze());
HttpResponse::Ok()
.content_type("application/octet-stream")
.insert_header((
"content-disposition",
r#"attachment; filename="bkapi-dump.dat.gz""#,
))
.streaming(stream)
}

View File

@ -1,12 +1,15 @@
use std::sync::Arc; use std::sync::Arc;
use async_nats::jetstream::consumer::DeliverPolicy;
use bk_tree::BKTree; use bk_tree::BKTree;
use futures::TryStreamExt; use futures::{StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use sqlx::{postgres::PgListener, Pool, Postgres, Row}; use sqlx::{postgres::PgListener, Pool, Postgres, Row};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use tracing_unwrap::ResultExt; use tracing_unwrap::ResultExt;
use crate::{Config, Error}; use crate::{Config, Error, NatsError};
lazy_static::lazy_static! { lazy_static::lazy_static! {
static ref TREE_ENTRIES: prometheus::IntCounter = prometheus::register_int_counter!("bkapi_tree_entries", "Total number of entries within tree").unwrap(); static ref TREE_ENTRIES: prometheus::IntCounter = prometheus::register_int_counter!("bkapi_tree_entries", "Total number of entries within tree").unwrap();
@ -18,7 +21,7 @@ lazy_static::lazy_static! {
/// A BKTree wrapper to cover common operations. /// A BKTree wrapper to cover common operations.
#[derive(Clone)] #[derive(Clone)]
pub struct Tree { pub struct Tree {
tree: Arc<RwLock<BKTree<Node, Hamming>>>, pub tree: Arc<RwLock<BKTree<Node, Hamming>>>,
} }
/// A hash and distance pair. May be used for searching or in search results. /// A hash and distance pair. May be used for searching or in search results.
@ -116,7 +119,6 @@ impl Tree {
.start_timer(); .start_timer();
let results: Vec<_> = tree let results: Vec<_> = tree
.find(&hash.into(), distance) .find(&hash.into(), distance)
.into_iter()
.map(|item| HashDistance { .map(|item| HashDistance {
distance: item.0, distance: item.0,
hash: (*item.1).into(), hash: (*item.1).into(),
@ -130,7 +132,8 @@ impl Tree {
} }
/// A hamming distance metric. /// A hamming distance metric.
struct Hamming; #[derive(Serialize, Deserialize)]
pub struct Hamming;
impl bk_tree::Metric<Node> for Hamming { impl bk_tree::Metric<Node> for Hamming {
fn distance(&self, a: &Node, b: &Node) -> u32 { fn distance(&self, a: &Node, b: &Node) -> u32 {
@ -144,8 +147,8 @@ impl bk_tree::Metric<Node> for Hamming {
} }
/// A value of a node in the BK tree. /// A value of a node in the BK tree.
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
struct Node([u8; 8]); pub struct Node([u8; 8]);
impl From<i64> for Node { impl From<i64> for Node {
fn from(num: i64) -> Self { fn from(num: i64) -> Self {
@ -173,6 +176,7 @@ pub(crate) async fn listen_for_payloads_db(
query: String, query: String,
tree: Tree, tree: Tree,
initial: futures::channel::oneshot::Sender<()>, initial: futures::channel::oneshot::Sender<()>,
token: CancellationToken,
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut initial = Some(initial); let mut initial = Some(initial);
@ -193,15 +197,40 @@ pub(crate) async fn listen_for_payloads_db(
.expect_or_log("nothing listening for initial data"); .expect_or_log("nothing listening for initial data");
} }
while let Some(notification) = listener.try_recv().await.map_err(Error::Listener)? { loop {
tokio::select! {
res = listener.try_recv() => {
match res {
Ok(Some(notification)) => {
tracing::trace!("got postgres payload"); tracing::trace!("got postgres payload");
process_payload(&tree, notification.payload().as_bytes()).await?; process_payload(&tree, notification.payload().as_bytes()).await?;
} }
Ok(None) => {
tracing::warn!("got none value from recv");
break;
}
Err(err) => {
tracing::error!("got recv error: {err}");
break;
}
}
}
_ = token.cancelled() => {
tracing::info!("got cancellation");
break;
}
}
}
if token.is_cancelled() {
tracing::info!("cancelled, stopping db listener");
return Ok(());
} else {
tracing::error!("disconnected from postgres listener, recreating tree"); tracing::error!("disconnected from postgres listener, recreating tree");
tokio::time::sleep(std::time::Duration::from_secs(10)).await; tokio::time::sleep(std::time::Duration::from_secs(10)).await;
} }
} }
}
/// Listen for incoming payloads from NATS. /// Listen for incoming payloads from NATS.
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
@ -211,29 +240,45 @@ pub(crate) async fn listen_for_payloads_nats(
client: async_nats::Client, client: async_nats::Client,
tree: Tree, tree: Tree,
initial: futures::channel::oneshot::Sender<()>, initial: futures::channel::oneshot::Sender<()>,
token: CancellationToken,
) -> Result<(), Error> { ) -> Result<(), Error> {
let jetstream = async_nats::jetstream::new(client); let jetstream = async_nats::jetstream::new(client);
let mut initial = Some(initial); let mut initial = Some(initial);
let stream = jetstream let stream = jetstream
.get_or_create_stream(async_nats::jetstream::stream::Config { .get_or_create_stream(async_nats::jetstream::stream::Config {
name: "bkapi-hashes".to_string(), name: format!(
subjects: vec!["bkapi.add".to_string()], "{}-bkapi-hashes",
max_age: std::time::Duration::from_secs(60 * 60 * 24), config.nats_prefix.clone().unwrap().replace('.', "-")
retention: async_nats::jetstream::stream::RetentionPolicy::Interest, ),
subjects: vec![format!("{}.bkapi.add", config.nats_prefix.unwrap())],
max_age: std::time::Duration::from_secs(60 * 30),
retention: async_nats::jetstream::stream::RetentionPolicy::Limits,
..Default::default() ..Default::default()
}) })
.await?; .await
.map_err(NatsError::CreateStream)?;
loop { // Because we're tracking the last sequence ID before we load tree data,
let consumer = stream // we don't need to start the listener until after it's loaded. This
.get_or_create_consumer( // prevents issues with a slow client but still retains every hash.
"bkapi-consumer", let mut seq = stream.cached_info().state.last_sequence;
async_nats::jetstream::consumer::pull::Config {
..Default::default() let create_consumer = |stream: async_nats::jetstream::stream::Stream, start_sequence: u64| async move {
tracing::info!(start_sequence, "creating consumer");
stream
.create_consumer(async_nats::jetstream::consumer::pull::Config {
deliver_policy: if start_sequence > 0 {
DeliverPolicy::ByStartSequence { start_sequence }
} else {
DeliverPolicy::All
}, },
) ..Default::default()
.await?; })
.await
.map_err(NatsError::Consumer)
};
tree.reload(&pool, &config.database_query).await?; tree.reload(&pool, &config.database_query).await?;
@ -243,21 +288,41 @@ pub(crate) async fn listen_for_payloads_nats(
.expect_or_log("nothing listening for initial data"); .expect_or_log("nothing listening for initial data");
} }
let mut messages = consumer.messages().await?; loop {
let consumer = create_consumer(stream.clone(), seq).await?;
let messages = consumer
.messages()
.await
.map_err(NatsError::Stream)?
.take_until(token.cancelled());
tokio::pin!(messages);
while let Ok(Some(message)) = messages.try_next().await { while let Ok(Some(message)) = messages.try_next().await {
tracing::trace!("got nats payload");
message.ack().await?;
process_payload(&tree, &message.payload).await?; process_payload(&tree, &message.payload).await?;
message.ack().await.map_err(NatsError::Generic)?;
seq = message
.info()
.expect_or_log("message missing info")
.stream_sequence;
} }
tracing::error!("disconnected from nats listener, recreating tree"); if token.is_cancelled() {
tracing::info!("cancelled, stopping nats listener");
return Ok(());
} else {
tracing::error!("disconnected from nats listener");
tokio::time::sleep(std::time::Duration::from_secs(10)).await; tokio::time::sleep(std::time::Duration::from_secs(10)).await;
} }
} }
}
/// Process a payload from Postgres or NATS and add to the tree. /// Process a payload from Postgres or NATS and add to the tree.
#[tracing::instrument(skip_all)]
async fn process_payload(tree: &Tree, payload: &[u8]) -> Result<(), Error> { async fn process_payload(tree: &Tree, payload: &[u8]) -> Result<(), Error> {
tracing::trace!("got payload: {}", String::from_utf8_lossy(payload));
let payload: Payload = serde_json::from_slice(payload).map_err(Error::Data)?; let payload: Payload = serde_json::from_slice(payload).map_err(Error::Data)?;
tracing::trace!("got hash: {}", payload.hash); tracing::trace!("got hash: {}", payload.hash);