mirror of
https://github.com/Syfaro/bkapi.git
synced 2024-11-05 14:44:29 +00:00
Switch to actix, refactoring.
This commit is contained in:
parent
196ea5a2e2
commit
3b3d1db2dc
1369
Cargo.lock
generated
1369
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
13
Cargo.toml
13
Cargo.toml
@ -13,8 +13,7 @@ tracing-subscriber = "0.2"
|
|||||||
tracing-unwrap = "0.9"
|
tracing-unwrap = "0.9"
|
||||||
tracing-opentelemetry = "0.14"
|
tracing-opentelemetry = "0.14"
|
||||||
|
|
||||||
opentelemetry = { version = "0.15", features = ["rt-async-std"] }
|
opentelemetry = { version = "0.15", features = ["rt-tokio"] }
|
||||||
opentelemetry-semantic-conventions = "0.7.0"
|
|
||||||
opentelemetry-jaeger = "0.14"
|
opentelemetry-jaeger = "0.14"
|
||||||
|
|
||||||
lazy_static = "1"
|
lazy_static = "1"
|
||||||
@ -24,14 +23,16 @@ bk-tree = "0.4.0"
|
|||||||
hamming = "0.1"
|
hamming = "0.1"
|
||||||
|
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
async-std = { version = "1", features = ["attributes"] }
|
tokio = { version = "1", features = ["sync"] }
|
||||||
|
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
|
|
||||||
tide = "0.16"
|
actix-web = "4.0.0-beta.8"
|
||||||
tide-tracing = "0.0.11"
|
actix-http = "3.0.0-beta.8"
|
||||||
|
actix-service = "2"
|
||||||
|
tracing-actix-web = { version = "0.4.0-beta.9", features = ["opentelemetry_0_15"] }
|
||||||
|
|
||||||
[dependencies.sqlx]
|
[dependencies.sqlx]
|
||||||
version = "0.5"
|
version = "0.5"
|
||||||
features = ["runtime-async-std-rustls", "postgres"]
|
features = ["runtime-actix-rustls", "postgres"]
|
||||||
|
317
src/main.rs
317
src/main.rs
@ -1,17 +1,20 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_std::sync::{RwLock, RwLockUpgradableReadGuard};
|
use actix_service::Service;
|
||||||
|
use actix_web::{
|
||||||
|
get,
|
||||||
|
web::{self, Data},
|
||||||
|
App, HttpResponse, HttpServer, Responder,
|
||||||
|
};
|
||||||
use envconfig::Envconfig;
|
use envconfig::Envconfig;
|
||||||
use opentelemetry::KeyValue;
|
use opentelemetry::KeyValue;
|
||||||
use sqlx::{
|
use prometheus::{Encoder, TextEncoder};
|
||||||
postgres::{PgListener, PgPoolOptions},
|
use sqlx::postgres::PgPoolOptions;
|
||||||
Pool, Postgres, Row,
|
use tokio::sync::RwLock;
|
||||||
};
|
|
||||||
use tide::Request;
|
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
use tracing_unwrap::ResultExt;
|
use tracing_unwrap::ResultExt;
|
||||||
|
|
||||||
mod middlewares;
|
mod tree;
|
||||||
|
|
||||||
lazy_static::lazy_static! {
|
lazy_static::lazy_static! {
|
||||||
static ref HTTP_REQUEST_COUNT: prometheus::CounterVec = prometheus::register_counter_vec!("http_requests_total", "Number of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
|
static ref HTTP_REQUEST_COUNT: prometheus::CounterVec = prometheus::register_counter_vec!("http_requests_total", "Number of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
|
||||||
@ -31,8 +34,6 @@ enum Error {
|
|||||||
Data(serde_json::Error),
|
Data(serde_json::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
|
|
||||||
|
|
||||||
#[derive(Envconfig, Clone)]
|
#[derive(Envconfig, Clone)]
|
||||||
struct Config {
|
struct Config {
|
||||||
#[envconfig(default = "0.0.0.0:3000")]
|
#[envconfig(default = "0.0.0.0:3000")]
|
||||||
@ -51,40 +52,84 @@ struct Config {
|
|||||||
max_distance: Option<u32>,
|
max_distance: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A hamming distance metric.
|
#[actix_web::main]
|
||||||
struct Hamming;
|
|
||||||
|
|
||||||
impl bk_tree::Metric<Node> for Hamming {
|
|
||||||
fn distance(&self, a: &Node, b: &Node) -> u32 {
|
|
||||||
hamming::distance_fast(&a.0, &b.0).expect_or_log("hashes did not have same byte alignment")
|
|
||||||
as u32
|
|
||||||
}
|
|
||||||
|
|
||||||
fn threshold_distance(&self, a: &Node, b: &Node, _threshold: u32) -> Option<u32> {
|
|
||||||
Some(self.distance(a, b))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A value of a node in the BK tree.
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
struct Node([u8; 8]);
|
|
||||||
|
|
||||||
impl From<i64> for Node {
|
|
||||||
fn from(num: i64) -> Self {
|
|
||||||
Self(num.to_be_bytes())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<Node> for i64 {
|
|
||||||
fn from(node: Node) -> Self {
|
|
||||||
i64::from_be_bytes(node.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_std::main]
|
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let config = Config::init_from_env().expect("could not load config");
|
let config = Config::init_from_env().expect("could not load config");
|
||||||
|
configure_tracing(&config);
|
||||||
|
|
||||||
|
tracing::info!("starting bkbase");
|
||||||
|
|
||||||
|
let tree: tree::Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(tree::Hamming)));
|
||||||
|
|
||||||
|
tracing::trace!("connecting to postgres");
|
||||||
|
let pool = PgPoolOptions::new()
|
||||||
|
.max_connections(2)
|
||||||
|
.connect(&config.database_url)
|
||||||
|
.await
|
||||||
|
.expect_or_log("could not connect to database");
|
||||||
|
tracing::debug!("connected to postgres");
|
||||||
|
|
||||||
|
let http_listen = config.http_listen.clone();
|
||||||
|
|
||||||
|
let (sender, receiver) = futures::channel::oneshot::channel();
|
||||||
|
|
||||||
|
tracing::info!("starting to listen for payloads");
|
||||||
|
let tree_clone = tree.clone();
|
||||||
|
let config_clone = config.clone();
|
||||||
|
tokio::task::spawn(async {
|
||||||
|
tree::listen_for_payloads(pool, config_clone, tree_clone, sender)
|
||||||
|
.await
|
||||||
|
.expect_or_log("listenting for updates failed");
|
||||||
|
});
|
||||||
|
|
||||||
|
tracing::info!("waiting for initial tree to load");
|
||||||
|
receiver
|
||||||
|
.await
|
||||||
|
.expect_or_log("tree loading was dropped before completing");
|
||||||
|
tracing::info!("initial tree loaded, starting server");
|
||||||
|
|
||||||
|
let tree = Data::new(tree);
|
||||||
|
let config = Data::new(config);
|
||||||
|
|
||||||
|
HttpServer::new(move || {
|
||||||
|
App::new()
|
||||||
|
.wrap(tracing_actix_web::TracingLogger::default())
|
||||||
|
.wrap_fn(|req, srv| {
|
||||||
|
let path = req.path().to_owned();
|
||||||
|
let method = req.method().to_string();
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let fut = srv.call(req);
|
||||||
|
|
||||||
|
async move {
|
||||||
|
let res = fut.await?;
|
||||||
|
let end = std::time::Instant::now().duration_since(start);
|
||||||
|
|
||||||
|
let status_code = res.status().as_u16().to_string();
|
||||||
|
|
||||||
|
let labels: Vec<&str> = vec![&path, &method, &status_code];
|
||||||
|
HTTP_REQUEST_COUNT.with_label_values(&labels).inc();
|
||||||
|
HTTP_REQUEST_DURATION
|
||||||
|
.with_label_values(&labels)
|
||||||
|
.observe(end.as_secs_f64());
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.app_data(tree.clone())
|
||||||
|
.app_data(config.clone())
|
||||||
|
.service(search)
|
||||||
|
.service(health)
|
||||||
|
.service(metrics)
|
||||||
|
})
|
||||||
|
.bind(&http_listen)
|
||||||
|
.expect_or_log("bind failed")
|
||||||
|
.run()
|
||||||
|
.await
|
||||||
|
.expect_or_log("server failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn configure_tracing(config: &Config) {
|
||||||
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new());
|
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new());
|
||||||
|
|
||||||
let env = std::env::var("ENVIRONMENT");
|
let env = std::env::var("ENVIRONMENT");
|
||||||
@ -103,10 +148,10 @@ async fn main() {
|
|||||||
KeyValue::new("environment", env.to_owned()),
|
KeyValue::new("environment", env.to_owned()),
|
||||||
KeyValue::new("version", env!("CARGO_PKG_VERSION")),
|
KeyValue::new("version", env!("CARGO_PKG_VERSION")),
|
||||||
])
|
])
|
||||||
.install_batch(opentelemetry::runtime::AsyncStd)
|
.install_batch(opentelemetry::runtime::Tokio)
|
||||||
.expect("otel jaeger pipeline could not be created");
|
.expect("otel jaeger pipeline could not be created");
|
||||||
|
|
||||||
let trace = tracing_opentelemetry::layer().with_tracer(tracer.clone());
|
let trace = tracing_opentelemetry::layer().with_tracer(tracer);
|
||||||
tracing::subscriber::set_global_default(
|
tracing::subscriber::set_global_default(
|
||||||
tracing_subscriber::Registry::default()
|
tracing_subscriber::Registry::default()
|
||||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||||
@ -114,59 +159,15 @@ async fn main() {
|
|||||||
.with(tracing_subscriber::fmt::layer()),
|
.with(tracing_subscriber::fmt::layer()),
|
||||||
)
|
)
|
||||||
.expect("tracing could not be configured");
|
.expect("tracing could not be configured");
|
||||||
|
|
||||||
tracing::info!("starting bkbase");
|
|
||||||
|
|
||||||
tracing::debug!("loaded config");
|
|
||||||
|
|
||||||
let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming)));
|
|
||||||
|
|
||||||
let pool = PgPoolOptions::new()
|
|
||||||
.max_connections(2)
|
|
||||||
.connect(&config.database_url)
|
|
||||||
.await
|
|
||||||
.expect_or_log("could not connect to database");
|
|
||||||
tracing::debug!("connected to postgres");
|
|
||||||
|
|
||||||
let http_listen = config.http_listen.clone();
|
|
||||||
let max_distance = config.max_distance;
|
|
||||||
|
|
||||||
let (sender, receiver) = futures::channel::oneshot::channel();
|
|
||||||
|
|
||||||
tracing::info!("starting to listen for payloads");
|
|
||||||
let tree_clone = tree.clone();
|
|
||||||
async_std::task::spawn(async {
|
|
||||||
listen_for_payloads(pool, config, tree_clone, sender)
|
|
||||||
.await
|
|
||||||
.expect_or_log("listenting for updates failed");
|
|
||||||
});
|
|
||||||
|
|
||||||
tracing::info!("waiting for initial tree to load");
|
|
||||||
receiver
|
|
||||||
.await
|
|
||||||
.expect_or_log("tree loading was dropped before completing");
|
|
||||||
tracing::info!("initial tree loaded, starting server");
|
|
||||||
|
|
||||||
let mut app = tide::with_state(State { tree, max_distance });
|
|
||||||
app.with(middlewares::TideOpentelemMiddleware::new(tracer));
|
|
||||||
app.with(tide_tracing::TraceMiddleware::new());
|
|
||||||
app.with(middlewares::TidePrometheusMiddleware);
|
|
||||||
|
|
||||||
app.at("/search").get(search);
|
|
||||||
app.at("/health").get(|_| async { Ok("OK") });
|
|
||||||
|
|
||||||
app.listen(&http_listen)
|
|
||||||
.await
|
|
||||||
.expect_or_log("could not start web server");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct State {
|
struct State {
|
||||||
tree: Tree,
|
tree: tree::Tree,
|
||||||
max_distance: Option<u32>,
|
max_distance: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(Debug, serde::Deserialize)]
|
||||||
struct Query {
|
struct Query {
|
||||||
hash: i64,
|
hash: i64,
|
||||||
distance: u32,
|
distance: u32,
|
||||||
@ -186,21 +187,23 @@ struct SearchResponse {
|
|||||||
hashes: Vec<HashDistance>,
|
hashes: Vec<HashDistance>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(req))]
|
#[get("/search")]
|
||||||
async fn search(req: Request<State>) -> tide::Result {
|
#[tracing::instrument(skip(query, tree, config), fields(query = ?query.0))]
|
||||||
let state = req.state();
|
async fn search(
|
||||||
|
query: web::Query<Query>,
|
||||||
|
tree: Data<tree::Tree>,
|
||||||
|
config: Data<Config>,
|
||||||
|
) -> Result<HttpResponse, std::convert::Infallible> {
|
||||||
|
let Query { hash, distance } = query.0;
|
||||||
|
let max_distance = config.max_distance;
|
||||||
|
|
||||||
let Query { hash, distance } = req.query()?;
|
|
||||||
tracing::info!("searching for hash {} with distance {}", hash, distance);
|
tracing::info!("searching for hash {} with distance {}", hash, distance);
|
||||||
|
|
||||||
if matches!(state.max_distance, Some(max_distance) if distance > max_distance) {
|
if matches!(max_distance, Some(max_distance) if distance > max_distance) {
|
||||||
return Err(tide::Error::from_str(
|
return Ok(HttpResponse::BadRequest().body("distance is greater than max distance"));
|
||||||
400,
|
|
||||||
"Distance is greater than max distance",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let tree = state.tree.read().await;
|
let tree = tree.read().await;
|
||||||
|
|
||||||
let duration = TREE_DURATION
|
let duration = TREE_DURATION
|
||||||
.with_label_values(&[&distance.to_string()])
|
.with_label_values(&[&distance.to_string()])
|
||||||
@ -223,113 +226,21 @@ async fn search(req: Request<State>) -> tide::Result {
|
|||||||
hashes: matches,
|
hashes: matches,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(serde_json::to_string(&resp)?.into())
|
Ok(HttpResponse::Ok().json(resp))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new BK tree and pull in all hashes from provided query.
|
#[get("/health")]
|
||||||
///
|
async fn health() -> impl Responder {
|
||||||
/// This must be called after you have started a listener, otherwise items may
|
"OK"
|
||||||
/// be lost.
|
|
||||||
async fn create_tree(
|
|
||||||
conn: &Pool<Postgres>,
|
|
||||||
config: &Config,
|
|
||||||
) -> Result<bk_tree::BKTree<Node, Hamming>, Error> {
|
|
||||||
use futures::TryStreamExt;
|
|
||||||
|
|
||||||
tracing::warn!("creating new tree");
|
|
||||||
|
|
||||||
let mut tree = bk_tree::BKTree::new(Hamming);
|
|
||||||
|
|
||||||
let mut rows = sqlx::query(&config.database_query).fetch(conn);
|
|
||||||
|
|
||||||
let mut count = 0;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
|
|
||||||
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
|
|
||||||
let node: Node = row.get::<i64, _>(0).into();
|
|
||||||
|
|
||||||
// Avoid checking if each value is unique if we were told that the
|
|
||||||
// database query only returns unique values.
|
|
||||||
let timer = TREE_ADD_DURATION.start_timer();
|
|
||||||
if config.database_is_unique || tree.find_exact(&node).is_none() {
|
|
||||||
tree.add(node);
|
|
||||||
}
|
|
||||||
timer.stop_and_record();
|
|
||||||
|
|
||||||
count += 1;
|
|
||||||
if count % 250_000 == 0 {
|
|
||||||
tracing::debug!(count, "loaded more rows");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let dur = std::time::Instant::now().duration_since(start);
|
|
||||||
|
|
||||||
tracing::info!(count, "completed loading rows in {:?}", dur);
|
|
||||||
|
|
||||||
Ok(tree)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Deserialize)]
|
#[get("/metrics")]
|
||||||
struct Payload {
|
async fn metrics() -> Result<HttpResponse, std::convert::Infallible> {
|
||||||
hash: i64,
|
let mut buffer = Vec::new();
|
||||||
}
|
let encoder = TextEncoder::new();
|
||||||
|
|
||||||
/// Listen for incoming payloads.
|
let metric_families = prometheus::gather();
|
||||||
///
|
encoder.encode(&metric_families, &mut buffer).unwrap();
|
||||||
/// This will create a new tree to ensure all items are present. It will also
|
|
||||||
/// automatically recreate trees as needed if the database connection is lost.
|
Ok(HttpResponse::Ok().body(buffer))
|
||||||
async fn listen_for_payloads(
|
|
||||||
conn: Pool<Postgres>,
|
|
||||||
config: Config,
|
|
||||||
tree: Tree,
|
|
||||||
initial: futures::channel::oneshot::Sender<()>,
|
|
||||||
) -> Result<(), Error> {
|
|
||||||
let mut listener = PgListener::connect_with(&conn)
|
|
||||||
.await
|
|
||||||
.map_err(Error::Listener)?;
|
|
||||||
listener
|
|
||||||
.listen(&config.database_subscribe)
|
|
||||||
.await
|
|
||||||
.map_err(Error::Listener)?;
|
|
||||||
|
|
||||||
let new_tree = create_tree(&conn, &config).await?;
|
|
||||||
{
|
|
||||||
let mut tree = tree.write().await;
|
|
||||||
*tree = new_tree;
|
|
||||||
}
|
|
||||||
|
|
||||||
initial
|
|
||||||
.send(())
|
|
||||||
.expect_or_log("nothing listening for initial data");
|
|
||||||
|
|
||||||
loop {
|
|
||||||
while let Some(notification) = listener.try_recv().await.map_err(Error::Listener)? {
|
|
||||||
let payload: Payload =
|
|
||||||
serde_json::from_str(notification.payload()).map_err(Error::Data)?;
|
|
||||||
tracing::debug!(hash = payload.hash, "evaluating new payload");
|
|
||||||
|
|
||||||
let node: Node = payload.hash.into();
|
|
||||||
|
|
||||||
let _timer = TREE_ADD_DURATION.start_timer();
|
|
||||||
|
|
||||||
let tree = tree.upgradable_read().await;
|
|
||||||
if tree.find_exact(&node).is_some() {
|
|
||||||
tracing::trace!("hash already existed in tree");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::trace!("hash did not exist, adding to tree");
|
|
||||||
let mut tree = RwLockUpgradableReadGuard::upgrade(tree).await;
|
|
||||||
tree.add(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::error!("disconnected from listener, recreating tree");
|
|
||||||
async_std::task::sleep(std::time::Duration::from_secs(10)).await;
|
|
||||||
let new_tree = create_tree(&conn, &config).await?;
|
|
||||||
{
|
|
||||||
let mut tree = tree.write().await;
|
|
||||||
*tree = new_tree;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -1,139 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::convert::TryFrom;
|
|
||||||
|
|
||||||
use opentelemetry::{
|
|
||||||
global::get_text_map_propagator,
|
|
||||||
trace::{FutureExt, Span, SpanKind, TraceContextExt, Tracer},
|
|
||||||
Context,
|
|
||||||
};
|
|
||||||
use opentelemetry_semantic_conventions::trace;
|
|
||||||
use prometheus::{Encoder, TextEncoder};
|
|
||||||
use tide::{
|
|
||||||
http::{
|
|
||||||
headers::{HeaderName, HeaderValue},
|
|
||||||
mime,
|
|
||||||
},
|
|
||||||
Middleware, Request, Response,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub struct TidePrometheusMiddleware;
|
|
||||||
|
|
||||||
impl TidePrometheusMiddleware {
|
|
||||||
const ROUTE: &'static str = "/metrics";
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tide::utils::async_trait]
|
|
||||||
impl<State: Clone + Send + Sync + 'static> Middleware<State> for TidePrometheusMiddleware {
|
|
||||||
async fn handle(&self, req: Request<State>, next: tide::Next<'_, State>) -> tide::Result {
|
|
||||||
let path = req.url().path().to_owned();
|
|
||||||
|
|
||||||
if path == Self::ROUTE {
|
|
||||||
let mut buffer = Vec::new();
|
|
||||||
let encoder = TextEncoder::new();
|
|
||||||
|
|
||||||
let metric_families = prometheus::gather();
|
|
||||||
encoder.encode(&metric_families, &mut buffer).unwrap();
|
|
||||||
|
|
||||||
return Ok(Response::builder(200)
|
|
||||||
.body(buffer)
|
|
||||||
.content_type(mime::PLAIN)
|
|
||||||
.build());
|
|
||||||
}
|
|
||||||
|
|
||||||
let method = req.method().to_string();
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let res = next.run(req).await;
|
|
||||||
let end = std::time::Instant::now().duration_since(start);
|
|
||||||
|
|
||||||
let status_code = res.status().to_string();
|
|
||||||
|
|
||||||
let labels: Vec<&str> = vec![&path, &method, &status_code];
|
|
||||||
|
|
||||||
crate::HTTP_REQUEST_COUNT.with_label_values(&labels).inc();
|
|
||||||
crate::HTTP_REQUEST_DURATION
|
|
||||||
.with_label_values(&labels)
|
|
||||||
.observe(end.as_secs_f64());
|
|
||||||
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct TideOpentelemMiddleware<T: Tracer> {
|
|
||||||
tracer: T,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Tracer> TideOpentelemMiddleware<T> {
|
|
||||||
pub fn new(tracer: T) -> Self {
|
|
||||||
Self { tracer }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tide::utils::async_trait]
|
|
||||||
impl<T: Tracer + Send + Sync, State: Clone + Send + Sync + 'static> Middleware<State>
|
|
||||||
for TideOpentelemMiddleware<T>
|
|
||||||
{
|
|
||||||
async fn handle(&self, req: Request<State>, next: tide::Next<'_, State>) -> tide::Result {
|
|
||||||
let parent_cx = get_parent_cx(&req);
|
|
||||||
|
|
||||||
let method = req.method().to_string();
|
|
||||||
let url = req.url();
|
|
||||||
|
|
||||||
let attributes = vec![
|
|
||||||
trace::HTTP_METHOD.string(method.clone()),
|
|
||||||
trace::HTTP_SCHEME.string(url.scheme().to_string()),
|
|
||||||
trace::HTTP_URL.string(url.to_string()),
|
|
||||||
];
|
|
||||||
|
|
||||||
let mut span_builder = self
|
|
||||||
.tracer
|
|
||||||
.span_builder(format!("{} {}", method, url.path()))
|
|
||||||
.with_kind(SpanKind::Server)
|
|
||||||
.with_attributes(attributes);
|
|
||||||
|
|
||||||
if parent_cx.span().span_context().is_remote() {
|
|
||||||
tracing::trace!("incoming request has remote span: {:?}", parent_cx);
|
|
||||||
span_builder = span_builder.with_parent_context(parent_cx);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut span = span_builder.start(&self.tracer);
|
|
||||||
span.add_event("request.started".to_owned(), vec![]);
|
|
||||||
let cx = &Context::current_with_span(span);
|
|
||||||
|
|
||||||
let mut res = next.run(req).with_context(cx.clone()).await;
|
|
||||||
|
|
||||||
let span = cx.span();
|
|
||||||
span.add_event("request.completed".to_owned(), vec![]);
|
|
||||||
span.set_attribute(trace::HTTP_STATUS_CODE.i64(u16::from(res.status()).into()));
|
|
||||||
|
|
||||||
if let Some(len) = res.len().and_then(|len| i64::try_from(len).ok()) {
|
|
||||||
span.set_attribute(trace::HTTP_RESPONSE_CONTENT_LENGTH.i64(len));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut injector = HashMap::new();
|
|
||||||
get_text_map_propagator(|propagator| propagator.inject_context(&cx, &mut injector));
|
|
||||||
|
|
||||||
for (key, value) in injector {
|
|
||||||
let header_name = HeaderName::from_bytes(key.into_bytes());
|
|
||||||
let header_value = HeaderValue::from_bytes(value.into_bytes());
|
|
||||||
|
|
||||||
if let (Ok(name), Ok(value)) = (header_name, header_value) {
|
|
||||||
res.insert_header(name, value);
|
|
||||||
} else {
|
|
||||||
tracing::error!("injected header data was invalid");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_parent_cx<State>(req: &Request<State>) -> Context {
|
|
||||||
let mut req_headers = HashMap::new();
|
|
||||||
|
|
||||||
for (key, value) in req.iter() {
|
|
||||||
req_headers.insert(key.to_string(), value.last().to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
get_text_map_propagator(|propagator| propagator.extract(&req_headers))
|
|
||||||
}
|
|
143
src/tree.rs
Normal file
143
src/tree.rs
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use sqlx::{postgres::PgListener, Pool, Postgres, Row};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tracing_unwrap::ResultExt;
|
||||||
|
|
||||||
|
use crate::{Config, Error};
|
||||||
|
|
||||||
|
pub(crate) type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
|
||||||
|
|
||||||
|
/// A hamming distance metric.
|
||||||
|
pub(crate) struct Hamming;
|
||||||
|
|
||||||
|
impl bk_tree::Metric<Node> for Hamming {
|
||||||
|
fn distance(&self, a: &Node, b: &Node) -> u32 {
|
||||||
|
hamming::distance_fast(&a.0, &b.0).expect_or_log("hashes did not have same byte alignment")
|
||||||
|
as u32
|
||||||
|
}
|
||||||
|
|
||||||
|
fn threshold_distance(&self, a: &Node, b: &Node, _threshold: u32) -> Option<u32> {
|
||||||
|
Some(self.distance(a, b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A value of a node in the BK tree.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub(crate) struct Node([u8; 8]);
|
||||||
|
|
||||||
|
impl From<i64> for Node {
|
||||||
|
fn from(num: i64) -> Self {
|
||||||
|
Self(num.to_be_bytes())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Node> for i64 {
|
||||||
|
fn from(node: Node) -> Self {
|
||||||
|
i64::from_be_bytes(node.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new BK tree and pull in all hashes from provided query.
|
||||||
|
///
|
||||||
|
/// This must be called after you have started a listener, otherwise items may
|
||||||
|
/// be lost.
|
||||||
|
async fn create_tree(
|
||||||
|
conn: &Pool<Postgres>,
|
||||||
|
config: &Config,
|
||||||
|
) -> Result<bk_tree::BKTree<Node, Hamming>, Error> {
|
||||||
|
use futures::TryStreamExt;
|
||||||
|
|
||||||
|
tracing::warn!("creating new tree");
|
||||||
|
let mut tree = bk_tree::BKTree::new(Hamming);
|
||||||
|
let mut rows = sqlx::query(&config.database_query).fetch(conn);
|
||||||
|
|
||||||
|
let mut count = 0;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
|
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
|
||||||
|
let node: Node = row.get::<i64, _>(0).into();
|
||||||
|
|
||||||
|
// Avoid checking if each value is unique if we were told that the
|
||||||
|
// database query only returns unique values.
|
||||||
|
let timer = crate::TREE_ADD_DURATION.start_timer();
|
||||||
|
if config.database_is_unique || tree.find_exact(&node).is_none() {
|
||||||
|
tree.add(node);
|
||||||
|
}
|
||||||
|
timer.stop_and_record();
|
||||||
|
|
||||||
|
count += 1;
|
||||||
|
if count % 250_000 == 0 {
|
||||||
|
tracing::debug!(count, "loaded more rows");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let dur = std::time::Instant::now().duration_since(start);
|
||||||
|
|
||||||
|
tracing::info!(count, "completed loading rows in {:?}", dur);
|
||||||
|
Ok(tree)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
struct Payload {
|
||||||
|
hash: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Listen for incoming payloads.
|
||||||
|
///
|
||||||
|
/// This will create a new tree to ensure all items are present. It will also
|
||||||
|
/// automatically recreate trees as needed if the database connection is lost.
|
||||||
|
pub(crate) async fn listen_for_payloads(
|
||||||
|
conn: Pool<Postgres>,
|
||||||
|
config: Config,
|
||||||
|
tree: Tree,
|
||||||
|
initial: futures::channel::oneshot::Sender<()>,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let mut listener = PgListener::connect_with(&conn)
|
||||||
|
.await
|
||||||
|
.map_err(Error::Listener)?;
|
||||||
|
listener
|
||||||
|
.listen(&config.database_subscribe)
|
||||||
|
.await
|
||||||
|
.map_err(Error::Listener)?;
|
||||||
|
|
||||||
|
let new_tree = create_tree(&conn, &config).await?;
|
||||||
|
{
|
||||||
|
let mut tree = tree.write().await;
|
||||||
|
*tree = new_tree;
|
||||||
|
}
|
||||||
|
|
||||||
|
initial
|
||||||
|
.send(())
|
||||||
|
.expect_or_log("nothing listening for initial data");
|
||||||
|
|
||||||
|
loop {
|
||||||
|
while let Some(notification) = listener.try_recv().await.map_err(Error::Listener)? {
|
||||||
|
let payload: Payload =
|
||||||
|
serde_json::from_str(notification.payload()).map_err(Error::Data)?;
|
||||||
|
tracing::debug!(hash = payload.hash, "evaluating new payload");
|
||||||
|
|
||||||
|
let node: Node = payload.hash.into();
|
||||||
|
|
||||||
|
let _timer = crate::TREE_ADD_DURATION.start_timer();
|
||||||
|
|
||||||
|
let mut tree = tree.write().await;
|
||||||
|
if tree.find_exact(&node).is_some() {
|
||||||
|
tracing::trace!("hash already existed in tree");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::trace!("hash did not exist, adding to tree");
|
||||||
|
tree.add(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::error!("disconnected from listener, recreating tree");
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||||
|
let new_tree = create_tree(&conn, &config).await?;
|
||||||
|
{
|
||||||
|
let mut tree = tree.write().await;
|
||||||
|
*tree = new_tree;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user