Commit 8c8521da authored by ehebrard's avatar ehebrard
Browse files

automatic binarization

parent 05284541
......@@ -85,21 +85,13 @@ int main(int argc, char *argv[]) {
WeightedDataset<int> input;
////// READING
if (opt.binarize) {
// try {
// read_binary(input, opt);
// } catch (const std::exception &e) {
if (opt.verbosity >= DTOptions::NORMAL)
cout << "c format not recognized or input non-binary, binarizing...\n";
read_non_binary(input, opt);
} else {
read_binary(input, opt);
}
// std::function<bool(const int f)> relevant = [](const int f) { return true; };
// input.printDatasetToFile(cout, string(" "), string(""), relevant,
// opt.outtarget != -1, false);
// cout << input << endl;
// }
if (opt.sample < 1)
input.sample(opt.sample);
......
......@@ -38,13 +38,13 @@ int run_algorithm(DTOptions &opt) {
WeightedDataset<E_t> input;
////// READING
if (opt.binarize) {
read_non_binary(input, opt);
} else {
try {
read_binary(input, opt);
} catch (const std::exception &e) {
if (opt.verbosity >= DTOptions::NORMAL)
cout << "c format not recognized or input non-binary, binarizing...\n";
read_non_binary(input, opt);
}
if (opt.verbosity >= DTOptions::NORMAL)
......@@ -66,9 +66,6 @@ int run_algorithm(DTOptions &opt) {
<< endl;
////// SOLVING
cout << "opt.mindepth: " << opt.mindepth << " opt.minsize " << opt.minsize << endl;
if (opt.mindepth) {
if (opt.minsize)
A.minimize_error_depth_size();
......
......@@ -39,10 +39,12 @@ int run_algorithm(DTOptions &opt) {
WeightedDataset<E_t> input;
////// READING
if (opt.binarize) {
read_non_binary(input, opt);
} else {
try {
read_binary(input, opt);
} catch (const std::exception &e) {
if (opt.verbosity >= DTOptions::NORMAL)
cout << "c format not recognized or input non-binary, binarizing...\n";
read_non_binary(input, opt);
}
// in compilation, noise and duplicates must be removed (we should probably
......
......@@ -41,7 +41,7 @@ void read(const std::string &fn, header_declaration notify_header,
} catch (std::exception &e) {
std::cout.flush();
cerr << "ERROR: " << e.what() << std::endl;
exit(1);
throw e;
}
}
......@@ -77,7 +77,7 @@ void read_binary(const std::string &fn, data_declaration notify_data,
} catch (std::exception &e) {
std::cout.flush();
cerr << "ERROR: " << e.what() << std::endl;
exit(1);
throw e;
}
}
......
......@@ -58,7 +58,6 @@ public:
enum feature_strategy { MINERROR = 0, ENTROPY = 1, GINI = 2, HYBRID = 3 };
int feature_strategy;
bool binarize;
double split;
int ada_it;
int ada_stop;
......@@ -92,13 +91,13 @@ public:
restart_base(opt.restart_base), restart_factor(opt.restart_factor),
time(opt.time), search(opt.search), bounding(opt.bounding),
node_strategy(opt.node_strategy),
feature_strategy(opt.feature_strategy), binarize(opt.binarize),
split(opt.split), ada_it(opt.ada_it), ada_stop(opt.ada_stop),
filter(opt.filter), reference_class(opt.reference_class),
mindepth(opt.mindepth), minsize(opt.minsize),
preprocessing(opt.preprocessing), progress(opt.progress),
delimiter(opt.delimiter), intarget(opt.intarget),
outtarget(opt.outtarget), pruning(opt.pruning) {}
feature_strategy(opt.feature_strategy), split(opt.split),
ada_it(opt.ada_it), ada_stop(opt.ada_stop), filter(opt.filter),
reference_class(opt.reference_class), mindepth(opt.mindepth),
minsize(opt.minsize), preprocessing(opt.preprocessing),
progress(opt.progress), delimiter(opt.delimiter),
intarget(opt.intarget), outtarget(opt.outtarget), pruning(opt.pruning) {
}
ostream &display(ostream &os);
};
......
......@@ -134,7 +134,6 @@ protected:
vector<T> value_set;
public:
// virtual ~ClassicEncoding() {}
// encode the values of the iterator
......@@ -190,7 +189,7 @@ public:
}
}
virtual const string getType() const {return "binary-direct";}
virtual const string getType() const { return "binary-direct"; }
// returns (in string format) the test x[i]
const string getLabel(const int i, const int v) const {
......@@ -230,7 +229,7 @@ public:
}
}
virtual const string getType() const {return "binary-scaled";}
virtual const string getType() const { return "binary-scaled"; }
// returns (in string format) the test x[i]
const string getLabel(const int i, const int v) const { return "&=?"; }
......@@ -240,9 +239,11 @@ template <typename T> class Order : public ClassicEncoding<T> {
private:
size_t num_examples;
const string &feature_name;
public:
Order(const size_t n) : ClassicEncoding<T>(), num_examples(n) {}
Order(const size_t n, const string &f)
: ClassicEncoding<T>(), num_examples(n), feature_name(f) {}
virtual size_t size() const {
return ClassicEncoding<T>::value_set.size() - 1;
......@@ -262,14 +263,23 @@ public:
// cout << ClassicEncoding<T>::value_set.size() << " " << (num_examples) ;
if(ClassicEncoding<T>::value_set.size() < sqrt(num_examples)) {
if (ClassicEncoding<T>::value_set.size() < sqrt(num_examples)) {
// cout << " full\n";
full_encoding();
} else {
// cout << " reduced\n";
reduced_encoding();
size_t num_intervals{static_cast<size_t>(
log(static_cast<double>(ClassicEncoding<T>::value_set.size())))};
if (num_intervals < 1)
num_intervals = 1;
cout << "c possible precision loss when binarizing feature "
<< feature_name << " (" << ClassicEncoding<T>::value_set.size()
<< " distinct values -> " << num_intervals << " intervals)"
<< "\n";
reduced_encoding(num_intervals);
}
}
......@@ -284,14 +294,9 @@ public:
e.resize(ve - vb, true);
ClassicEncoding<T>::encoding_map[*i] = e;
}
}
void reduced_encoding() {
size_t num_intervals{static_cast<size_t>(log(static_cast<double>(ClassicEncoding<T>::value_set.size())))};
if(num_intervals < 1)
num_intervals = 1;
void reduced_encoding(const size_t num_intervals) {
vector<size_t> boundary;
size_t i_size{ClassicEncoding<T>::value_set.size() / num_intervals};
......@@ -318,8 +323,7 @@ public:
}
}
virtual const string getType() const {return "order";}
virtual const string getType() const { return "order"; }
// returns (in string format) the test x[i]
const string getLabel(const int i, const int v) const {
......@@ -353,9 +357,10 @@ public:
// assert(*(v - 1) < *v);
// }
size_t num_intervals{static_cast<size_t>(log(static_cast<double>(ClassicEncoding<T>::value_set.size())))};
size_t num_intervals{static_cast<size_t>(
log(static_cast<double>(ClassicEncoding<T>::value_set.size())))};
if(num_intervals < 1)
if (num_intervals < 1)
num_intervals = 1;
vector<size_t> boundary;
......@@ -436,7 +441,7 @@ public:
}
}
virtual const string getType() const {return "direct";}
virtual const string getType() const { return "direct"; }
// returns (in string format) the test x[i]
const string getLabel(const int i, const int v) const {
......@@ -493,15 +498,15 @@ public:
//@{
explicit TypedDataSet() {}
~TypedDataSet() {
while(not int_encoder.empty()) {
while (not int_encoder.empty()) {
delete int_encoder.back();
int_encoder.pop_back();
}
while(not float_encoder.empty()) {
while (not float_encoder.empty()) {
delete float_encoder.back();
float_encoder.pop_back();
}
while(not symb_encoder.empty()) {
while (not symb_encoder.empty()) {
delete symb_encoder.back();
symb_encoder.pop_back();
}
......@@ -622,12 +627,14 @@ public:
// int_encoder.push_back(enc);
// }
else {
enc = new Order<int>(int_value[feature_rank[f]].size());
enc = new Order<int>(int_value[feature_rank[f]].size(),
feature_label[f]);
enc->encode(int_buffer.begin(), int_buffer.end());
int_encoder.push_back(enc);
}
// cout << "int (" << feature_rank[f] << "/" << int_encoder.size() << "):";
// cout << "int (" << feature_rank[f] << "/" << int_encoder.size() <<
// "):";
// for (auto v : int_buffer)
// cout << "\n" << v << " -> " << enc->getEncoding(v);
// cout << endl;
......@@ -649,10 +656,8 @@ public:
// cout << "constructor\n";
Encoding<float> *enc = new Order<float>(float_value[feature_rank[f]].size());
Encoding<float> *enc = new Order<float>(
float_value[feature_rank[f]].size(), feature_label[f]);
// cout << "encode\n";
......@@ -738,7 +743,6 @@ public:
}
// cout << bin.example_count() << endl;
}
std::ostream &display(std::ostream &os) const {
......@@ -779,22 +783,22 @@ public:
int r{feature_rank[f]};
switch (t) {
case INTEGER:
os << " " << int_encoder[r]->getType() ;
if(int_encoder[r]->size() > 1)
os << " " << int_encoder[r]->getType();
if (int_encoder[r]->size() > 1)
os << " " << int_encoder[r]->size();
for (int j{0}; j < int_encoder[r]->size(); ++j)
os << " " << int_encoder[r]->getLabel(j);
break;
case FLOAT:
os << " " << float_encoder[r]->getType() ;
if(float_encoder[r]->size() > 1)
os << " " << float_encoder[r]->getType();
if (float_encoder[r]->size() > 1)
os << " " << float_encoder[r]->size();
for (int j{0}; j < float_encoder[r]->size(); ++j)
os << " " << float_encoder[r]->getLabel(j);
break;
case SYMBOL:
os << " " << symb_encoder[r]->getType() ;
if(symb_encoder[r]->size() > 1)
os << " " << symb_encoder[r]->getType();
if (symb_encoder[r]->size() > 1)
os << " " << symb_encoder[r]->size();
for (int j{0}; j < symb_encoder[r]->size(); ++j)
os << " " << symb_encoder[r]->getLabel(j);
......
......@@ -163,9 +163,6 @@ DTOptions blossom::parse_dt(int argc, char *argv[]) {
cmd.add<SwitchArg>(opt.bounding, "", "nolb", "switch bound reasoning off",
true);
cmd.add<SwitchArg>(opt.binarize, "", "binarize", "binarize the data set",
false);
cmd.add<ValueArg<double>>(opt.split, "", "split", "proportion of examples in the test set",
false, 0.0, "double");
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment