Commit 82adf61b authored by Pierre Fernbach's avatar Pierre Fernbach
Browse files

Merge branch 'topic/pickle' into 'devel'

[Python] add pickle support

See merge request loco-3d/multicontact-api!13
parents eeab657f 7b2c17e4
Pipeline #9453 passed with stage
in 8 minutes and 12 seconds
......@@ -260,6 +260,7 @@ struct ContactSequencePythonVisitor : public bp::def_visitor<ContactSequencePyth
static bp::list getAllEffectorsInContactAsList(CS& self) {
return toPythonList<std::string>(self.getAllEffectorsInContact());
}
};
} // namespace python
} // namespace multicontact_api
......
......@@ -11,6 +11,26 @@ namespace python {
namespace bp = boost::python;
template <typename Derived>
struct cs_pickle_suite : bp::pickle_suite {
static bp::object getstate (const Derived& cs) {
std::ostringstream os;
boost::archive::text_oarchive oa(os);
oa << cs;
return bp::str(os.str());
}
static void
setstate(Derived& cs, bp::object entries) {
bp::str s = bp::extract<bp::str> (entries)();
std::string st = bp::extract<std::string> (s)();
std::istringstream is (st);
boost::archive::text_iarchive ia (is);
ia >> cs;
}
};
template <typename Derived>
struct SerializableVisitor : public boost::python::def_visitor<SerializableVisitor<Derived> > {
template <class PyClass>
......@@ -20,7 +40,8 @@ struct SerializableVisitor : public boost::python::def_visitor<SerializableVisit
.def("saveAsXML", &Derived::saveAsXML, bp::args("filename", "tag_name"), "Saves *this inside a XML file.")
.def("loadFromXML", &Derived::loadFromXML, bp::args("filename", "tag_name"), "Loads *this from a XML file.")
.def("saveAsBinary", &Derived::saveAsBinary, bp::args("filename"), "Saves *this inside a binary file.")
.def("loadFromBinary", &Derived::loadFromBinary, bp::args("filename"), "Loads *this from a binary file.");
.def("loadFromBinary", &Derived::loadFromBinary, bp::args("filename"), "Loads *this from a binary file.")
.def_pickle(cs_pickle_suite<Derived>());
}
};
} // namespace python
......
......@@ -11,7 +11,7 @@ from numpy import array, array_equal, isclose, random
import pinocchio as pin
from multicontact_api import ContactModelPlanar, ContactPatch, ContactPhase, ContactSequence
from pinocchio import SE3, Quaternion
import pickle
pin.switchToNumpyArray()
......@@ -182,6 +182,9 @@ class ContactModelTest(unittest.TestCase):
mp_xml = ContactModelPlanar()
mp_xml.loadFromXML("mp_test.xml", 'ContactPatch')
self.assertEqual(mp1, mp_xml)
mp_pickled = pickle.dumps(mp1)
mp_from_pickle = pickle.loads(mp_pickled)
self.assertEqual(mp1, mp_from_pickle)
def test_contact_model_serialization_full(self):
mu = 0.3
......@@ -200,6 +203,9 @@ class ContactModelTest(unittest.TestCase):
mp_xml = ContactModelPlanar()
mp_xml.loadFromXML("mp_test.xml", 'ContactPatch')
self.assertEqual(mp1, mp_xml)
mp_pickled = pickle.dumps(mp1)
mp_from_pickle = pickle.loads(mp_pickled)
self.assertEqual(mp1, mp_from_pickle)
class ContactPatchTest(unittest.TestCase):
......@@ -278,6 +284,9 @@ class ContactPatchTest(unittest.TestCase):
cp_xml = ContactPatch()
cp_xml.loadFromXML("cp_test.xml", 'ContactPatch')
self.assertEqual(cp1, cp_xml)
cp_pickled = pickle.dumps(cp1)
cp_from_pickle = pickle.loads(cp_pickled)
self.assertEqual(cp1, cp_from_pickle)
def test_serialization_full(self):
p = SE3()
......@@ -295,6 +304,9 @@ class ContactPatchTest(unittest.TestCase):
cp_xml = ContactPatch()
cp_xml.loadFromXML("cp_test.xml", 'ContactPatch')
self.assertEqual(cp1, cp_xml)
cp_pickled = pickle.dumps(cp1)
cp_from_pickle = pickle.loads(cp_pickled)
self.assertEqual(cp1, cp_from_pickle)
class ContactPhaseTest(unittest.TestCase):
......@@ -1144,6 +1156,9 @@ class ContactPhaseTest(unittest.TestCase):
cp_xml = ContactPhase()
cp_xml.loadFromXML("cp_test.xml", 'ContactPhase')
self.assertEqual(cp1, cp_xml)
cp_pickled = pickle.dumps(cp1)
cp_from_pickle = pickle.loads(cp_pickled)
self.assertEqual(cp1, cp_from_pickle)
def test_contact_phase_serialization_full(self):
cp1 = buildRandomContactPhase(0., 2.)
......@@ -1160,6 +1175,9 @@ class ContactPhaseTest(unittest.TestCase):
cp_xml.loadFromXML("cp_test_full.xml", 'ContactPhase')
self.assertEqual(cp1, cp_xml)
# TODO : check serialization from another file
cp_pickled = pickle.dumps(cp1)
cp_from_pickle = pickle.loads(cp_pickled)
self.assertEqual(cp1, cp_from_pickle)
def test_contact_phase_contacts_variation(self):
# # contacts repositioned :
......@@ -1741,6 +1759,15 @@ class ContactSequenceTest(unittest.TestCase):
with self.assertRaises(ValueError):
cs1.phaseAtTime(10.)
def test_pickle_contact_sequence(self):
cs = ContactSequence()
for i in range(10):
cp = buildRandomContactPhase(0., 2.)
cs.append(cp)
cs_pickled = pickle.dumps(cs)
cs_from_pickle = pickle.loads(cs_pickled)
self.assertEqual(cs_from_pickle, cs)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment