Commit 8c8521da authored by ehebrard's avatar ehebrard
Browse files

automatic binarization

parent 05284541
...@@ -85,21 +85,13 @@ int main(int argc, char *argv[]) { ...@@ -85,21 +85,13 @@ int main(int argc, char *argv[]) {
WeightedDataset<int> input; WeightedDataset<int> input;
////// READING ////// READING
if (opt.binarize) { // try {
// read_binary(input, opt);
read_non_binary(input, opt); // } catch (const std::exception &e) {
if (opt.verbosity >= DTOptions::NORMAL)
} else { cout << "c format not recognized or input non-binary, binarizing...\n";
read_non_binary(input, opt);
read_binary(input, opt); // }
}
// std::function<bool(const int f)> relevant = [](const int f) { return true; };
// input.printDatasetToFile(cout, string(" "), string(""), relevant,
// opt.outtarget != -1, false);
// cout << input << endl;
if (opt.sample < 1) if (opt.sample < 1)
input.sample(opt.sample); input.sample(opt.sample);
......
...@@ -38,13 +38,13 @@ int run_algorithm(DTOptions &opt) { ...@@ -38,13 +38,13 @@ int run_algorithm(DTOptions &opt) {
WeightedDataset<E_t> input; WeightedDataset<E_t> input;
////// READING ////// READING
if (opt.binarize) {
read_non_binary(input, opt);
} else {
try {
read_binary(input, opt); read_binary(input, opt);
} catch (const std::exception &e) {
if (opt.verbosity >= DTOptions::NORMAL)
cout << "c format not recognized or input non-binary, binarizing...\n";
read_non_binary(input, opt);
} }
if (opt.verbosity >= DTOptions::NORMAL) if (opt.verbosity >= DTOptions::NORMAL)
...@@ -66,9 +66,6 @@ int run_algorithm(DTOptions &opt) { ...@@ -66,9 +66,6 @@ int run_algorithm(DTOptions &opt) {
<< endl; << endl;
////// SOLVING ////// SOLVING
cout << "opt.mindepth: " << opt.mindepth << " opt.minsize " << opt.minsize << endl;
if (opt.mindepth) { if (opt.mindepth) {
if (opt.minsize) if (opt.minsize)
A.minimize_error_depth_size(); A.minimize_error_depth_size();
......
...@@ -39,10 +39,12 @@ int run_algorithm(DTOptions &opt) { ...@@ -39,10 +39,12 @@ int run_algorithm(DTOptions &opt) {
WeightedDataset<E_t> input; WeightedDataset<E_t> input;
////// READING ////// READING
if (opt.binarize) { try {
read_non_binary(input, opt);
} else {
read_binary(input, opt); read_binary(input, opt);
} catch (const std::exception &e) {
if (opt.verbosity >= DTOptions::NORMAL)
cout << "c format not recognized or input non-binary, binarizing...\n";
read_non_binary(input, opt);
} }
// in compilation, noise and duplicates must be removed (we should probably // in compilation, noise and duplicates must be removed (we should probably
......
...@@ -41,7 +41,7 @@ void read(const std::string &fn, header_declaration notify_header, ...@@ -41,7 +41,7 @@ void read(const std::string &fn, header_declaration notify_header,
} catch (std::exception &e) { } catch (std::exception &e) {
std::cout.flush(); std::cout.flush();
cerr << "ERROR: " << e.what() << std::endl; cerr << "ERROR: " << e.what() << std::endl;
exit(1); throw e;
} }
} }
...@@ -77,7 +77,7 @@ void read_binary(const std::string &fn, data_declaration notify_data, ...@@ -77,7 +77,7 @@ void read_binary(const std::string &fn, data_declaration notify_data,
} catch (std::exception &e) { } catch (std::exception &e) {
std::cout.flush(); std::cout.flush();
cerr << "ERROR: " << e.what() << std::endl; cerr << "ERROR: " << e.what() << std::endl;
exit(1); throw e;
} }
} }
......
...@@ -58,7 +58,6 @@ public: ...@@ -58,7 +58,6 @@ public:
enum feature_strategy { MINERROR = 0, ENTROPY = 1, GINI = 2, HYBRID = 3 }; enum feature_strategy { MINERROR = 0, ENTROPY = 1, GINI = 2, HYBRID = 3 };
int feature_strategy; int feature_strategy;
bool binarize;
double split; double split;
int ada_it; int ada_it;
int ada_stop; int ada_stop;
...@@ -92,13 +91,13 @@ public: ...@@ -92,13 +91,13 @@ public:
restart_base(opt.restart_base), restart_factor(opt.restart_factor), restart_base(opt.restart_base), restart_factor(opt.restart_factor),
time(opt.time), search(opt.search), bounding(opt.bounding), time(opt.time), search(opt.search), bounding(opt.bounding),
node_strategy(opt.node_strategy), node_strategy(opt.node_strategy),
feature_strategy(opt.feature_strategy), binarize(opt.binarize), feature_strategy(opt.feature_strategy), split(opt.split),
split(opt.split), ada_it(opt.ada_it), ada_stop(opt.ada_stop), ada_it(opt.ada_it), ada_stop(opt.ada_stop), filter(opt.filter),
filter(opt.filter), reference_class(opt.reference_class), reference_class(opt.reference_class), mindepth(opt.mindepth),
mindepth(opt.mindepth), minsize(opt.minsize), minsize(opt.minsize), preprocessing(opt.preprocessing),
preprocessing(opt.preprocessing), progress(opt.progress), progress(opt.progress), delimiter(opt.delimiter),
delimiter(opt.delimiter), intarget(opt.intarget), intarget(opt.intarget), outtarget(opt.outtarget), pruning(opt.pruning) {
outtarget(opt.outtarget), pruning(opt.pruning) {} }
ostream &display(ostream &os); ostream &display(ostream &os);
}; };
......
...@@ -56,8 +56,8 @@ instance concatenate(const instance &w1, const instance &w2) { ...@@ -56,8 +56,8 @@ instance concatenate(const instance &w1, const instance &w2) {
template <typename T> class Encoding { template <typename T> class Encoding {
public: public:
virtual ~Encoding() {} virtual ~Encoding() {}
// encode the values of the iterator // encode the values of the iterator
virtual void encode(typename std::vector<T>::iterator beg, virtual void encode(typename std::vector<T>::iterator beg,
typename std::vector<T>::iterator end) = 0; typename std::vector<T>::iterator end) = 0;
...@@ -66,8 +66,8 @@ public: ...@@ -66,8 +66,8 @@ public:
virtual const instance &getEncoding(T &x) const = 0; virtual const instance &getEncoding(T &x) const = 0;
virtual size_t size() const = 0; virtual size_t size() const = 0;
virtual const string getType() const = 0; virtual const string getType() const = 0;
// returns (in string format) the test corresponding to x[i]==v // returns (in string format) the test corresponding to x[i]==v
virtual const string getLabel(const int i, const int v) const = 0; virtual const string getLabel(const int i, const int v) const = 0;
...@@ -87,8 +87,8 @@ protected: ...@@ -87,8 +87,8 @@ protected:
vector<instance> lit; vector<instance> lit;
public: public:
// virtual ~TrivialEncoding() {} // virtual ~TrivialEncoding() {}
virtual size_t size() const { return 1; } virtual size_t size() const { return 1; }
// encode the values of the iterator // encode the values of the iterator
...@@ -105,12 +105,12 @@ public: ...@@ -105,12 +105,12 @@ public:
virtual const instance &getEncoding(T &x) const { virtual const instance &getEncoding(T &x) const {
return lit[(x == value_set[1])]; return lit[(x == value_set[1])];
} }
virtual const string getType() const { virtual const string getType() const {
std::stringstream ss; std::stringstream ss;
ss << "trivial " << value_set[0] << " " << value_set[1]; ss << "trivial " << value_set[0] << " " << value_set[1];
return ss.str(); return ss.str();
} }
// returns (in string format) the test x[i] // returns (in string format) the test x[i]
virtual const string getLabel(const int i, const int v) const { virtual const string getLabel(const int i, const int v) const {
...@@ -134,9 +134,8 @@ protected: ...@@ -134,9 +134,8 @@ protected:
vector<T> value_set; vector<T> value_set;
public: public:
// virtual ~ClassicEncoding() {}
// virtual ~ClassicEncoding() {}
// encode the values of the iterator // encode the values of the iterator
// template <typename RandomIt> // template <typename RandomIt>
virtual void encode(typename std::vector<T>::iterator beg, virtual void encode(typename std::vector<T>::iterator beg,
...@@ -190,7 +189,7 @@ public: ...@@ -190,7 +189,7 @@ public:
} }
} }
virtual const string getType() const {return "binary-direct";} virtual const string getType() const { return "binary-direct"; }
// returns (in string format) the test x[i] // returns (in string format) the test x[i]
const string getLabel(const int i, const int v) const { const string getLabel(const int i, const int v) const {
...@@ -229,8 +228,8 @@ public: ...@@ -229,8 +228,8 @@ public:
ClassicEncoding<T>::encoding_map[*i] = e; ClassicEncoding<T>::encoding_map[*i] = e;
} }
} }
virtual const string getType() const {return "binary-scaled";} virtual const string getType() const { return "binary-scaled"; }
// returns (in string format) the test x[i] // returns (in string format) the test x[i]
const string getLabel(const int i, const int v) const { return "&=?"; } const string getLabel(const int i, const int v) const { return "&=?"; }
...@@ -240,10 +239,12 @@ template <typename T> class Order : public ClassicEncoding<T> { ...@@ -240,10 +239,12 @@ template <typename T> class Order : public ClassicEncoding<T> {
private: private:
size_t num_examples; size_t num_examples;
const string &feature_name;
public: public:
Order(const size_t n) : ClassicEncoding<T>(), num_examples(n) {} Order(const size_t n, const string &f)
: ClassicEncoding<T>(), num_examples(n), feature_name(f) {}
virtual size_t size() const { virtual size_t size() const {
return ClassicEncoding<T>::value_set.size() - 1; return ClassicEncoding<T>::value_set.size() - 1;
} }
...@@ -259,22 +260,31 @@ public: ...@@ -259,22 +260,31 @@ public:
// for (auto v : ClassicEncoding<T>::value_set) // for (auto v : ClassicEncoding<T>::value_set)
// cout << " " << v; // cout << " " << v;
// cout << endl; // cout << endl;
// cout << ClassicEncoding<T>::value_set.size() << " " << (num_examples) ; // cout << ClassicEncoding<T>::value_set.size() << " " << (num_examples) ;
if(ClassicEncoding<T>::value_set.size() < sqrt(num_examples)) { if (ClassicEncoding<T>::value_set.size() < sqrt(num_examples)) {
// cout << " full\n"; // cout << " full\n";
full_encoding(); full_encoding();
} else { } else {
// cout << " reduced\n";
size_t num_intervals{static_cast<size_t>(
reduced_encoding(); log(static_cast<double>(ClassicEncoding<T>::value_set.size())))};
} if (num_intervals < 1)
num_intervals = 1;
cout << "c possible precision loss when binarizing feature "
<< feature_name << " (" << ClassicEncoding<T>::value_set.size()
<< " distinct values -> " << num_intervals << " intervals)"
<< "\n";
reduced_encoding(num_intervals);
}
} }
void full_encoding() { void full_encoding() {
auto vb{ClassicEncoding<T>::value_set.begin()}; auto vb{ClassicEncoding<T>::value_set.begin()};
auto ve{ClassicEncoding<T>::value_set.end() - 1}; auto ve{ClassicEncoding<T>::value_set.end() - 1};
...@@ -284,14 +294,9 @@ public: ...@@ -284,14 +294,9 @@ public:
e.resize(ve - vb, true); e.resize(ve - vb, true);
ClassicEncoding<T>::encoding_map[*i] = e; ClassicEncoding<T>::encoding_map[*i] = e;
} }
}
}
void reduced_encoding(const size_t num_intervals) {
void reduced_encoding() {
size_t num_intervals{static_cast<size_t>(log(static_cast<double>(ClassicEncoding<T>::value_set.size())))};
if(num_intervals < 1)
num_intervals = 1;
vector<size_t> boundary; vector<size_t> boundary;
size_t i_size{ClassicEncoding<T>::value_set.size() / num_intervals}; size_t i_size{ClassicEncoding<T>::value_set.size() / num_intervals};
...@@ -316,10 +321,9 @@ public: ...@@ -316,10 +321,9 @@ public:
e.resize(boundary.size(), true); e.resize(boundary.size(), true);
ClassicEncoding<T>::encoding_map[*i] = e; ClassicEncoding<T>::encoding_map[*i] = e;
} }
} }
virtual const string getType() const { return "order"; }
virtual const string getType() const {return "order";}
// returns (in string format) the test x[i] // returns (in string format) the test x[i]
const string getLabel(const int i, const int v) const { const string getLabel(const int i, const int v) const {
...@@ -352,11 +356,12 @@ public: ...@@ -352,11 +356,12 @@ public:
// v < ClassicEncoding<T>::value_set.end(); ++v) { // v < ClassicEncoding<T>::value_set.end(); ++v) {
// assert(*(v - 1) < *v); // assert(*(v - 1) < *v);
// } // }
size_t num_intervals{static_cast<size_t>(log(static_cast<double>(ClassicEncoding<T>::value_set.size())))}; size_t num_intervals{static_cast<size_t>(
log(static_cast<double>(ClassicEncoding<T>::value_set.size())))};
if(num_intervals < 1)
num_intervals = 1; if (num_intervals < 1)
num_intervals = 1;
vector<size_t> boundary; vector<size_t> boundary;
size_t i_size{ClassicEncoding<T>::value_set.size() / num_intervals}; size_t i_size{ClassicEncoding<T>::value_set.size() / num_intervals};
...@@ -435,8 +440,8 @@ public: ...@@ -435,8 +440,8 @@ public:
ClassicEncoding<T>::encoding_map[*i] = e; ClassicEncoding<T>::encoding_map[*i] = e;
} }
} }
virtual const string getType() const {return "direct";} virtual const string getType() const { return "direct"; }
// returns (in string format) the test x[i] // returns (in string format) the test x[i]
const string getLabel(const int i, const int v) const { const string getLabel(const int i, const int v) const {
...@@ -492,20 +497,20 @@ public: ...@@ -492,20 +497,20 @@ public:
/*!@name Constructors*/ /*!@name Constructors*/
//@{ //@{
explicit TypedDataSet() {} explicit TypedDataSet() {}
~TypedDataSet() { ~TypedDataSet() {
while(not int_encoder.empty()) { while (not int_encoder.empty()) {
delete int_encoder.back(); delete int_encoder.back();
int_encoder.pop_back(); int_encoder.pop_back();
} }
while(not float_encoder.empty()) { while (not float_encoder.empty()) {
delete float_encoder.back(); delete float_encoder.back();
float_encoder.pop_back(); float_encoder.pop_back();
} }
while(not symb_encoder.empty()) { while (not symb_encoder.empty()) {
delete symb_encoder.back(); delete symb_encoder.back();
symb_encoder.pop_back(); symb_encoder.pop_back();
} }
} }
//@} //@}
/*!@name Accessors*/ /*!@name Accessors*/
...@@ -622,12 +627,14 @@ public: ...@@ -622,12 +627,14 @@ public:
// int_encoder.push_back(enc); // int_encoder.push_back(enc);
// } // }
else { else {
enc = new Order<int>(int_value[feature_rank[f]].size()); enc = new Order<int>(int_value[feature_rank[f]].size(),
feature_label[f]);
enc->encode(int_buffer.begin(), int_buffer.end()); enc->encode(int_buffer.begin(), int_buffer.end());
int_encoder.push_back(enc); int_encoder.push_back(enc);
} }
// cout << "int (" << feature_rank[f] << "/" << int_encoder.size() << "):"; // cout << "int (" << feature_rank[f] << "/" << int_encoder.size() <<
// "):";
// for (auto v : int_buffer) // for (auto v : int_buffer)
// cout << "\n" << v << " -> " << enc->getEncoding(v); // cout << "\n" << v << " -> " << enc->getEncoding(v);
// cout << endl; // cout << endl;
...@@ -648,11 +655,9 @@ public: ...@@ -648,11 +655,9 @@ public:
// exit(1); // exit(1);
// cout << "constructor\n"; // cout << "constructor\n";
Encoding<float> *enc = new Order<float>(float_value[feature_rank[f]].size()); Encoding<float> *enc = new Order<float>(
float_value[feature_rank[f]].size(), feature_label[f]);
// cout << "encode\n"; // cout << "encode\n";
...@@ -732,13 +737,12 @@ public: ...@@ -732,13 +737,12 @@ public:
// instance db; // instance db;
// bin.duplicate_format(binex, db); // bin.duplicate_format(binex, db);
// bin.add(db, label[i] != min_label); // bin.add(db, label[i] != min_label);
bin.addBitsetExample(binex, label[i] != min_label); bin.addBitsetExample(binex, label[i] != min_label);
// cout << binex << endl; // cout << binex << endl;
} }
// cout << bin.example_count() << endl;
// cout << bin.example_count() << endl;
} }
std::ostream &display(std::ostream &os) const { std::ostream &display(std::ostream &os) const {
...@@ -779,23 +783,23 @@ public: ...@@ -779,23 +783,23 @@ public:
int r{feature_rank[f]}; int r{feature_rank[f]};
switch (t) { switch (t) {
case INTEGER: case INTEGER:
os << " " << int_encoder[r]->getType() ; os << " " << int_encoder[r]->getType();
if(int_encoder[r]->size() > 1) if (int_encoder[r]->size() > 1)
os << " " << int_encoder[r]->size(); os << " " << int_encoder[r]->size();
for (int j{0}; j < int_encoder[r]->size(); ++j) for (int j{0}; j < int_encoder[r]->size(); ++j)
os << " " << int_encoder[r]->getLabel(j); os << " " << int_encoder[r]->getLabel(j);
break; break;
case FLOAT: case FLOAT:
os << " " << float_encoder[r]->getType() ; os << " " << float_encoder[r]->getType();
if(float_encoder[r]->size() > 1) if (float_encoder[r]->size() > 1)
os << " " << float_encoder[r]->size(); os << " " << float_encoder[r]->size();
for (int j{0}; j < float_encoder[r]->size(); ++j) for (int j{0}; j < float_encoder[r]->size(); ++j)
os << " " << float_encoder[r]->getLabel(j); os << " " << float_encoder[r]->getLabel(j);
break; break;
case SYMBOL: case SYMBOL:
os << " " << symb_encoder[r]->getType() ; os << " " << symb_encoder[r]->getType();
if(symb_encoder[r]->size() > 1) if (symb_encoder[r]->size() > 1)
os << " " << symb_encoder[r]->size(); os << " " << symb_encoder[r]->size();
for (int j{0}; j < symb_encoder[r]->size(); ++j) for (int j{0}; j < symb_encoder[r]->size(); ++j)
os << " " << symb_encoder[r]->getLabel(j); os << " " << symb_encoder[r]->getLabel(j);
break; break;
......
...@@ -163,9 +163,6 @@ DTOptions blossom::parse_dt(int argc, char *argv[]) { ...@@ -163,9 +163,6 @@ DTOptions blossom::parse_dt(int argc, char *argv[]) {
cmd.add<SwitchArg>(opt.bounding, "", "nolb", "switch bound reasoning off", cmd.add<SwitchArg>(opt.bounding, "", "nolb", "switch bound reasoning off",
true); true);
cmd.add<SwitchArg>(opt.binarize, "", "binarize", "binarize the data set",
false);
cmd.add<ValueArg<double>>(opt.split, "", "split", "proportion of examples in the test set", cmd.add<ValueArg<double>>(opt.split, "", "split", "proportion of examples in the test set",
false, 0.0, "double"); false, 0.0, "double");
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment