1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#[macro_use] extern crate crossbeam_channel;
#[macro_use] extern crate lazy_static;
#[macro_use] extern crate log;
#[macro_use] extern crate serde_derive;
extern crate bincode;
extern crate bufstream;
extern crate evmap;
extern crate ordered_float;
extern crate rand;
extern crate rayon;
extern crate s3;
extern crate serde_json;
extern crate serde_yaml;
extern crate threadpool;
extern crate time;
extern crate tmsn;
extern crate metricslib;
mod commons;
mod config;
mod testing;
pub mod head;
pub mod scanner;
use config::Config;
use config::SampleMode;
use commons::Model;
use commons::bins::Bins;
use commons::tree::ADTree as Tree;
use scanner::start as start_scanner;
use head::start_head;
use testing::validate;
use commons::bins::load_bins;
use commons::io::clear_s3_bucket;
use commons::persistent_io::read_model;
use commons::labeled_data::LabeledData;
type RawTFeature = f32;
type TFeature = u8;
type TLabel = i8;
type RawExample = LabeledData<RawTFeature, TLabel>;
type Example = LabeledData<TFeature, TLabel>;
const REGION: &str = "us-east-1";
const BUCKET: &str = "tmsn-cache2";
fn prep_training(config_filepath: &String) -> (Config, SampleMode, Vec<Bins>) {
let config = Config::new(config_filepath);
let sample_mode = SampleMode::new(&config.sampling_mode);
if sample_mode == SampleMode::S3 {
clear_s3_bucket(REGION, BUCKET, config.exp_name.as_str());
}
debug!("Loading bins.");
let bins = load_bins(config.sampler_scanner.as_str(), Some(&config));
(config, sample_mode, bins)
}
pub fn training(config_filepath: &String) {
let (config, sample_mode, bins) = prep_training(config_filepath);
let init_tree: Model = {
if config.resume_training && config.sampler_scanner == "sampler" {
debug!("resume_training is enabled");
let (_, _, mut model) = read_model();
model.base_version = 0;
debug!("Loaded an existing tree");
model
} else {
debug!("Created a new tree");
Tree::new(config.num_trees * (4 + config.num_splits + 1) + 10)
}
};
if config.sampler_scanner == "scanner" {
start_scanner(&config, &sample_mode, &bins, &init_tree);
} else {
start_head(&config, &sample_mode, &bins, &init_tree);
}
}
pub fn testing(config_filepath: &String) {
let config: Config = Config::new(config_filepath);
validate(
config.models_table_filename.clone(),
config.testing_filename.clone(),
config.num_testing_examples,
config.num_features,
config.batch_size,
config.positive.clone(),
config.incremental_testing,
config.testing_scores_only,
);
}