Major refactoring, better usage of NATS.

* Gracefully handle shutdowns.
* Support prefix for NATS subjects.
* Use NATS services for search.
* Much better use of NATS stream/consumers.
This commit is contained in:
Syfaro 2023-08-14 23:17:16 -04:00
parent 26dca9a51d
commit 365512b0c2
6 changed files with 1088 additions and 775 deletions

1416
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ publish = false
nats = ["async-nats"]
[dependencies]
async-nats = { version = "0.27", optional = true }
async-nats = { version = "0.31.0", optional = true }
futures = "0.3"
opentelemetry = "0.18"
opentelemetry-http = "0.7"

View File

@ -4,6 +4,7 @@ use crate::{SearchResult, SearchResults};
#[derive(Clone)]
pub struct BKApiNatsClient {
client: async_nats::Client,
subject: String,
}
/// A hash and distance.
@ -19,8 +20,11 @@ impl BKApiNatsClient {
const NATS_SUBJECT: &str = "bkapi.search";
/// Create a new client with a given NATS client.
pub fn new(client: async_nats::Client) -> Self {
Self { client }
pub fn new(client: async_nats::Client, prefix: &str) -> Self {
Self {
client,
subject: format!("{}.{}", prefix, Self::NATS_SUBJECT),
}
}
/// Search for a single hash.
@ -48,7 +52,7 @@ impl BKApiNatsClient {
let message = self
.client
.request(Self::NATS_SUBJECT.to_string(), payload.into())
.request(self.subject.clone(), payload.into())
.await?;
let results: Vec<Vec<HashDistance>> = serde_json::from_slice(&message.payload).unwrap();

View File

@ -6,36 +6,30 @@ edition = "2018"
publish = false
[dependencies]
dotenvy = "0.15"
clap = { version = "4", features = ["derive", "env"] }
thiserror = "1"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-unwrap = "0.10"
tracing-opentelemetry = "0.18"
opentelemetry = { version = "0.18", features = ["rt-tokio"] }
opentelemetry-jaeger = { version = "0.17", features = ["rt-tokio"] }
lazy_static = "1"
prometheus = { version = "0.13", features = ["process"] }
bk-tree = "0.4.0"
hamming = "0.1"
futures = "0.3"
tokio = { version = "1", features = ["sync"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
actix-web = "4"
actix-http = "3"
actix-service = "2"
actix-web = "4"
async-nats = { version = "0.31", features = ["service"] }
bincode = { version = "2.0.0-rc.3", features = ["serde"] }
bk-tree = { version = "0.5.0", features = ["serde"] }
clap = { version = "4", features = ["derive", "env"] }
dotenvy = "0.15"
flate2 = "1"
futures = "0.3"
hamming = "0.1"
lazy_static = "1"
opentelemetry = { version = "0.18", features = ["rt-tokio"] }
prometheus = { version = "0.13", features = ["process"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "1"
tokio = { version = "1", features = ["sync"] }
tokio-util = { version = "0.7", features = ["io", "io-util"] }
tracing = "0.1"
tracing-actix-web = { version = "0.7", features = ["opentelemetry_0_18"] }
async-nats = "0.27"
tracing-opentelemetry = "0.18"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-unwrap = "0.10"
foxlib = { git = "https://github.com/Syfaro/foxlib.git" }

View File

@ -6,9 +6,12 @@ use actix_web::{
web::{self, Data},
App, HttpResponse, HttpServer,
};
use async_nats::service::ServiceExt;
use clap::Parser;
use futures::StreamExt;
use flate2::{write::GzEncoder, Compression};
use futures::{StreamExt, TryStreamExt};
use sqlx::postgres::PgPoolOptions;
use tokio_util::sync::CancellationToken;
use tracing::Instrument;
use tracing_unwrap::ResultExt;
@ -28,11 +31,27 @@ enum Error {
#[error("listener got data that could not be decoded: {0}")]
Data(serde_json::Error),
#[error("nats encountered error: {0}")]
Nats(#[from] async_nats::Error),
Nats(#[from] NatsError),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
}
#[derive(thiserror::Error, Debug)]
enum NatsError {
#[error("{0}")]
Subscribe(#[from] async_nats::SubscribeError),
#[error("{0}")]
CreateStream(#[from] async_nats::jetstream::context::CreateStreamError),
#[error("{0}")]
Consumer(#[from] async_nats::jetstream::stream::ConsumerError),
#[error("{0}")]
Stream(#[from] async_nats::jetstream::consumer::StreamError),
#[error("{0}")]
Request(#[from] async_nats::jetstream::context::RequestError),
#[error("{0}")]
Generic(#[from] async_nats::Error),
}
#[derive(Parser, Clone)]
struct Config {
/// Host to listen for incoming HTTP requests.
@ -58,11 +77,14 @@ struct Config {
database_subscribe: Option<String>,
/// The NATS host.
#[clap(long, env)]
#[clap(long, env, requires = "nats_prefix")]
nats_host: Option<String>,
/// The NATS NKEY.
#[clap(long, env)]
nats_nkey: Option<String>,
/// A prefix to use for NATS subjects.
#[clap(long, env)]
nats_prefix: Option<String>,
/// Maximum distance permitted in queries.
#[clap(long, env, default_value = "10")]
@ -84,6 +106,8 @@ async fn main() {
tracing::info!("starting bkapi");
let token = CancellationToken::new();
let metrics_server = foxlib::MetricsServer::serve(config.metrics_host, false).await;
let tree = tree::Tree::new();
@ -115,24 +139,26 @@ async fn main() {
let tree_clone = tree.clone();
let config_clone = config.clone();
if let Some(subscription) = config.database_subscribe.clone() {
let mut listener_task = if let Some(subscription) = config.database_subscribe.clone() {
tracing::info!("starting to listen for payloads from postgres");
let query = config.database_query.clone();
tokio::task::spawn(async move {
tree::listen_for_payloads_db(pool, subscription, query, tree_clone, sender)
.await
.unwrap_or_log();
});
tokio::spawn(tree::listen_for_payloads_db(
pool,
subscription,
config.database_query.clone(),
tree_clone,
sender,
token.clone(),
))
} else if let Some(client) = client.clone() {
tracing::info!("starting to listen for payloads from nats");
tokio::task::spawn(async {
tree::listen_for_payloads_nats(config_clone, pool, client, tree_clone, sender)
.await
.unwrap_or_log();
});
tokio::spawn(tree::listen_for_payloads_nats(
config_clone,
pool,
client,
tree_clone,
sender,
token.clone(),
))
} else {
panic!("no listener source available");
};
@ -146,19 +172,48 @@ async fn main() {
metrics_server.set_ready(true);
if let Some(client) = client {
let tree_clone = tree.clone();
let config_clone = config.clone();
tokio::task::spawn(async move {
search_nats(client, tree_clone, config_clone)
let tree = tree.clone();
let config = config.clone();
let token = token.clone();
tokio::spawn(async move {
search_nats(client, tree, config, token)
.await
.unwrap_or_log();
});
}
start_server(config, tree).await.unwrap_or_log();
let mut server = start_server(config, tree);
let server_handle = server.handle();
tokio::spawn({
let token = token.clone();
async move {
tokio::signal::ctrl_c()
.await
.expect_or_log("ctrl+c handler failed to install");
token.cancel();
}
});
tokio::select! {
_ = token.cancelled() => {
tracing::info!("got cancellation, stopping server");
let _ = tokio::join!(server_handle.stop(true), server, listener_task);
}
res = &mut listener_task => {
tracing::error!("listener task ended: {res:?}");
let _ = tokio::join!(server_handle.stop(true), server);
}
res = &mut server => {
tracing::error!("server ended: {res:?}");
let _ = tokio::join!(server_handle.stop(true), listener_task);
}
}
}
async fn start_server(config: Config, tree: tree::Tree) -> Result<(), Error> {
fn start_server(config: Config, tree: tree::Tree) -> actix_web::dev::Server {
let tree = Data::new(tree);
let config_data = Data::new(config.clone());
@ -190,12 +245,12 @@ async fn start_server(config: Config, tree: tree::Tree) -> Result<(), Error> {
.app_data(tree.clone())
.app_data(config_data.clone())
.service(search)
.service(dump)
})
.bind(&config.http_listen)
.expect_or_log("bind failed")
.disable_signals()
.run()
.await
.map_err(Error::Io)
}
#[derive(Debug, serde::Deserialize)]
@ -242,56 +297,77 @@ struct SearchPayload {
distance: u32,
}
#[tracing::instrument(skip(client, tree, config))]
#[tracing::instrument(skip_all)]
async fn search_nats(
client: async_nats::Client,
tree: tree::Tree,
config: Config,
token: CancellationToken,
) -> Result<(), Error> {
tracing::info!("subscribing to searches");
let client = Arc::new(client);
let max_distance = config.max_distance;
let mut sub = client
.queue_subscribe("bkapi.search".to_string(), "bkapi-search".to_string())
.await?;
let service = client
.add_service(async_nats::service::Config {
name: "bkapi-search".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
description: None,
stats_handler: None,
metadata: None,
})
.await
.map_err(NatsError::Generic)?;
while let Some(message) = sub.next().await {
let mut endpoint = service
.endpoint(format!("{}.bkapi.search", config.nats_prefix.unwrap()))
.await
.map_err(NatsError::Generic)?;
loop {
tokio::select! {
Some(request) = endpoint.next() => {
tracing::trace!("got search message");
let reply = match message.reply {
Some(reply) => reply,
None => {
tracing::warn!("message had no reply subject, skipping");
continue;
}
};
if let Err(err) = handle_search_nats(
max_distance,
client.clone(),
tree.clone(),
reply,
&message.payload,
)
.await
{
if let Err(err) = handle_search_nats(max_distance, tree.clone(), request).await {
tracing::error!("could not handle nats search: {err}");
}
}
_ = token.cancelled() => {
tracing::info!("cancelled, stopping endpoint");
break;
}
}
}
if let Err(err) = endpoint.stop().await {
tracing::error!("could not stop endpoint: {err}");
}
Ok(())
}
async fn handle_search_nats(
max_distance: u32,
client: Arc<async_nats::Client>,
tree: tree::Tree,
reply: String,
payload: &[u8],
request: async_nats::service::Request,
) -> Result<(), Error> {
let payloads: Vec<SearchPayload> = serde_json::from_slice(payload).map_err(Error::Data)?;
let payloads: Vec<SearchPayload> = match serde_json::from_slice(&request.message.payload) {
Ok(payloads) => payloads,
Err(err) => {
let err = Err(async_nats::service::error::Error {
status: err.to_string(),
code: 400,
});
if let Err(err) = request.respond(err).await {
tracing::error!("could not respond with error: {err}");
}
return Ok(());
}
};
tokio::task::spawn(
async move {
@ -302,16 +378,15 @@ async fn handle_search_nats(
let results = tree.find(hashes).await;
if let Err(err) = client
.publish(
reply,
serde_json::to_vec(&results)
.expect_or_log("results could not be serialized")
.into(),
)
.await
{
tracing::error!("could not publish results: {err}");
let resp = serde_json::to_vec(&results).map(Into::into).map_err(|err| {
async_nats::service::error::Error {
status: err.to_string(),
code: 503,
}
});
if let Err(err) = request.respond(resp).await {
tracing::error!("could not respond: {err}");
}
}
.in_current_span(),
@ -319,3 +394,44 @@ async fn handle_search_nats(
Ok(())
}
#[get("/dump")]
async fn dump(tree: Data<tree::Tree>) -> HttpResponse {
let (wtr, rdr) = tokio::io::duplex(4096);
tokio::task::spawn_blocking(move || {
let tree = tree.tree.blocking_read();
let bridge = tokio_util::io::SyncIoBridge::new(wtr);
let mut compressor = GzEncoder::new(bridge, Compression::default());
if let Err(err) = bincode::serde::encode_into_std_write(
&*tree,
&mut compressor,
bincode::config::standard(),
) {
tracing::error!("could not write tree to compressor: {err}");
}
match compressor.finish() {
Ok(mut file) => match file.shutdown() {
Ok(_) => tracing::info!("finished writing dump"),
Err(err) => tracing::error!("could not finish writing dump: {err}"),
},
Err(err) => {
tracing::error!("could not finish compressor: {err}");
}
}
});
let stream = tokio_util::codec::FramedRead::new(rdr, tokio_util::codec::BytesCodec::new())
.map_ok(|b| b.freeze());
HttpResponse::Ok()
.content_type("application/octet-stream")
.insert_header((
"content-disposition",
r#"attachment; filename="bkapi-dump.dat.gz""#,
))
.streaming(stream)
}

View File

@ -1,12 +1,15 @@
use std::sync::Arc;
use async_nats::jetstream::consumer::DeliverPolicy;
use bk_tree::BKTree;
use futures::TryStreamExt;
use futures::{StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use sqlx::{postgres::PgListener, Pool, Postgres, Row};
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use tracing_unwrap::ResultExt;
use crate::{Config, Error};
use crate::{Config, Error, NatsError};
lazy_static::lazy_static! {
static ref TREE_ENTRIES: prometheus::IntCounter = prometheus::register_int_counter!("bkapi_tree_entries", "Total number of entries within tree").unwrap();
@ -18,7 +21,7 @@ lazy_static::lazy_static! {
/// A BKTree wrapper to cover common operations.
#[derive(Clone)]
pub struct Tree {
tree: Arc<RwLock<BKTree<Node, Hamming>>>,
pub tree: Arc<RwLock<BKTree<Node, Hamming>>>,
}
/// A hash and distance pair. May be used for searching or in search results.
@ -116,7 +119,6 @@ impl Tree {
.start_timer();
let results: Vec<_> = tree
.find(&hash.into(), distance)
.into_iter()
.map(|item| HashDistance {
distance: item.0,
hash: (*item.1).into(),
@ -130,7 +132,8 @@ impl Tree {
}
/// A hamming distance metric.
struct Hamming;
#[derive(Serialize, Deserialize)]
pub struct Hamming;
impl bk_tree::Metric<Node> for Hamming {
fn distance(&self, a: &Node, b: &Node) -> u32 {
@ -144,8 +147,8 @@ impl bk_tree::Metric<Node> for Hamming {
}
/// A value of a node in the BK tree.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct Node([u8; 8]);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct Node([u8; 8]);
impl From<i64> for Node {
fn from(num: i64) -> Self {
@ -173,6 +176,7 @@ pub(crate) async fn listen_for_payloads_db(
query: String,
tree: Tree,
initial: futures::channel::oneshot::Sender<()>,
token: CancellationToken,
) -> Result<(), Error> {
let mut initial = Some(initial);
@ -193,14 +197,39 @@ pub(crate) async fn listen_for_payloads_db(
.expect_or_log("nothing listening for initial data");
}
while let Some(notification) = listener.try_recv().await.map_err(Error::Listener)? {
loop {
tokio::select! {
res = listener.try_recv() => {
match res {
Ok(Some(notification)) => {
tracing::trace!("got postgres payload");
process_payload(&tree, notification.payload().as_bytes()).await?;
}
Ok(None) => {
tracing::warn!("got none value from recv");
break;
}
Err(err) => {
tracing::error!("got recv error: {err}");
break;
}
}
}
_ = token.cancelled() => {
tracing::info!("got cancellation");
break;
}
}
}
if token.is_cancelled() {
tracing::info!("cancelled, stopping db listener");
return Ok(());
} else {
tracing::error!("disconnected from postgres listener, recreating tree");
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
}
}
}
/// Listen for incoming payloads from NATS.
@ -211,29 +240,45 @@ pub(crate) async fn listen_for_payloads_nats(
client: async_nats::Client,
tree: Tree,
initial: futures::channel::oneshot::Sender<()>,
token: CancellationToken,
) -> Result<(), Error> {
let jetstream = async_nats::jetstream::new(client);
let mut initial = Some(initial);
let stream = jetstream
.get_or_create_stream(async_nats::jetstream::stream::Config {
name: "bkapi-hashes".to_string(),
subjects: vec!["bkapi.add".to_string()],
max_age: std::time::Duration::from_secs(60 * 60 * 24),
retention: async_nats::jetstream::stream::RetentionPolicy::Interest,
name: format!(
"{}-bkapi-hashes",
config.nats_prefix.clone().unwrap().replace('.', "-")
),
subjects: vec![format!("{}.bkapi.add", config.nats_prefix.unwrap())],
max_age: std::time::Duration::from_secs(60 * 30),
retention: async_nats::jetstream::stream::RetentionPolicy::Limits,
..Default::default()
})
.await?;
.await
.map_err(NatsError::CreateStream)?;
loop {
let consumer = stream
.get_or_create_consumer(
"bkapi-consumer",
async_nats::jetstream::consumer::pull::Config {
..Default::default()
// Because we're tracking the last sequence ID before we load tree data,
// we don't need to start the listener until after it's loaded. This
// prevents issues with a slow client but still retains every hash.
let mut seq = stream.cached_info().state.last_sequence;
let create_consumer = |stream: async_nats::jetstream::stream::Stream, start_sequence: u64| async move {
tracing::info!(start_sequence, "creating consumer");
stream
.create_consumer(async_nats::jetstream::consumer::pull::Config {
deliver_policy: if start_sequence > 0 {
DeliverPolicy::ByStartSequence { start_sequence }
} else {
DeliverPolicy::All
},
)
.await?;
..Default::default()
})
.await
.map_err(NatsError::Consumer)
};
tree.reload(&pool, &config.database_query).await?;
@ -243,21 +288,41 @@ pub(crate) async fn listen_for_payloads_nats(
.expect_or_log("nothing listening for initial data");
}
let mut messages = consumer.messages().await?;
loop {
let consumer = create_consumer(stream.clone(), seq).await?;
let messages = consumer
.messages()
.await
.map_err(NatsError::Stream)?
.take_until(token.cancelled());
tokio::pin!(messages);
while let Ok(Some(message)) = messages.try_next().await {
tracing::trace!("got nats payload");
message.ack().await?;
process_payload(&tree, &message.payload).await?;
message.ack().await.map_err(NatsError::Generic)?;
seq = message
.info()
.expect_or_log("message missing info")
.stream_sequence;
}
tracing::error!("disconnected from nats listener, recreating tree");
if token.is_cancelled() {
tracing::info!("cancelled, stopping nats listener");
return Ok(());
} else {
tracing::error!("disconnected from nats listener");
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
}
}
}
/// Process a payload from Postgres or NATS and add to the tree.
#[tracing::instrument(skip_all)]
async fn process_payload(tree: &Tree, payload: &[u8]) -> Result<(), Error> {
tracing::trace!("got payload: {}", String::from_utf8_lossy(payload));
let payload: Payload = serde_json::from_slice(payload).map_err(Error::Data)?;
tracing::trace!("got hash: {}", payload.hash);