mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-12-22 09:15:33 +00:00
Latest navi open source refresh
latest code change including the global thread pool Closes twitter/the-algorithm#452 Closes twitter/the-algorithm#505
This commit is contained in:
parent
6e5c875a69
commit
4df87a278e
|
@ -31,6 +31,11 @@ In navi/navi, you can run the following commands:
|
|||
- `scripts/run_onnx.sh` for [Onnx](https://onnx.ai/)
|
||||
|
||||
Do note that you need to create a models directory and create some versions, preferably using epoch time, e.g., `1679693908377`.
|
||||
so the models structure looks like:
|
||||
models/
|
||||
-web_click
|
||||
- 1809000
|
||||
- 1809010
|
||||
|
||||
## Build
|
||||
You can adapt the above scripts to build using Cargo.
|
||||
|
|
|
@ -3,7 +3,6 @@ name = "dr_transform"
|
|||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
@ -12,7 +11,6 @@ bpr_thrift = { path = "../thrift_bpr_adapter/thrift/"}
|
|||
segdense = { path = "../segdense/"}
|
||||
thrift = "0.17.0"
|
||||
ndarray = "0.15"
|
||||
ort = {git ="https://github.com/pykeio/ort.git", tag="v1.14.2"}
|
||||
base64 = "0.20.0"
|
||||
npyz = "0.7.2"
|
||||
log = "0.4.17"
|
||||
|
@ -21,6 +19,11 @@ prometheus = "0.13.1"
|
|||
once_cell = "1.17.0"
|
||||
rand = "0.8.5"
|
||||
itertools = "0.10.5"
|
||||
anyhow = "1.0.70"
|
||||
[target.'cfg(not(target_os="linux"))'.dependencies]
|
||||
ort = {git ="https://github.com/pykeio/ort.git", features=["profiling"], tag="v1.14.6"}
|
||||
[target.'cfg(target_os="linux")'.dependencies]
|
||||
ort = {git ="https://github.com/pykeio/ort.git", features=["profiling", "tensorrt", "cuda", "copy-dylibs"], tag="v1.14.6"}
|
||||
[dev-dependencies]
|
||||
criterion = "0.3.0"
|
||||
|
||||
|
|
|
@ -3,3 +3,4 @@ pub mod converter;
|
|||
#[cfg(test)]
|
||||
mod test;
|
||||
pub mod util;
|
||||
pub extern crate ort;
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
[package]
|
||||
name = "navi"
|
||||
version = "2.0.42"
|
||||
version = "2.0.45"
|
||||
edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[[bin]]
|
||||
name = "navi"
|
||||
|
@ -16,12 +15,19 @@ required-features=["torch"]
|
|||
name = "navi_onnx"
|
||||
path = "src/bin/navi_onnx.rs"
|
||||
required-features=["onnx"]
|
||||
[[bin]]
|
||||
name = "navi_onnx_test"
|
||||
path = "src/bin/bin_tests/navi_onnx_test.rs"
|
||||
[[bin]]
|
||||
name = "navi_torch_test"
|
||||
path = "src/bin/bin_tests/navi_torch_test.rs"
|
||||
required-features=["torch"]
|
||||
|
||||
[features]
|
||||
default=[]
|
||||
navi_console=[]
|
||||
torch=["tch"]
|
||||
onnx=["ort"]
|
||||
onnx=[]
|
||||
tf=["tensorflow"]
|
||||
[dependencies]
|
||||
itertools = "0.10.5"
|
||||
|
@ -47,6 +53,7 @@ parking_lot = "0.12.1"
|
|||
rand = "0.8.5"
|
||||
rand_pcg = "0.3.1"
|
||||
random = "0.12.2"
|
||||
x509-parser = "0.15.0"
|
||||
sha256 = "1.0.3"
|
||||
tonic = { version = "0.6.2", features=['compression', 'tls'] }
|
||||
tokio = { version = "1.17.0", features = ["macros", "rt-multi-thread", "fs", "process"] }
|
||||
|
@ -55,16 +62,12 @@ npyz = "0.7.3"
|
|||
base64 = "0.21.0"
|
||||
histogram = "0.6.9"
|
||||
tch = {version = "0.10.3", optional = true}
|
||||
tensorflow = { version = "0.20.0", optional = true }
|
||||
tensorflow = { version = "0.18.0", optional = true }
|
||||
once_cell = {version = "1.17.1"}
|
||||
ndarray = "0.15"
|
||||
serde = "1.0.154"
|
||||
serde_json = "1.0.94"
|
||||
dr_transform = { path = "../dr_transform"}
|
||||
[target.'cfg(not(target_os="linux"))'.dependencies]
|
||||
ort = {git ="https://github.com/pykeio/ort.git", features=["profiling"], optional = true, tag="v1.14.2"}
|
||||
[target.'cfg(target_os="linux")'.dependencies]
|
||||
ort = {git ="https://github.com/pykeio/ort.git", features=["profiling", "tensorrt", "cuda", "copy-dylibs"], optional = true, tag="v1.14.2"}
|
||||
[build-dependencies]
|
||||
tonic-build = {version = "0.6.2", features=['prost', "compression"] }
|
||||
[profile.release]
|
||||
|
@ -74,3 +77,5 @@ ndarray-rand = "0.14.0"
|
|||
tokio-test = "*"
|
||||
assert_cmd = "2.0"
|
||||
criterion = "0.4.0"
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
#!/bin/sh
|
||||
#RUST_LOG=debug LD_LIBRARY_PATH=so/onnx/lib target/release/navi_onnx --port 30 --num-worker-threads 8 --intra-op-parallelism 8 --inter-op-parallelism 8 \
|
||||
RUST_LOG=info LD_LIBRARY_PATH=so/onnx/lib cargo run --bin navi_onnx --features onnx -- \
|
||||
--port 30 --num-worker-threads 8 --intra-op-parallelism 8 --inter-op-parallelism 8 \
|
||||
--port 8030 --num-worker-threads 8 \
|
||||
--model-check-interval-secs 30 \
|
||||
--model-dir models/int8 \
|
||||
--output caligrated_probabilities \
|
||||
--input "" \
|
||||
--modelsync-cli "echo" \
|
||||
--onnx-ep-options use_arena=true
|
||||
--onnx-ep-options use_arena=true \
|
||||
--model-dir models/prod_home --output caligrated_probabilities --input "" --intra-op-parallelism 8 --inter-op-parallelism 8 --max-batch-size 1 --batch-time-out-millis 1 \
|
||||
--model-dir models/prod_home1 --output caligrated_probabilities --input "" --intra-op-parallelism 8 --inter-op-parallelism 8 --max-batch-size 1 --batch-time-out-millis 1 \
|
||||
|
|
|
@ -1,11 +1,24 @@
|
|||
use anyhow::Result;
|
||||
use log::info;
|
||||
use navi::cli_args::{ARGS, MODEL_SPECS};
|
||||
use navi::onnx_model::onnx::OnnxModel;
|
||||
use navi::{bootstrap, metrics};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
env_logger::init();
|
||||
assert_eq!(MODEL_SPECS.len(), ARGS.inter_op_parallelism.len());
|
||||
info!("global: {:?}", ARGS.onnx_global_thread_pool_options);
|
||||
let assert_session_params = if ARGS.onnx_global_thread_pool_options.is_empty() {
|
||||
// std::env::set_var("OMP_NUM_THREADS", "1");
|
||||
info!("now we use per session thread pool");
|
||||
MODEL_SPECS.len()
|
||||
}
|
||||
else {
|
||||
info!("now we use global thread pool");
|
||||
0
|
||||
};
|
||||
assert_eq!(assert_session_params, ARGS.inter_op_parallelism.len());
|
||||
assert_eq!(assert_session_params, ARGS.inter_op_parallelism.len());
|
||||
|
||||
metrics::register_custom_metrics();
|
||||
bootstrap::bootstrap(OnnxModel::new)
|
||||
}
|
||||
|
|
|
@ -207,6 +207,9 @@ impl<T: Model> PredictionService for PredictService<T> {
|
|||
PredictResult::DropDueToOverload => Err(Status::resource_exhausted("")),
|
||||
PredictResult::ModelNotFound(idx) => {
|
||||
Err(Status::not_found(format!("model index {}", idx)))
|
||||
},
|
||||
PredictResult::ModelNotReady(idx) => {
|
||||
Err(Status::unavailable(format!("model index {}", idx)))
|
||||
}
|
||||
PredictResult::ModelVersionNotFound(idx, version) => Err(
|
||||
Status::not_found(format!("model index:{}, version {}", idx, version)),
|
||||
|
|
|
@ -87,13 +87,11 @@ pub struct Args {
|
|||
pub intra_op_parallelism: Vec<String>,
|
||||
#[clap(
|
||||
long,
|
||||
default_value = "14",
|
||||
help = "number of threads to parallelize computations of the graph"
|
||||
)]
|
||||
pub inter_op_parallelism: Vec<String>,
|
||||
#[clap(
|
||||
long,
|
||||
default_value = "serving_default",
|
||||
help = "signature of a serving. only TF"
|
||||
)]
|
||||
pub serving_sig: Vec<String>,
|
||||
|
@ -107,10 +105,12 @@ pub struct Args {
|
|||
help = "max warmup records to use. warmup only implemented for TF"
|
||||
)]
|
||||
pub max_warmup_records: usize,
|
||||
#[clap(long, value_parser = Args::parse_key_val::<String, String>, value_delimiter=',')]
|
||||
pub onnx_global_thread_pool_options: Vec<(String, String)>,
|
||||
#[clap(
|
||||
long,
|
||||
default_value = "true",
|
||||
help = "when to use graph parallelization. only for ONNX"
|
||||
long,
|
||||
default_value = "true",
|
||||
help = "when to use graph parallelization. only for ONNX"
|
||||
)]
|
||||
pub onnx_use_parallel_mode: String,
|
||||
// #[clap(long, default_value = "false")]
|
||||
|
|
|
@ -146,6 +146,7 @@ pub enum PredictResult {
|
|||
Ok(Vec<TensorScores>, i64),
|
||||
DropDueToOverload,
|
||||
ModelNotFound(usize),
|
||||
ModelNotReady(usize),
|
||||
ModelVersionNotFound(usize, i64),
|
||||
}
|
||||
|
||||
|
|
|
@ -13,21 +13,22 @@ pub mod onnx {
|
|||
use dr_transform::converter::{BatchPredictionRequestToTorchTensorConverter, Converter};
|
||||
use itertools::Itertools;
|
||||
use log::{debug, info};
|
||||
use ort::environment::Environment;
|
||||
use ort::session::Session;
|
||||
use ort::tensor::InputTensor;
|
||||
use ort::{ExecutionProvider, GraphOptimizationLevel, SessionBuilder};
|
||||
use dr_transform::ort::environment::Environment;
|
||||
use dr_transform::ort::session::Session;
|
||||
use dr_transform::ort::tensor::InputTensor;
|
||||
use dr_transform::ort::{ExecutionProvider, GraphOptimizationLevel, SessionBuilder};
|
||||
use dr_transform::ort::LoggingLevel;
|
||||
use serde_json::Value;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::sync::Arc;
|
||||
use std::{fmt, fs};
|
||||
use tokio::time::Instant;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref ENVIRONMENT: Arc<Environment> = Arc::new(
|
||||
Environment::builder()
|
||||
.with_name("onnx home")
|
||||
.with_log_level(ort::LoggingLevel::Error)
|
||||
.with_log_level(LoggingLevel::Error)
|
||||
.with_global_thread_pool(ARGS.onnx_global_thread_pool_options.clone())
|
||||
.build()
|
||||
.unwrap()
|
||||
);
|
||||
|
@ -101,23 +102,30 @@ pub mod onnx {
|
|||
let meta_info = format!("{}/{}/{}", ARGS.model_dir[idx], version, META_INFO);
|
||||
let mut builder = SessionBuilder::new(&ENVIRONMENT)?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_parallel_execution(ARGS.onnx_use_parallel_mode == "true")?
|
||||
.with_inter_threads(
|
||||
utils::get_config_or(
|
||||
model_config,
|
||||
"inter_op_parallelism",
|
||||
&ARGS.inter_op_parallelism[idx],
|
||||
)
|
||||
.parse()?,
|
||||
)?
|
||||
.with_intra_threads(
|
||||
utils::get_config_or(
|
||||
model_config,
|
||||
"intra_op_parallelism",
|
||||
&ARGS.intra_op_parallelism[idx],
|
||||
)
|
||||
.parse()?,
|
||||
)?
|
||||
.with_parallel_execution(ARGS.onnx_use_parallel_mode == "true")?;
|
||||
if ARGS.onnx_global_thread_pool_options.is_empty() {
|
||||
builder = builder
|
||||
.with_inter_threads(
|
||||
utils::get_config_or(
|
||||
model_config,
|
||||
"inter_op_parallelism",
|
||||
&ARGS.inter_op_parallelism[idx],
|
||||
)
|
||||
.parse()?,
|
||||
)?
|
||||
.with_intra_threads(
|
||||
utils::get_config_or(
|
||||
model_config,
|
||||
"intra_op_parallelism",
|
||||
&ARGS.intra_op_parallelism[idx],
|
||||
)
|
||||
.parse()?,
|
||||
)?;
|
||||
}
|
||||
else {
|
||||
builder = builder.with_disable_per_session_threads()?;
|
||||
}
|
||||
builder = builder
|
||||
.with_memory_pattern(ARGS.onnx_use_memory_pattern == "true")?
|
||||
.with_execution_providers(&OnnxModel::ep_choices())?;
|
||||
match &ARGS.profiling {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use arrayvec::ArrayVec;
|
||||
use itertools::Itertools;
|
||||
use log::{error, info, warn};
|
||||
use log::{error, info};
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::string::String;
|
||||
use std::sync::Arc;
|
||||
|
@ -179,17 +179,17 @@ impl<T: Model> PredictService<T> {
|
|||
//initialize the latest version array
|
||||
let mut cur_versions = vec!["".to_owned(); MODEL_SPECS.len()];
|
||||
loop {
|
||||
let config = utils::read_config(&meta_file).unwrap_or_else(|e| {
|
||||
warn!("config file {} not found due to: {}", meta_file, e);
|
||||
Value::Null
|
||||
});
|
||||
info!("***polling for models***"); //nice deliminter
|
||||
info!("config:{}", config);
|
||||
if let Some(ref cli) = ARGS.modelsync_cli {
|
||||
if let Err(e) = call_external_modelsync(cli, &cur_versions).await {
|
||||
error!("model sync cli running error:{}", e)
|
||||
}
|
||||
}
|
||||
let config = utils::read_config(&meta_file).unwrap_or_else(|e| {
|
||||
info!("config file {} not found due to: {}", meta_file, e);
|
||||
Value::Null
|
||||
});
|
||||
info!("config:{}", config);
|
||||
for (idx, cur_version) in cur_versions.iter_mut().enumerate() {
|
||||
let model_dir = &ARGS.model_dir[idx];
|
||||
PredictService::scan_load_latest_model_from_model_dir(
|
||||
|
@ -229,26 +229,32 @@ impl<T: Model> PredictService<T> {
|
|||
let no_more_msg = match msg {
|
||||
Ok(PredictMessage::Predict(model_spec_at, version, val, resp, ts)) => {
|
||||
if let Some(model_predictors) = all_model_predictors.get_mut(model_spec_at) {
|
||||
match version {
|
||||
None => model_predictors[0].push(val, resp, ts),
|
||||
Some(the_version) => match model_predictors
|
||||
.iter_mut()
|
||||
.find(|x| x.model.version() == the_version)
|
||||
{
|
||||
None => resp
|
||||
.send(PredictResult::ModelVersionNotFound(
|
||||
model_spec_at,
|
||||
the_version,
|
||||
))
|
||||
.unwrap_or_else(|e| {
|
||||
error!("cannot send back version error: {:?}", e)
|
||||
}),
|
||||
Some(predictor) => predictor.push(val, resp, ts),
|
||||
},
|
||||
if model_predictors.is_empty() {
|
||||
resp.send(PredictResult::ModelNotReady(model_spec_at))
|
||||
.unwrap_or_else(|e| error!("cannot send back model not ready error: {:?}", e));
|
||||
}
|
||||
else {
|
||||
match version {
|
||||
None => model_predictors[0].push(val, resp, ts),
|
||||
Some(the_version) => match model_predictors
|
||||
.iter_mut()
|
||||
.find(|x| x.model.version() == the_version)
|
||||
{
|
||||
None => resp
|
||||
.send(PredictResult::ModelVersionNotFound(
|
||||
model_spec_at,
|
||||
the_version,
|
||||
))
|
||||
.unwrap_or_else(|e| {
|
||||
error!("cannot send back version error: {:?}", e)
|
||||
}),
|
||||
Some(predictor) => predictor.push(val, resp, ts),
|
||||
},
|
||||
}
|
||||
}
|
||||
} else {
|
||||
resp.send(PredictResult::ModelNotFound(model_spec_at))
|
||||
.unwrap_or_else(|e| error!("cannot send back model error: {:?}", e))
|
||||
.unwrap_or_else(|e| error!("cannot send back model not found error: {:?}", e))
|
||||
}
|
||||
MPSC_CHANNEL_SIZE.dec();
|
||||
false
|
||||
|
|
|
@ -3,9 +3,9 @@ name = "segdense"
|
|||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
env_logger = "0.10.0"
|
||||
serde = { version = "1.0.104", features = ["derive"] }
|
||||
serde_json = "1.0.48"
|
||||
log = "0.4.17"
|
||||
|
|
Loading…
Reference in a new issue