Commit a5550681 authored by ehebrard's avatar ehebrard
Browse files

invariants

parent 6604c615
......@@ -12,11 +12,23 @@ from the `code` directory.
- Command line
to learn a decision tree classifier:
```
bin/blossom <datafile> [options]
```
to binarize a non-binary dataset:
```
bin/binarizer <datafile> --print_ins [--output filename]
```
to compile a table:
```
bin/compile <datafile> [options]
```
- High level API (Sklearn estimator)
......
......@@ -101,6 +101,9 @@ int main(int argc, char *argv[]) {
// opt.outtarget != -1, false);
// // cout << input << endl;
if (opt.sample < 1)
input.sample(opt.sample);
////// PREPROCESING
if (opt.preprocessing)
input.preprocess(opt.verbosity >= DTOptions::NORMAL);
......
......@@ -71,8 +71,11 @@ int run_algorithm(DTOptions &opt) {
A.minimize_error_depth_size();
else
A.minimize_error_depth();
} else
} else {
if (opt.minsize)
A.set_size_objective();
A.minimize_error();
}
Tree<E_t> sol = A.getSolution();
......
......@@ -115,6 +115,8 @@ public:
bool search();
void initialise_search();
void set_size_objective() { size_matters = true; }
void minimize_error();
void minimize_error_depth();
......
......@@ -3,6 +3,7 @@
#include <vector>
#include <algorithm>
#include <random>
#include "typedef.hpp"
#include "CmdLine.hpp"
......@@ -30,6 +31,9 @@ public:
// template <class Algo> void toInc(Algo &algo);
// template <class Algo> void setup(Algo &algo) const;
void preprocess(const bool verbose = false);
// randomly select ratio * count(c) examples from classes c in {0,1}
void sample(const double ratio, const long seed=12345);
size_t input_count(const bool c) const { return data[c].size(); }
size_t input_example_count() const { return input_count(0) + input_count(1); }
......@@ -166,7 +170,20 @@ inline void WeightedDataset<E_t>::addExample(rIter beg_row, rIter end_row,
// algo.setErrorOffset(suppression_count);
// }
template <typename E_t> inline void WeightedDataset<E_t>::preprocess(const bool verbose) {
template <typename E_t> void WeightedDataset<E_t>::sample(const double ratio, const long seed) {
mt19937 random_generator;
random_generator.seed(seed);
for(auto y{0}; y<2; ++y) {
size_t target{static_cast<size_t>(static_cast<double>(count(y)) * ratio)};
while(count(y) > target) {
auto i{random_generator() % count(y)};
examples[y].remove_back(examples[y][i]);
}
}
}
template <typename E_t> void WeightedDataset<E_t>::preprocess(const bool verbose) {
auto t{cpu_time()};
......
......@@ -1415,6 +1415,47 @@ void BacktrackingAlgorithm<ErrorPolicy, E_t>::minimize_error_depth() {
}
// template <template<typename> class ErrorPolicy, typename E_t>
// void BacktrackingAlgorithm<ErrorPolicy, E_t>::minimize_error_size() {
//
// initialise_search();
//
// if (options.verbosity > DTOptions::QUIET)
// separator("search");
//
// auto perfect{false};
// // auto saved_error{ub_error};
// while (search() and is_null<E_t>(ub_error)) {
// perfect = true;
// // saved_error = ub_error;
// ub_error = min_positive<E_t>();
// ub_depth = actual_depth - 1;
//
// restart(true);
//
// if (ub_depth == 1)
// singleDecision();
// }
//
// if (perfect) {
// ++ub_depth;
// ub_error = 0;
// }
// // else
// // cleaning();
//
// if (options.verbosity > DTOptions::QUIET) {
// if (interrupted)
// separator("interrupted");
// else
// separator("optimal");
// }
//
// if (options.verbosity > DTOptions::SILENT)
// print_new_best();
// }
template <template<typename> class ErrorPolicy, typename E_t>
void BacktrackingAlgorithm<ErrorPolicy, E_t>::minimize_error_depth_size() {
......
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment