Commit 1e5a2800 authored by Louis Jean's avatar Louis Jean
Browse files

Add convenience methods to set example weights, change the wrapper to use the...

Add convenience methods to set example weights, change the wrapper to use the added methods, add an executable python script to run ./bud_first_search through python
parent f25f4d14
......@@ -35,9 +35,11 @@ opt = bfs.parse(bfs.to_str_vec(["--max_depth", "3"]))
opt.verbosity = bfs.Verbosity.YACKING
wood = bfs.Wood()
algo = bfs.BacktrackingAlgorithm(wood, opt)
# TODO change to algo.addExamples(test) ?
bfs.addExamples(algo, bfs.to_sample_vec(test))
algo = bfs.BacktrackingAlgo(wood, opt)
for sample in test:
algo.addExample(sample)
algo.minimize_error()
tree = algo.getSolution()
......
from .bud_first_search import *
from .wrapper import DTOptions, Wood, Tree, BacktrackingAlgorithm
from .wrapper import parse, addExamples
\ No newline at end of file
from .wrapper import DTOptions, Wood, Tree, BacktrackingAlgo
from .wrapper import parse, read_binary
\ No newline at end of file
import sys
import bud_first_search as bfs
test = bfs.TEST_SAMPLE
opt = bfs.parse(bfs.to_str_vec(sys.argv))
filename = str(opt.instance_file)
wood = bfs.Wood()
algo = bfs.BacktrackingAlgo(wood, opt)
if filename == "":
print("No input file!")
sys.exit(-1)
else:
bfs.read_binary(algo, opt)
algo.minimize_error()
tree = algo.getSolution()
nodes, edges = bfs.read_tree(tree)
print("nodes:", nodes,"\n\nedges: ", edges)
\ No newline at end of file
......@@ -32,16 +32,6 @@ def to_str_vec(str_list):
return vec
def to_sample_vec(samples):
samples_vec = wrapper.example_vec(len(samples))
for i in range(len(samples)):
sample = samples[i]
features_vec = wrapper.int_vec(sample[:-1])
samples_vec[i] = wrapper.Example(features_vec, sample[-1])
return samples_vec
def read_tree(tree):
nodes = []
edges = []
......@@ -83,8 +73,10 @@ class BudFirstSearch:
def fit(self, samples):
self.wood = wrapper.Wood()
self.algo = wrapper.BacktrackingAlgorithm(self.wood, self.opt)
wrapper.addExamples(self.algo, to_sample_vec(samples))
self.algo = wrapper.BacktrackingAlgo(self.wood, self.opt)
for sample in samples:
self.algo.addExample(sample)
self.algo.minimize_error()
self.tree = self.algo.getSolution()
......@@ -117,10 +109,3 @@ class BudFirstSearch:
return node["feat"]
TEST_SAMPLE = [[1, 0, 1], [1, 1, 0], [0, 1, 1], [0, 0, 0]]
if __name__ == "__main__":
b = BudFirstSearch()
b.opt.max_depth = 3
b.opt.verbosity = Verbosity.YACKING
b.opt.feature_strategy = FeatureStrategy.GINI
b.fit(TEST_SAMPLE)
......@@ -7,19 +7,5 @@
#define SWIG_FILE_WITH_INIT
// Dataset
struct Example {
std::vector<int> features;
int target;
Example() {}
Example(std::vector<int> features, int target)
: features(features), target(target) {}
};
extern void addExamples(primer::BacktrackingAlgorithm &algo, std::vector<Example> data);
extern DTOptions parse(std::vector<std::string> params);
extern void free(void *ptr);
extern void read_binary(primer::BacktrackingAlgorithm<> &A, DTOptions &opt);
......@@ -2,6 +2,9 @@
#include <iostream>
#include "CSVReader.hpp"
#include "TXTReader.hpp"
using namespace primer;
/*
......@@ -27,12 +30,6 @@ void addNode(Tree &tree, int node, Results &res) {
}
*/
void addExamples(primer::BacktrackingAlgorithm &algo, std::vector<Example> data) {
for (Example &example: data) {
algo.addExample(example.features.begin(), example.features.end(), example.target);
}
}
DTOptions parse(std::vector<std::string> params) {
std::vector<char*> cparams;
for (auto &param : params) {
......@@ -42,6 +39,23 @@ DTOptions parse(std::vector<std::string> params) {
return parse_dt(cparams.size(), &cparams[0]);
}
void free(void *ptr) {
delete ptr;
}
\ No newline at end of file
void read_binary(BacktrackingAlgorithm<> &A, DTOptions &opt) {
string ext{opt.instance_file.substr(opt.instance_file.find_last_of(".") + 1)};
if (opt.format == "csv" or (opt.format == "guess" and ext == "csv")) {
csv::read_binary(opt.instance_file, [&](vector<int> &data) {
A.addExample(data.begin(), data.end() - 1, data.back());
});
} else if (opt.format == "dl8" or (opt.format == "guess" and ext == "dl8")) {
txt::read_binary(opt.instance_file, [&](vector<int> &data) {
auto y = *data.begin();
A.addExample(data.begin() + 1, data.end(), y);
});
} else {
if (opt.format != "txt" and ext != "txt")
cout << "p Warning, unrecognized format, trying txt\n";
txt::read_binary(opt.instance_file, [&](vector<int> &data) {
A.addExample(data.begin(), data.end() - 1, data.back());
});
}
}
......@@ -6,22 +6,10 @@
#include "budFirstSearch.h"
%}
struct Example {
std::vector<int> features;
int target;
Example();
Example(std::vector<int> features, int target);
};
extern void addExamples(primer::BacktrackingAlgorithm &algo, std::vector<Example> data);
extern DTOptions parse(std::vector<std::string> params);
extern void free(void* ptr);
extern void read_binary(primer::BacktrackingAlgorithm<IntegerError<int>, int> &A, DTOptions &opt);
namespace std {
%template(example_vec) vector<Example>;
%template(int_vec) vector<int>;
%template(cstr_vec) vector<char*>;
%template(str_vec) vector<string>;
......@@ -31,6 +19,11 @@ namespace std {
class DTOptions {
public:
std::string instance_file;
std::string debug;
std::string output;
std::string format;
int verbosity;
int seed;
......@@ -94,6 +87,7 @@ namespace primer {
// BacktrackingAlgorithm
template <class Error, class ErrorType>
class BacktrackingAlgorithm {
public:
BacktrackingAlgorithm() = delete;
......@@ -102,5 +96,14 @@ namespace primer {
void minimize_error_depth();
void minimize_error_depth_size();
Tree getSolution();
void addExample(const std::vector<int> &example);
void addExample(const std::vector<int> &example, int weight);
};
template <class ErrorType> class IntegerError;
template <class ErrorType> class WeightedError;
%template(BacktrackingAlgo) BacktrackingAlgorithm<IntegerError<int>, int>;
%template(WeightedBacktrackingAlgo) BacktrackingAlgorithm<WeightedError<int>, int>;
%template(WeightedBacktrackingAlgod) BacktrackingAlgorithm<WeightedError<double>, double>;
}
......@@ -33,7 +33,7 @@ public:
/** This method is called everytime a new example is added to the dataset.
* \param i index of the added example */
void add_example(Algo &algo, const int y, const size_t i) {}
void add_example(Algo &algo, const int y, const size_t i, const E_t weight = 1) {}
E_t node_error(const Algo &algo, const int i) const;
......@@ -52,7 +52,7 @@ public:
/** This method is called everytime a new example is added to the dataset.
* \param i index of the added example */
void add_example(Algo &algo, const int y, const size_t i);
void add_example(Algo &algo, const int y, const size_t i, const E_t weight = 1);
void set_weight(const int y, const size_t i, const E_t weight);
......@@ -362,7 +362,9 @@ public:
E_t error() const;
template <class rIter>
void addExample(rIter beg_sample, rIter end_sample, const bool y);
void addExample(rIter beg_sample, rIter end_sample, const bool y, const E_t weight = 1);
void addExample(const std::vector<int> &example, const E_t weight = 1);
/*!@name Printing*/
//@{
......@@ -374,7 +376,7 @@ public:
template <class ErrorPolicy, typename E_t>
template <class rIter>
inline void BacktrackingAlgorithm<ErrorPolicy, E_t>::addExample(rIter beg_sample, rIter end_sample,
const bool y) {
const bool y, const E_t weight) {
int n{static_cast<int>(end_sample - beg_sample)};
if (n > num_feature) {
......@@ -407,7 +409,7 @@ inline void BacktrackingAlgorithm<ErrorPolicy, E_t>::addExample(rIter beg_sample
++k;
}
error_policy.add_example(*this, y, example[y].size() - 1);
error_policy.add_example(*this, y, example[y].size() - 1, weight);
// cout << endl;
}
......
......@@ -44,11 +44,11 @@ E_t IntegerError<E_t>::node_error(const IntegerError::Algo &algo, const int i) c
// ===== WeightedError
template <typename E_t>
void WeightedError<E_t>::add_example(WeightedError<E_t>::Algo &algo, const int y, const size_t i) {
void WeightedError<E_t>::add_example(WeightedError<E_t>::Algo &algo, const int y, const size_t i, const E_t weight) {
if (weights[y].size() <= i) {
weights[y].resize(i+1);
}
weights[y][i] = 1;
weights[y][i] = weight;
}
template <typename E_t>
......@@ -1326,6 +1326,11 @@ void BacktrackingAlgorithm<ErrorPolicy, E_t>::minimize_error_depth_size() {
print_new_best();
}
template <class ErrorPolicy, typename E_t>
void BacktrackingAlgorithm<ErrorPolicy, E_t>::addExample(const std::vector<int> &example, const E_t weight) {
addExample(example.begin(), example.end() - 1, example.back(), weight);
}
template <class ErrorPolicy, typename E_t>
bool BacktrackingAlgorithm<ErrorPolicy, E_t>::fail() {
for (auto b : blossom) {
......
......@@ -57,13 +57,13 @@ $(BIN)/%: $(MOD)/obj/%.o $(PLIBOBJ)
$(MOD)/obj/%.o: $(MOD)/src/%.cpp
@echo 'compile '$<
$(CCC) $(CFLAGS) -c $< -o $@
$(CCC) $(CFLAGS) -c $< -o $@
# Examples, one at a time
%: $(MOD)/obj/%.o $(PLIBOBJ)
@echo 'link '$<
@echo 'link '$<
$(CCC) $(CFLAGS) $(PLIBOBJ) $(LFLAGS) $< -lm -o $(BIN)/$@
wrapper: $(PLIBOBJ)
(cd bud_first_search && make)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment