1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
use rayon::prelude::*;

use Example;
use TFeature;
use commons::ExampleInSampleSet;
use commons::Model;
use commons::bins::Bins;
use scanner::buffer_loader::BufferLoader;

use commons::get_bound;

use super::learner::NUM_PREDS;
use super::learner::PREDS;
use super::learner::RuleStats;
use super::learner::TreeNode;


pub fn preprocess_data<'a>(
    data: &'a[ExampleInSampleSet], tree: &Model, expand_node: usize, rho_gamma: f32,
) -> Vec<(usize, f32, (&'a Example, RuleStats))> {
    data.par_iter().map(|(example, (weight, _, _, _))| {
        let labeled_weight = weight * (example.label as f32);
        let null_weight = 2.0 * rho_gamma * weight;
        let mut vals: RuleStats = [[(0.0, 0.0); 2]; NUM_PREDS];
        PREDS.iter().enumerate().for_each(|(i, pred)| {
            let abs_val = (pred.0 * labeled_weight, pred.1 * labeled_weight);
            let ci      = (abs_val.0 - null_weight, abs_val.1 - null_weight);
            vals[i][0]  = abs_val;
            vals[i][1]  = (
                ci.0 * ci.0 - null_weight * null_weight,
                ci.1 * ci.1 - null_weight * null_weight
            );
        });
        (tree.get_leaf_index_prediction(expand_node, example), *weight, (example, vals))
    }).collect()
}

// if `total_weight` put into account those examples that a node abstained, the comparison
// is then among all 'specialists'.
pub fn find_tree_node<'a>(
    data: &'a Vec<(f32, (&Example, RuleStats))>, feature_index: usize,
    rho_gamma: f32, count: usize, total_weight: f32, total_weight_sq: f32, expand_node: usize,
    bin: &'a Bins, weak_rules_score: &'a mut Vec<[f32; 2]>, sum_c_squared: &'a mut Vec<[f32; 2]>,
    debug_info: (((&'a mut Vec<f32>, &'a mut Vec<f32>), &'a mut Vec<f32>), &'a mut Vec<f32>),
) -> Option<TreeNode> {
    let (((num_positive, num_negative), weight_positive), weight_negative) = debug_info;

    // <Split, NodeId, RuleId, stats, LeftOrRight>
    // the last element of is for the examples that are larger than all split values
    let mut bin_accum_vals: Vec<RuleStats> =
        vec![[[(0.0, 0.0); 2]; NUM_PREDS]; bin.len() + 1];
    // Counts the total weights and the counts for both positive and negative examples
    let mut counts: [usize; 2] = [0, 0];
    let mut weights: [f32; 2]  = [0.0, 0.0];
    data.iter()
        .for_each(|(w, (example, vals))| {
            let flip_index = example.feature[feature_index] as usize;
            let t = &mut bin_accum_vals[flip_index];
            for j in 0..NUM_PREDS {
                for k in 0..t[j].len() {
                    t[j][k].0 += vals[j][k].0;
                    t[j][k].1 += vals[j][k].1;
                }
            }
            if example.label > 0 {
                counts[0]  += 1;
                weights[0] += w;
            } else {
                counts[1]  += 1;
                weights[1] += w;
            }
        });

    let mut accum_left  = [[0.0; 2]; NUM_PREDS];
    let mut accum_right = [[0.0; 2]; NUM_PREDS];
    // Accumulate sum of the stats of all examples that go to the right child
    for j in 0..bin.len() { // Split value
        for pred_idx in 0..NUM_PREDS { // Types of rule
            for it in 0..accum_right[pred_idx].len() {
                accum_right[pred_idx][it] +=
                    bin_accum_vals[j][pred_idx][it].1;
            }
        }
    }
    // Now update each splitting values of the bin
    let mut valid_weak_rule = None;
    (0..bin.len()).for_each(|j| {
        for pred_idx in 0..NUM_PREDS { // Types of rule
            // Move examples from the right to the left child
            for it in 0..accum_left[pred_idx].len() {
                accum_left[pred_idx][it]  +=
                    bin_accum_vals[j][pred_idx][it].0;
                accum_right[pred_idx][it] -=
                    bin_accum_vals[j][pred_idx][it].1;
            }
            let accum: Vec<f32> = accum_left[pred_idx].iter()
                                                        .zip(accum_right[pred_idx].iter())
                                                        .map(|(a, b)| *a + *b)
                                                        .collect();
            {
                let rule_idx = pred_idx;
                let weak_rules_score =
                    &mut weak_rules_score[j][rule_idx];
                let sum_c_squared    = &mut sum_c_squared[j][rule_idx];
                let num_positive = &mut num_positive[j];
                let num_negative = &mut num_negative[j];
                let weight_positive = &mut weight_positive[j];
                let weight_negative = &mut weight_negative[j];

                *weak_rules_score   += accum[0];
                *sum_c_squared      += accum[1];
                *num_positive       += counts[0] as f32;
                *num_negative       += counts[1] as f32;
                *weight_positive    += weights[0];
                *weight_negative    += weights[1];
                // Check stopping rule
                let sum_c = *weak_rules_score - 2.0 * rho_gamma * total_weight;
                let sum_c_squared = *sum_c_squared +
                    4.0 * rho_gamma * rho_gamma * total_weight_sq;
                let bound = get_bound(sum_c, sum_c_squared);
                if sum_c > bound {
                    let base_pred = 0.5 * (
                        (0.5 + rho_gamma) / (0.5 - rho_gamma)
                    ).ln();
                    let real_pred =
                        (base_pred * PREDS[pred_idx].0, base_pred * PREDS[pred_idx].1);
                    valid_weak_rule = Some(
                        TreeNode {
                            prt_index:      expand_node,
                            feature:        feature_index,
                            threshold:      j as TFeature,
                            predict:        real_pred,

                            gamma:          rho_gamma,
                            raw_martingale: *weak_rules_score,
                            sum_c:          sum_c,
                            sum_c_squared:  sum_c_squared,
                            bound:          bound,
                            num_scanned:    count,

                            positive:        *num_positive as usize,
                            negative:        *num_negative as usize,
                            positive_weight: *weight_positive,
                            negative_weight: *weight_negative,

                            fallback:       false,
                        }
                    );
                }
            }
        }
    });
    valid_weak_rule
}


pub fn gen_tree_node(
    expand_node_index: usize, feature_index: usize, bin_index: usize, rule_index: usize, ratio: f32,
) -> TreeNode {
    let rho_gamma = ratio / 2.0;
    let base_pred = 0.5 * (
        (0.5 + rho_gamma) / (0.5 - rho_gamma)
    ).ln();
    let real_pred =
        (base_pred * PREDS[rule_index].0, base_pred * PREDS[rule_index].1);
    TreeNode {
        prt_index:      expand_node_index,
        feature:        feature_index,
        threshold:      bin_index as TFeature,
        predict:        real_pred,
        gamma:          rho_gamma,

        fallback:        true,

        // other attributes are for debugging purpose only
        raw_martingale: 0.0,
        sum_c:          0.0,
        sum_c_squared:  0.0,
        bound:          0.0,
        num_scanned:    0,

        positive:        0,
        negative:        0,
        positive_weight: 0.0,
        negative_weight: 0.0,
    }
}


pub fn get_base_node(max_sample_size: usize, data_loader: &mut BufferLoader) -> (f32, f32, f32) {
    let mut sample_size = max_sample_size;
    let mut n_pos = 0;
    let mut n_neg = 0;
    while sample_size > 0 {
        let (data, _) = data_loader.get_next_batch(true);
        let (num_pos, num_neg) =
            data.par_iter().fold(
                || (0, 0),
                |(num_pos, num_neg), (example, _)| {
                    if example.label > 0 {
                        (num_pos + 1, num_neg)
                    } else {
                        (num_pos, num_neg + 1)
                    }
                }
            ).reduce(|| (0, 0), |(a1, a2), (b1, b2)| (a1 + b1, a2 + b2));
        n_pos += num_pos;
        n_neg += num_neg;
        sample_size -= data.len();
    }

    let gamma = (0.5 - n_pos as f32 / (n_pos + n_neg) as f32).abs();
    let prediction = 0.5 * (n_pos as f32 / n_neg as f32).ln();
    info!("root-tree-info, {}, {}, {}, {}", 1, max_sample_size, gamma, gamma * gamma);
    (gamma, prediction, gamma)
}