1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/*!
Sparrow is an implementation of TMSN for boosting.

From a high level, Sparrow consists of three components,

1. scanner: it runs the boosting process, which scans the samples in memory, and
updates the current model by finding a new weak rule to be added to the score function;
2. sampler: it samples examples from disk and updates their scores according to the latest
score function,
3. model manager: it assigns tasks to the scanners, receives model updates from them, and
maintains the current score function.
*/
#[macro_use] extern crate crossbeam_channel;
#[macro_use] extern crate lazy_static;
#[macro_use] extern crate log;
#[macro_use] extern crate serde_derive;
extern crate bincode;
extern crate bufstream;
extern crate evmap;
extern crate ordered_float;
extern crate rand;
extern crate rayon;
extern crate s3;
extern crate serde_json;
extern crate serde_yaml;
extern crate threadpool;
extern crate time;
extern crate tmsn;
extern crate metricslib;

/// Common functions and classes.
mod commons;
mod config;
/// Validating models
mod testing;
/// Implementation of the components running on head node, specifically the scanner
/// and the model manager
pub mod head;
/// Implementation of the scanner
/// ![](/images/scanner.png)
pub mod scanner;

use config::Config;
use config::SampleMode;
use commons::Model;
use commons::bins::Bins;
use commons::tree::ADTree as Tree;

use scanner::start as start_scanner;
use head::start_head;
use testing::validate;

use commons::bins::load_bins;
use commons::io::clear_s3_bucket;
use commons::persistent_io::read_model;

// Types
// TODO: decide TFeature according to the bin size
use commons::labeled_data::LabeledData;
type RawTFeature = f32;
type TFeature = u8;
type TLabel = i8;
type RawExample = LabeledData<RawTFeature, TLabel>;
type Example = LabeledData<TFeature, TLabel>;

const REGION:   &str = "us-east-1";
const BUCKET:   &str = "tmsn-cache2";


fn prep_training(config_filepath: &String) -> (Config, SampleMode, Vec<Bins>) {
    // Load configurations
    let config = Config::new(config_filepath);
    let sample_mode = SampleMode::new(&config.sampling_mode);

    // Clear S3 before running
    if sample_mode == SampleMode::S3 {
        clear_s3_bucket(REGION, BUCKET, config.exp_name.as_str());
    }

    debug!("Loading bins.");
    let bins = load_bins(config.sampler_scanner.as_str(), Some(&config));
    (config, sample_mode, bins)
}

/// Train a model
///
/// Parameter:
///
/// * config_filepath: the filepath to the configuration file
pub fn training(config_filepath: &String) {
    let (config, sample_mode, bins) = prep_training(config_filepath);
    let init_tree: Model = {
        if config.resume_training && config.sampler_scanner == "sampler" {
            // Resuming from an earlier training
            debug!("resume_training is enabled");
            let (_, _, mut model) = read_model();
            model.base_version = 0;
            debug!("Loaded an existing tree");
            model
        } else {
            debug!("Created a new tree");
            // TODO: extend for the cases that more than 4 nodes were used for creating grids
            Tree::new(config.num_trees * (4 + config.num_splits + 1) + 10)
        }
    };
    if config.sampler_scanner == "scanner" {
        start_scanner(&config, &sample_mode, &bins, &init_tree);
    } else { // if config.sampler_scanner == "sampler"
        start_head(&config, &sample_mode, &bins, &init_tree);
    }
}


/// Test a model
///
/// Parameter:
///
/// * config_filepath: the filepath to the configuration file
pub fn testing(config_filepath: &String) {
    // Load configurations
    let config: Config = Config::new(config_filepath);
    validate(
        config.models_table_filename.clone(),
        config.testing_filename.clone(),
        config.num_testing_examples,
        config.num_features,
        config.batch_size,
        config.positive.clone(),
        config.incremental_testing,
        config.testing_scores_only,
    );
}