mirror of
https://github.com/Syfaro/bkapi.git
synced 2024-11-05 14:44:29 +00:00
Major refactoring, better usage of NATS.
* Gracefully handle shutdowns. * Support prefix for NATS subjects. * Use NATS services for search. * Much better use of NATS stream/consumers.
This commit is contained in:
parent
26dca9a51d
commit
365512b0c2
1416
Cargo.lock
generated
1416
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -9,7 +9,7 @@ publish = false
|
|||||||
nats = ["async-nats"]
|
nats = ["async-nats"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-nats = { version = "0.27", optional = true }
|
async-nats = { version = "0.31.0", optional = true }
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
opentelemetry = "0.18"
|
opentelemetry = "0.18"
|
||||||
opentelemetry-http = "0.7"
|
opentelemetry-http = "0.7"
|
||||||
|
@ -4,6 +4,7 @@ use crate::{SearchResult, SearchResults};
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct BKApiNatsClient {
|
pub struct BKApiNatsClient {
|
||||||
client: async_nats::Client,
|
client: async_nats::Client,
|
||||||
|
subject: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A hash and distance.
|
/// A hash and distance.
|
||||||
@ -19,8 +20,11 @@ impl BKApiNatsClient {
|
|||||||
const NATS_SUBJECT: &str = "bkapi.search";
|
const NATS_SUBJECT: &str = "bkapi.search";
|
||||||
|
|
||||||
/// Create a new client with a given NATS client.
|
/// Create a new client with a given NATS client.
|
||||||
pub fn new(client: async_nats::Client) -> Self {
|
pub fn new(client: async_nats::Client, prefix: &str) -> Self {
|
||||||
Self { client }
|
Self {
|
||||||
|
client,
|
||||||
|
subject: format!("{}.{}", prefix, Self::NATS_SUBJECT),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Search for a single hash.
|
/// Search for a single hash.
|
||||||
@ -48,7 +52,7 @@ impl BKApiNatsClient {
|
|||||||
|
|
||||||
let message = self
|
let message = self
|
||||||
.client
|
.client
|
||||||
.request(Self::NATS_SUBJECT.to_string(), payload.into())
|
.request(self.subject.clone(), payload.into())
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let results: Vec<Vec<HashDistance>> = serde_json::from_slice(&message.payload).unwrap();
|
let results: Vec<Vec<HashDistance>> = serde_json::from_slice(&message.payload).unwrap();
|
||||||
|
@ -6,36 +6,30 @@ edition = "2018"
|
|||||||
publish = false
|
publish = false
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
dotenvy = "0.15"
|
|
||||||
clap = { version = "4", features = ["derive", "env"] }
|
|
||||||
thiserror = "1"
|
|
||||||
|
|
||||||
tracing = "0.1"
|
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
|
||||||
tracing-unwrap = "0.10"
|
|
||||||
tracing-opentelemetry = "0.18"
|
|
||||||
|
|
||||||
opentelemetry = { version = "0.18", features = ["rt-tokio"] }
|
|
||||||
opentelemetry-jaeger = { version = "0.17", features = ["rt-tokio"] }
|
|
||||||
|
|
||||||
lazy_static = "1"
|
|
||||||
prometheus = { version = "0.13", features = ["process"] }
|
|
||||||
|
|
||||||
bk-tree = "0.4.0"
|
|
||||||
hamming = "0.1"
|
|
||||||
|
|
||||||
futures = "0.3"
|
|
||||||
tokio = { version = "1", features = ["sync"] }
|
|
||||||
|
|
||||||
serde = { version = "1", features = ["derive"] }
|
|
||||||
serde_json = "1"
|
|
||||||
|
|
||||||
actix-web = "4"
|
|
||||||
actix-http = "3"
|
actix-http = "3"
|
||||||
actix-service = "2"
|
actix-service = "2"
|
||||||
|
actix-web = "4"
|
||||||
|
async-nats = { version = "0.31", features = ["service"] }
|
||||||
|
bincode = { version = "2.0.0-rc.3", features = ["serde"] }
|
||||||
|
bk-tree = { version = "0.5.0", features = ["serde"] }
|
||||||
|
clap = { version = "4", features = ["derive", "env"] }
|
||||||
|
dotenvy = "0.15"
|
||||||
|
flate2 = "1"
|
||||||
|
futures = "0.3"
|
||||||
|
hamming = "0.1"
|
||||||
|
lazy_static = "1"
|
||||||
|
opentelemetry = { version = "0.18", features = ["rt-tokio"] }
|
||||||
|
prometheus = { version = "0.13", features = ["process"] }
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
thiserror = "1"
|
||||||
|
tokio = { version = "1", features = ["sync"] }
|
||||||
|
tokio-util = { version = "0.7", features = ["io", "io-util"] }
|
||||||
|
tracing = "0.1"
|
||||||
tracing-actix-web = { version = "0.7", features = ["opentelemetry_0_18"] }
|
tracing-actix-web = { version = "0.7", features = ["opentelemetry_0_18"] }
|
||||||
|
tracing-opentelemetry = "0.18"
|
||||||
async-nats = "0.27"
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
tracing-unwrap = "0.10"
|
||||||
|
|
||||||
foxlib = { git = "https://github.com/Syfaro/foxlib.git" }
|
foxlib = { git = "https://github.com/Syfaro/foxlib.git" }
|
||||||
|
|
||||||
|
@ -6,9 +6,12 @@ use actix_web::{
|
|||||||
web::{self, Data},
|
web::{self, Data},
|
||||||
App, HttpResponse, HttpServer,
|
App, HttpResponse, HttpServer,
|
||||||
};
|
};
|
||||||
|
use async_nats::service::ServiceExt;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use futures::StreamExt;
|
use flate2::{write::GzEncoder, Compression};
|
||||||
|
use futures::{StreamExt, TryStreamExt};
|
||||||
use sqlx::postgres::PgPoolOptions;
|
use sqlx::postgres::PgPoolOptions;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::Instrument;
|
use tracing::Instrument;
|
||||||
use tracing_unwrap::ResultExt;
|
use tracing_unwrap::ResultExt;
|
||||||
|
|
||||||
@ -28,11 +31,27 @@ enum Error {
|
|||||||
#[error("listener got data that could not be decoded: {0}")]
|
#[error("listener got data that could not be decoded: {0}")]
|
||||||
Data(serde_json::Error),
|
Data(serde_json::Error),
|
||||||
#[error("nats encountered error: {0}")]
|
#[error("nats encountered error: {0}")]
|
||||||
Nats(#[from] async_nats::Error),
|
Nats(#[from] NatsError),
|
||||||
#[error("io error: {0}")]
|
#[error("io error: {0}")]
|
||||||
Io(#[from] std::io::Error),
|
Io(#[from] std::io::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
enum NatsError {
|
||||||
|
#[error("{0}")]
|
||||||
|
Subscribe(#[from] async_nats::SubscribeError),
|
||||||
|
#[error("{0}")]
|
||||||
|
CreateStream(#[from] async_nats::jetstream::context::CreateStreamError),
|
||||||
|
#[error("{0}")]
|
||||||
|
Consumer(#[from] async_nats::jetstream::stream::ConsumerError),
|
||||||
|
#[error("{0}")]
|
||||||
|
Stream(#[from] async_nats::jetstream::consumer::StreamError),
|
||||||
|
#[error("{0}")]
|
||||||
|
Request(#[from] async_nats::jetstream::context::RequestError),
|
||||||
|
#[error("{0}")]
|
||||||
|
Generic(#[from] async_nats::Error),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Clone)]
|
#[derive(Parser, Clone)]
|
||||||
struct Config {
|
struct Config {
|
||||||
/// Host to listen for incoming HTTP requests.
|
/// Host to listen for incoming HTTP requests.
|
||||||
@ -58,11 +77,14 @@ struct Config {
|
|||||||
database_subscribe: Option<String>,
|
database_subscribe: Option<String>,
|
||||||
|
|
||||||
/// The NATS host.
|
/// The NATS host.
|
||||||
#[clap(long, env)]
|
#[clap(long, env, requires = "nats_prefix")]
|
||||||
nats_host: Option<String>,
|
nats_host: Option<String>,
|
||||||
/// The NATS NKEY.
|
/// The NATS NKEY.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
nats_nkey: Option<String>,
|
nats_nkey: Option<String>,
|
||||||
|
/// A prefix to use for NATS subjects.
|
||||||
|
#[clap(long, env)]
|
||||||
|
nats_prefix: Option<String>,
|
||||||
|
|
||||||
/// Maximum distance permitted in queries.
|
/// Maximum distance permitted in queries.
|
||||||
#[clap(long, env, default_value = "10")]
|
#[clap(long, env, default_value = "10")]
|
||||||
@ -84,6 +106,8 @@ async fn main() {
|
|||||||
|
|
||||||
tracing::info!("starting bkapi");
|
tracing::info!("starting bkapi");
|
||||||
|
|
||||||
|
let token = CancellationToken::new();
|
||||||
|
|
||||||
let metrics_server = foxlib::MetricsServer::serve(config.metrics_host, false).await;
|
let metrics_server = foxlib::MetricsServer::serve(config.metrics_host, false).await;
|
||||||
|
|
||||||
let tree = tree::Tree::new();
|
let tree = tree::Tree::new();
|
||||||
@ -115,24 +139,26 @@ async fn main() {
|
|||||||
|
|
||||||
let tree_clone = tree.clone();
|
let tree_clone = tree.clone();
|
||||||
let config_clone = config.clone();
|
let config_clone = config.clone();
|
||||||
if let Some(subscription) = config.database_subscribe.clone() {
|
let mut listener_task = if let Some(subscription) = config.database_subscribe.clone() {
|
||||||
tracing::info!("starting to listen for payloads from postgres");
|
tracing::info!("starting to listen for payloads from postgres");
|
||||||
|
tokio::spawn(tree::listen_for_payloads_db(
|
||||||
let query = config.database_query.clone();
|
pool,
|
||||||
|
subscription,
|
||||||
tokio::task::spawn(async move {
|
config.database_query.clone(),
|
||||||
tree::listen_for_payloads_db(pool, subscription, query, tree_clone, sender)
|
tree_clone,
|
||||||
.await
|
sender,
|
||||||
.unwrap_or_log();
|
token.clone(),
|
||||||
});
|
))
|
||||||
} else if let Some(client) = client.clone() {
|
} else if let Some(client) = client.clone() {
|
||||||
tracing::info!("starting to listen for payloads from nats");
|
tracing::info!("starting to listen for payloads from nats");
|
||||||
|
tokio::spawn(tree::listen_for_payloads_nats(
|
||||||
tokio::task::spawn(async {
|
config_clone,
|
||||||
tree::listen_for_payloads_nats(config_clone, pool, client, tree_clone, sender)
|
pool,
|
||||||
.await
|
client,
|
||||||
.unwrap_or_log();
|
tree_clone,
|
||||||
});
|
sender,
|
||||||
|
token.clone(),
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
panic!("no listener source available");
|
panic!("no listener source available");
|
||||||
};
|
};
|
||||||
@ -146,19 +172,48 @@ async fn main() {
|
|||||||
metrics_server.set_ready(true);
|
metrics_server.set_ready(true);
|
||||||
|
|
||||||
if let Some(client) = client {
|
if let Some(client) = client {
|
||||||
let tree_clone = tree.clone();
|
let tree = tree.clone();
|
||||||
let config_clone = config.clone();
|
let config = config.clone();
|
||||||
tokio::task::spawn(async move {
|
let token = token.clone();
|
||||||
search_nats(client, tree_clone, config_clone)
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
search_nats(client, tree, config, token)
|
||||||
.await
|
.await
|
||||||
.unwrap_or_log();
|
.unwrap_or_log();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
start_server(config, tree).await.unwrap_or_log();
|
let mut server = start_server(config, tree);
|
||||||
|
let server_handle = server.handle();
|
||||||
|
|
||||||
|
tokio::spawn({
|
||||||
|
let token = token.clone();
|
||||||
|
|
||||||
|
async move {
|
||||||
|
tokio::signal::ctrl_c()
|
||||||
|
.await
|
||||||
|
.expect_or_log("ctrl+c handler failed to install");
|
||||||
|
token.cancel();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = token.cancelled() => {
|
||||||
|
tracing::info!("got cancellation, stopping server");
|
||||||
|
let _ = tokio::join!(server_handle.stop(true), server, listener_task);
|
||||||
|
}
|
||||||
|
res = &mut listener_task => {
|
||||||
|
tracing::error!("listener task ended: {res:?}");
|
||||||
|
let _ = tokio::join!(server_handle.stop(true), server);
|
||||||
|
}
|
||||||
|
res = &mut server => {
|
||||||
|
tracing::error!("server ended: {res:?}");
|
||||||
|
let _ = tokio::join!(server_handle.stop(true), listener_task);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn start_server(config: Config, tree: tree::Tree) -> Result<(), Error> {
|
fn start_server(config: Config, tree: tree::Tree) -> actix_web::dev::Server {
|
||||||
let tree = Data::new(tree);
|
let tree = Data::new(tree);
|
||||||
let config_data = Data::new(config.clone());
|
let config_data = Data::new(config.clone());
|
||||||
|
|
||||||
@ -190,12 +245,12 @@ async fn start_server(config: Config, tree: tree::Tree) -> Result<(), Error> {
|
|||||||
.app_data(tree.clone())
|
.app_data(tree.clone())
|
||||||
.app_data(config_data.clone())
|
.app_data(config_data.clone())
|
||||||
.service(search)
|
.service(search)
|
||||||
|
.service(dump)
|
||||||
})
|
})
|
||||||
.bind(&config.http_listen)
|
.bind(&config.http_listen)
|
||||||
.expect_or_log("bind failed")
|
.expect_or_log("bind failed")
|
||||||
|
.disable_signals()
|
||||||
.run()
|
.run()
|
||||||
.await
|
|
||||||
.map_err(Error::Io)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize)]
|
#[derive(Debug, serde::Deserialize)]
|
||||||
@ -242,56 +297,77 @@ struct SearchPayload {
|
|||||||
distance: u32,
|
distance: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(client, tree, config))]
|
#[tracing::instrument(skip_all)]
|
||||||
async fn search_nats(
|
async fn search_nats(
|
||||||
client: async_nats::Client,
|
client: async_nats::Client,
|
||||||
tree: tree::Tree,
|
tree: tree::Tree,
|
||||||
config: Config,
|
config: Config,
|
||||||
|
token: CancellationToken,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
tracing::info!("subscribing to searches");
|
tracing::info!("subscribing to searches");
|
||||||
|
|
||||||
let client = Arc::new(client);
|
let client = Arc::new(client);
|
||||||
let max_distance = config.max_distance;
|
let max_distance = config.max_distance;
|
||||||
|
|
||||||
let mut sub = client
|
let service = client
|
||||||
.queue_subscribe("bkapi.search".to_string(), "bkapi-search".to_string())
|
.add_service(async_nats::service::Config {
|
||||||
.await?;
|
name: "bkapi-search".to_string(),
|
||||||
|
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||||
|
description: None,
|
||||||
|
stats_handler: None,
|
||||||
|
metadata: None,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(NatsError::Generic)?;
|
||||||
|
|
||||||
while let Some(message) = sub.next().await {
|
let mut endpoint = service
|
||||||
|
.endpoint(format!("{}.bkapi.search", config.nats_prefix.unwrap()))
|
||||||
|
.await
|
||||||
|
.map_err(NatsError::Generic)?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
Some(request) = endpoint.next() => {
|
||||||
tracing::trace!("got search message");
|
tracing::trace!("got search message");
|
||||||
|
|
||||||
let reply = match message.reply {
|
if let Err(err) = handle_search_nats(max_distance, tree.clone(), request).await {
|
||||||
Some(reply) => reply,
|
|
||||||
None => {
|
|
||||||
tracing::warn!("message had no reply subject, skipping");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(err) = handle_search_nats(
|
|
||||||
max_distance,
|
|
||||||
client.clone(),
|
|
||||||
tree.clone(),
|
|
||||||
reply,
|
|
||||||
&message.payload,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
tracing::error!("could not handle nats search: {err}");
|
tracing::error!("could not handle nats search: {err}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_ = token.cancelled() => {
|
||||||
|
tracing::info!("cancelled, stopping endpoint");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(err) = endpoint.stop().await {
|
||||||
|
tracing::error!("could not stop endpoint: {err}");
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_search_nats(
|
async fn handle_search_nats(
|
||||||
max_distance: u32,
|
max_distance: u32,
|
||||||
client: Arc<async_nats::Client>,
|
|
||||||
tree: tree::Tree,
|
tree: tree::Tree,
|
||||||
reply: String,
|
request: async_nats::service::Request,
|
||||||
payload: &[u8],
|
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
let payloads: Vec<SearchPayload> = serde_json::from_slice(payload).map_err(Error::Data)?;
|
let payloads: Vec<SearchPayload> = match serde_json::from_slice(&request.message.payload) {
|
||||||
|
Ok(payloads) => payloads,
|
||||||
|
Err(err) => {
|
||||||
|
let err = Err(async_nats::service::error::Error {
|
||||||
|
status: err.to_string(),
|
||||||
|
code: 400,
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Err(err) = request.respond(err).await {
|
||||||
|
tracing::error!("could not respond with error: {err}");
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
tokio::task::spawn(
|
tokio::task::spawn(
|
||||||
async move {
|
async move {
|
||||||
@ -302,16 +378,15 @@ async fn handle_search_nats(
|
|||||||
|
|
||||||
let results = tree.find(hashes).await;
|
let results = tree.find(hashes).await;
|
||||||
|
|
||||||
if let Err(err) = client
|
let resp = serde_json::to_vec(&results).map(Into::into).map_err(|err| {
|
||||||
.publish(
|
async_nats::service::error::Error {
|
||||||
reply,
|
status: err.to_string(),
|
||||||
serde_json::to_vec(&results)
|
code: 503,
|
||||||
.expect_or_log("results could not be serialized")
|
}
|
||||||
.into(),
|
});
|
||||||
)
|
|
||||||
.await
|
if let Err(err) = request.respond(resp).await {
|
||||||
{
|
tracing::error!("could not respond: {err}");
|
||||||
tracing::error!("could not publish results: {err}");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.in_current_span(),
|
.in_current_span(),
|
||||||
@ -319,3 +394,44 @@ async fn handle_search_nats(
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[get("/dump")]
|
||||||
|
async fn dump(tree: Data<tree::Tree>) -> HttpResponse {
|
||||||
|
let (wtr, rdr) = tokio::io::duplex(4096);
|
||||||
|
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
|
let tree = tree.tree.blocking_read();
|
||||||
|
|
||||||
|
let bridge = tokio_util::io::SyncIoBridge::new(wtr);
|
||||||
|
let mut compressor = GzEncoder::new(bridge, Compression::default());
|
||||||
|
|
||||||
|
if let Err(err) = bincode::serde::encode_into_std_write(
|
||||||
|
&*tree,
|
||||||
|
&mut compressor,
|
||||||
|
bincode::config::standard(),
|
||||||
|
) {
|
||||||
|
tracing::error!("could not write tree to compressor: {err}");
|
||||||
|
}
|
||||||
|
|
||||||
|
match compressor.finish() {
|
||||||
|
Ok(mut file) => match file.shutdown() {
|
||||||
|
Ok(_) => tracing::info!("finished writing dump"),
|
||||||
|
Err(err) => tracing::error!("could not finish writing dump: {err}"),
|
||||||
|
},
|
||||||
|
Err(err) => {
|
||||||
|
tracing::error!("could not finish compressor: {err}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let stream = tokio_util::codec::FramedRead::new(rdr, tokio_util::codec::BytesCodec::new())
|
||||||
|
.map_ok(|b| b.freeze());
|
||||||
|
|
||||||
|
HttpResponse::Ok()
|
||||||
|
.content_type("application/octet-stream")
|
||||||
|
.insert_header((
|
||||||
|
"content-disposition",
|
||||||
|
r#"attachment; filename="bkapi-dump.dat.gz""#,
|
||||||
|
))
|
||||||
|
.streaming(stream)
|
||||||
|
}
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_nats::jetstream::consumer::DeliverPolicy;
|
||||||
use bk_tree::BKTree;
|
use bk_tree::BKTree;
|
||||||
use futures::TryStreamExt;
|
use futures::{StreamExt, TryStreamExt};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::{postgres::PgListener, Pool, Postgres, Row};
|
use sqlx::{postgres::PgListener, Pool, Postgres, Row};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing_unwrap::ResultExt;
|
use tracing_unwrap::ResultExt;
|
||||||
|
|
||||||
use crate::{Config, Error};
|
use crate::{Config, Error, NatsError};
|
||||||
|
|
||||||
lazy_static::lazy_static! {
|
lazy_static::lazy_static! {
|
||||||
static ref TREE_ENTRIES: prometheus::IntCounter = prometheus::register_int_counter!("bkapi_tree_entries", "Total number of entries within tree").unwrap();
|
static ref TREE_ENTRIES: prometheus::IntCounter = prometheus::register_int_counter!("bkapi_tree_entries", "Total number of entries within tree").unwrap();
|
||||||
@ -18,7 +21,7 @@ lazy_static::lazy_static! {
|
|||||||
/// A BKTree wrapper to cover common operations.
|
/// A BKTree wrapper to cover common operations.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Tree {
|
pub struct Tree {
|
||||||
tree: Arc<RwLock<BKTree<Node, Hamming>>>,
|
pub tree: Arc<RwLock<BKTree<Node, Hamming>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A hash and distance pair. May be used for searching or in search results.
|
/// A hash and distance pair. May be used for searching or in search results.
|
||||||
@ -116,7 +119,6 @@ impl Tree {
|
|||||||
.start_timer();
|
.start_timer();
|
||||||
let results: Vec<_> = tree
|
let results: Vec<_> = tree
|
||||||
.find(&hash.into(), distance)
|
.find(&hash.into(), distance)
|
||||||
.into_iter()
|
|
||||||
.map(|item| HashDistance {
|
.map(|item| HashDistance {
|
||||||
distance: item.0,
|
distance: item.0,
|
||||||
hash: (*item.1).into(),
|
hash: (*item.1).into(),
|
||||||
@ -130,7 +132,8 @@ impl Tree {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A hamming distance metric.
|
/// A hamming distance metric.
|
||||||
struct Hamming;
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct Hamming;
|
||||||
|
|
||||||
impl bk_tree::Metric<Node> for Hamming {
|
impl bk_tree::Metric<Node> for Hamming {
|
||||||
fn distance(&self, a: &Node, b: &Node) -> u32 {
|
fn distance(&self, a: &Node, b: &Node) -> u32 {
|
||||||
@ -144,8 +147,8 @@ impl bk_tree::Metric<Node> for Hamming {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A value of a node in the BK tree.
|
/// A value of a node in the BK tree.
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
struct Node([u8; 8]);
|
pub struct Node([u8; 8]);
|
||||||
|
|
||||||
impl From<i64> for Node {
|
impl From<i64> for Node {
|
||||||
fn from(num: i64) -> Self {
|
fn from(num: i64) -> Self {
|
||||||
@ -173,6 +176,7 @@ pub(crate) async fn listen_for_payloads_db(
|
|||||||
query: String,
|
query: String,
|
||||||
tree: Tree,
|
tree: Tree,
|
||||||
initial: futures::channel::oneshot::Sender<()>,
|
initial: futures::channel::oneshot::Sender<()>,
|
||||||
|
token: CancellationToken,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
let mut initial = Some(initial);
|
let mut initial = Some(initial);
|
||||||
|
|
||||||
@ -193,15 +197,40 @@ pub(crate) async fn listen_for_payloads_db(
|
|||||||
.expect_or_log("nothing listening for initial data");
|
.expect_or_log("nothing listening for initial data");
|
||||||
}
|
}
|
||||||
|
|
||||||
while let Some(notification) = listener.try_recv().await.map_err(Error::Listener)? {
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
res = listener.try_recv() => {
|
||||||
|
match res {
|
||||||
|
Ok(Some(notification)) => {
|
||||||
tracing::trace!("got postgres payload");
|
tracing::trace!("got postgres payload");
|
||||||
process_payload(&tree, notification.payload().as_bytes()).await?;
|
process_payload(&tree, notification.payload().as_bytes()).await?;
|
||||||
}
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
tracing::warn!("got none value from recv");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
tracing::error!("got recv error: {err}");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = token.cancelled() => {
|
||||||
|
tracing::info!("got cancellation");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.is_cancelled() {
|
||||||
|
tracing::info!("cancelled, stopping db listener");
|
||||||
|
return Ok(());
|
||||||
|
} else {
|
||||||
tracing::error!("disconnected from postgres listener, recreating tree");
|
tracing::error!("disconnected from postgres listener, recreating tree");
|
||||||
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Listen for incoming payloads from NATS.
|
/// Listen for incoming payloads from NATS.
|
||||||
#[tracing::instrument(skip_all)]
|
#[tracing::instrument(skip_all)]
|
||||||
@ -211,29 +240,45 @@ pub(crate) async fn listen_for_payloads_nats(
|
|||||||
client: async_nats::Client,
|
client: async_nats::Client,
|
||||||
tree: Tree,
|
tree: Tree,
|
||||||
initial: futures::channel::oneshot::Sender<()>,
|
initial: futures::channel::oneshot::Sender<()>,
|
||||||
|
token: CancellationToken,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
let jetstream = async_nats::jetstream::new(client);
|
let jetstream = async_nats::jetstream::new(client);
|
||||||
let mut initial = Some(initial);
|
let mut initial = Some(initial);
|
||||||
|
|
||||||
let stream = jetstream
|
let stream = jetstream
|
||||||
.get_or_create_stream(async_nats::jetstream::stream::Config {
|
.get_or_create_stream(async_nats::jetstream::stream::Config {
|
||||||
name: "bkapi-hashes".to_string(),
|
name: format!(
|
||||||
subjects: vec!["bkapi.add".to_string()],
|
"{}-bkapi-hashes",
|
||||||
max_age: std::time::Duration::from_secs(60 * 60 * 24),
|
config.nats_prefix.clone().unwrap().replace('.', "-")
|
||||||
retention: async_nats::jetstream::stream::RetentionPolicy::Interest,
|
),
|
||||||
|
subjects: vec![format!("{}.bkapi.add", config.nats_prefix.unwrap())],
|
||||||
|
max_age: std::time::Duration::from_secs(60 * 30),
|
||||||
|
retention: async_nats::jetstream::stream::RetentionPolicy::Limits,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await?;
|
.await
|
||||||
|
.map_err(NatsError::CreateStream)?;
|
||||||
|
|
||||||
loop {
|
// Because we're tracking the last sequence ID before we load tree data,
|
||||||
let consumer = stream
|
// we don't need to start the listener until after it's loaded. This
|
||||||
.get_or_create_consumer(
|
// prevents issues with a slow client but still retains every hash.
|
||||||
"bkapi-consumer",
|
let mut seq = stream.cached_info().state.last_sequence;
|
||||||
async_nats::jetstream::consumer::pull::Config {
|
|
||||||
..Default::default()
|
let create_consumer = |stream: async_nats::jetstream::stream::Stream, start_sequence: u64| async move {
|
||||||
|
tracing::info!(start_sequence, "creating consumer");
|
||||||
|
|
||||||
|
stream
|
||||||
|
.create_consumer(async_nats::jetstream::consumer::pull::Config {
|
||||||
|
deliver_policy: if start_sequence > 0 {
|
||||||
|
DeliverPolicy::ByStartSequence { start_sequence }
|
||||||
|
} else {
|
||||||
|
DeliverPolicy::All
|
||||||
},
|
},
|
||||||
)
|
..Default::default()
|
||||||
.await?;
|
})
|
||||||
|
.await
|
||||||
|
.map_err(NatsError::Consumer)
|
||||||
|
};
|
||||||
|
|
||||||
tree.reload(&pool, &config.database_query).await?;
|
tree.reload(&pool, &config.database_query).await?;
|
||||||
|
|
||||||
@ -243,21 +288,41 @@ pub(crate) async fn listen_for_payloads_nats(
|
|||||||
.expect_or_log("nothing listening for initial data");
|
.expect_or_log("nothing listening for initial data");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut messages = consumer.messages().await?;
|
loop {
|
||||||
|
let consumer = create_consumer(stream.clone(), seq).await?;
|
||||||
|
|
||||||
|
let messages = consumer
|
||||||
|
.messages()
|
||||||
|
.await
|
||||||
|
.map_err(NatsError::Stream)?
|
||||||
|
.take_until(token.cancelled());
|
||||||
|
tokio::pin!(messages);
|
||||||
|
|
||||||
while let Ok(Some(message)) = messages.try_next().await {
|
while let Ok(Some(message)) = messages.try_next().await {
|
||||||
tracing::trace!("got nats payload");
|
|
||||||
message.ack().await?;
|
|
||||||
process_payload(&tree, &message.payload).await?;
|
process_payload(&tree, &message.payload).await?;
|
||||||
|
|
||||||
|
message.ack().await.map_err(NatsError::Generic)?;
|
||||||
|
seq = message
|
||||||
|
.info()
|
||||||
|
.expect_or_log("message missing info")
|
||||||
|
.stream_sequence;
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::error!("disconnected from nats listener, recreating tree");
|
if token.is_cancelled() {
|
||||||
|
tracing::info!("cancelled, stopping nats listener");
|
||||||
|
return Ok(());
|
||||||
|
} else {
|
||||||
|
tracing::error!("disconnected from nats listener");
|
||||||
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Process a payload from Postgres or NATS and add to the tree.
|
/// Process a payload from Postgres or NATS and add to the tree.
|
||||||
|
#[tracing::instrument(skip_all)]
|
||||||
async fn process_payload(tree: &Tree, payload: &[u8]) -> Result<(), Error> {
|
async fn process_payload(tree: &Tree, payload: &[u8]) -> Result<(), Error> {
|
||||||
|
tracing::trace!("got payload: {}", String::from_utf8_lossy(payload));
|
||||||
|
|
||||||
let payload: Payload = serde_json::from_slice(payload).map_err(Error::Data)?;
|
let payload: Payload = serde_json::from_slice(payload).map_err(Error::Data)?;
|
||||||
tracing::trace!("got hash: {}", payload.hash);
|
tracing::trace!("got hash: {}", payload.hash);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user