Skip to content

Commit

Permalink
Migrate to pyo3 v0.21 Bound API (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX committed Mar 10, 2024
1 parent 8365eab commit 8be60ee
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ name = "libipld"
crate-type = ["rlib", "cdylib"]

[dependencies]
pyo3 = { version = "0.20", features = ["generate-import-lib", "anyhow"] }
pyo3 = { version = "0.21.0-beta.0", features = ["generate-import-lib", "anyhow"] }
python3-dll-a = "0.2.7"
anyhow = "1.0.75"
futures = "0.3"
Expand Down
2 changes: 1 addition & 1 deletion profiling/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ structopt = "0.3.26"
clap = "4.5.1"

[dependencies.pyo3]
version = "0.20"
version = "0.21.0-beta.0"
98 changes: 51 additions & 47 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,43 @@ use pyo3::{PyObject, Python};
use pyo3::conversion::ToPyObject;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::pybacked::PyBackedStr;

fn car_header_to_pydict<'py>(py: Python<'py>, header: &CarHeader) -> &'py PyDict {
let dict_obj = PyDict::new(py);
fn car_header_to_pydict<'py>(py: Python<'py>, header: &CarHeader) -> Bound<'py, PyDict> {
let dict_obj = PyDict::new_bound(py);

dict_obj.set_item("version", header.version()).unwrap();

let roots = PyList::empty(py);
let roots = PyList::empty_bound(py);
header.roots().iter().for_each(|cid| {
let cid_obj = cid.to_string().to_object(py);
roots.append(cid_obj).unwrap();
});

dict_obj.set_item("roots", roots).unwrap();

dict_obj.into()
dict_obj
}

fn cid_hash_to_pydict<'py>(py: Python<'py>, cid: &Cid) -> &'py PyDict {
fn cid_hash_to_pydict<'py>(py: Python<'py>, cid: &Cid) -> Bound<'py, PyDict> {
let hash = cid.hash();
let dict_obj = PyDict::new(py);
let dict_obj = PyDict::new_bound(py);

dict_obj.set_item("code", hash.code()).unwrap();
dict_obj.set_item("size", hash.size()).unwrap();
dict_obj.set_item("digest", PyBytes::new(py, &hash.digest())).unwrap();
dict_obj.set_item("digest", PyBytes::new_bound(py, &hash.digest())).unwrap();

dict_obj.into()
dict_obj
}

fn cid_to_pydict<'py>(py: Python<'py>, cid: &Cid) -> &'py PyDict {
let dict_obj = PyDict::new(py);
fn cid_to_pydict<'py>(py: Python<'py>, cid: &Cid) -> Bound<'py, PyDict> {
let dict_obj = PyDict::new_bound(py);

dict_obj.set_item("version", cid.version() as u64).unwrap();
dict_obj.set_item("codec", cid.codec()).unwrap();
dict_obj.set_item("hash", cid_hash_to_pydict(py, cid)).unwrap();

dict_obj.into()
dict_obj
}

fn decode_len(len: u64) -> Result<usize> {
Expand All @@ -65,15 +66,17 @@ fn map_key_cmp(a: &str, b: &str) -> std::cmp::Ordering {
}
}

fn sort_map_keys(keys: &PySequence, len: usize) -> Vec<(&str, usize)> {
fn sort_map_keys(keys: Bound<PySequence>, len: usize) -> Vec<(PyBackedStr, usize)> {
// Returns key and index.
let mut keys_str = Vec::with_capacity(len);
for i in 0..len {
let key: &PyString = keys.get_item(i).unwrap().downcast().unwrap();
keys_str.push((key.to_str().unwrap(), i));
let item = keys.get_item(i).unwrap();
let key = item.downcast::<PyString>().unwrap().to_owned();
let backed_str = PyBackedStr::try_from(key).unwrap();
keys_str.push((backed_str, i));
}

keys_str.sort_by(|a, b| {
keys_str.sort_by(|a, b| { // sort_unstable_by performs bad
let (s1, _) = a;
let (s2, _) = b;

Expand All @@ -90,24 +93,25 @@ fn decode_dag_cbor_to_pyobject<R: Read + Seek>(py: Python, r: &mut R, deep: usiz
MajorKind::NegativeInt => (-1 - decode::read_uint(r, major)? as i64).to_object(py),
MajorKind::ByteString => {
let len = decode::read_uint(r, major)?;
PyBytes::new(py, &decode::read_bytes(r, len)?).into()
PyBytes::new_bound(py, &decode::read_bytes(r, len)?).into()
}
MajorKind::TextString => {
let len = decode::read_uint(r, major)?;
decode::read_str(r, len)?.to_object(py)
}
MajorKind::Array => {
let len = decode_len(decode::read_uint(r, major)?)?;
// TODO (MarshalX): how to init list with capacity?
let list = PyList::empty(py);
let list = PyList::empty_bound(py);

for _ in 0..len {
list.append(decode_dag_cbor_to_pyobject(py, r, deep + 1)?).unwrap();
}

list.into()
}
MajorKind::Map => {
let len = decode_len(decode::read_uint(r, major)?)?;
let dict = PyDict::new(py);
let dict = PyDict::new_bound(py);

let mut prev_key: Option<String> = None;
for _ in 0..len {
Expand Down Expand Up @@ -135,6 +139,7 @@ fn decode_dag_cbor_to_pyobject<R: Read + Seek>(py: Python, r: &mut R, deep: usiz
let value = decode_dag_cbor_to_pyobject(py, r, deep + 1)?;
dict.set_item(key_py, value).unwrap();
}

dict.into()
}
MajorKind::Tag => {
Expand All @@ -157,7 +162,7 @@ fn decode_dag_cbor_to_pyobject<R: Read + Seek>(py: Python, r: &mut R, deep: usiz
Ok(py_object)
}

fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny, w: &mut W) -> Result<()> {
fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: Bound<'py, PyAny>, w: &mut W) -> Result<()> {
/* Order is important for performance!
Fast checks go first:
Expand All @@ -177,7 +182,7 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny

Ok(())
} else if obj.is_instance_of::<PyBool>() {
let buf = if obj.is_true()? { [cbor::TRUE.into()] } else { [cbor::FALSE.into()] };
let buf = if obj.is_truthy()? { [cbor::TRUE.into()] } else { [cbor::FALSE.into()] };
w.write_all(&buf)?;

Ok(())
Expand All @@ -192,7 +197,7 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny

Ok(())
} else if obj.is_instance_of::<PyList>() {
let seq: &PySequence = obj.downcast().unwrap();
let seq = obj.downcast::<PySequence>().unwrap();
let len = obj.len()?;

encode::write_u64(w, MajorKind::Array, len as u64)?;
Expand All @@ -203,15 +208,14 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny

Ok(())
} else if obj.is_instance_of::<PyDict>() {
let map: &PyMapping = obj.downcast().unwrap();
let keys = map.keys()?;
let values = map.values()?;
let map = obj.downcast::<PyMapping>().unwrap();
let len = map.len()?;
let keys = sort_map_keys(map.keys()?, len);
let values = map.values()?;

encode::write_u64(w, MajorKind::Map, len as u64)?;

let sorted_keys = sort_map_keys(&keys, len);
for (key, i) in sorted_keys {
for (key, i) in keys {
let key_buf = key.as_bytes();
encode::write_u64(w, MajorKind::TextString, key_buf.len() as u64)?;
w.write_all(key_buf)?;
Expand All @@ -221,7 +225,7 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny

Ok(())
} else if obj.is_instance_of::<PyFloat>() {
let f: &PyFloat = obj.downcast().unwrap();
let f = obj.downcast::<PyFloat>().unwrap();
let v = f.value();

if !v.is_finite() {
Expand All @@ -234,15 +238,15 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny

Ok(())
} else if obj.is_instance_of::<PyBytes>() {
let b: &PyBytes = obj.downcast().unwrap();
let b = obj.downcast::<PyBytes>().unwrap();
let l: u64 = b.len()? as u64;

encode::write_u64(w, MajorKind::ByteString, l)?;
w.write_all(b.as_bytes())?;

Ok(())
} else if obj.is_instance_of::<PyString>() {
let s: &PyString = obj.downcast().unwrap();
let s = obj.downcast::<PyString>().unwrap();

// FIXME (MarshalX): it's not efficient to try to parse it as CID
let cid = Cid::try_from(s.to_str()?);
Expand Down Expand Up @@ -271,9 +275,9 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny
}

#[pyfunction]
fn decode_dag_cbor_multi<'py>(py: Python<'py>, data: &[u8]) -> PyResult<&'py PyList> {
fn decode_dag_cbor_multi<'py>(py: Python<'py>, data: &[u8]) -> PyResult<PyObject> {
let mut reader = BufReader::new(Cursor::new(data));
let decoded_parts = PyList::empty(py);
let decoded_parts = PyList::empty_bound(py);

loop {
let py_object = decode_dag_cbor_to_pyobject(py, &mut reader, 0);
Expand All @@ -284,11 +288,11 @@ fn decode_dag_cbor_multi<'py>(py: Python<'py>, data: &[u8]) -> PyResult<&'py PyL
}
}

Ok(decoded_parts)
Ok(decoded_parts.into())
}

#[pyfunction]
pub fn decode_car<'py>(py: Python<'py>, data: &[u8]) -> PyResult<(&'py PyDict, &'py PyDict)> {
pub fn decode_car<'py>(py: Python<'py>, data: &[u8]) -> PyResult<(PyObject, PyObject)> {
let car_response = executor::block_on(CarReader::new(data));
if let Err(e) = car_response {
return Err(get_err("Failed to decode CAR", e.to_string()));
Expand All @@ -297,7 +301,7 @@ pub fn decode_car<'py>(py: Python<'py>, data: &[u8]) -> PyResult<(&'py PyDict, &
let car = car_response.unwrap();

let header = car_header_to_pydict(py, car.header());
let parsed_blocks = PyDict::new(py);
let parsed_blocks = PyDict::new_bound(py);

let blocks: Vec<Result<(Cid, Vec<u8>), CarError>> = executor::block_on(car.stream().collect());
blocks.into_iter().for_each(|block| {
Expand All @@ -310,7 +314,7 @@ pub fn decode_car<'py>(py: Python<'py>, data: &[u8]) -> PyResult<(&'py PyDict, &
}
});

Ok((header, parsed_blocks))
Ok((header.into(), parsed_blocks.into()))
}

#[pyfunction]
Expand All @@ -324,22 +328,22 @@ fn decode_dag_cbor(py: Python, data: &[u8]) -> PyResult<PyObject> {
}

#[pyfunction]
fn encode_dag_cbor<'py>(py: Python<'py>, data: &PyAny) -> PyResult<&'py PyBytes> {
fn encode_dag_cbor<'py>(py: Python<'py>, data: Bound<'py, PyAny>) -> PyResult<PyObject> {
let mut buf = &mut BufWriter::new(Vec::new());
if let Err(e) = encode_dag_cbor_from_pyobject(py, data, &mut buf) {
return Err(get_err("Failed to encode DAG-CBOR", e.to_string()));
}
if let Err(e) = buf.flush() {
return Err(get_err("Failed to flush buffer", e.to_string()));
}
Ok(PyBytes::new(py, &buf.get_ref()))
Ok(PyBytes::new_bound(py, &buf.get_ref()).into())
}

#[pyfunction]
fn decode_cid(py: Python, data: String) -> PyResult<&PyDict> {
fn decode_cid(py: Python, data: String) -> PyResult<PyObject> {
let cid = Cid::try_from(data.as_str());
if let Ok(cid) = cid {
Ok(cid_to_pydict(py, &cid))
Ok(cid_to_pydict(py, &cid).into())
} else {
Err(get_err("Failed to decode CID", cid.unwrap_err().to_string()))
}
Expand All @@ -349,23 +353,23 @@ fn decode_cid(py: Python, data: String) -> PyResult<&PyDict> {
fn decode_multibase(py: Python, data: String) -> PyResult<(char, PyObject)> {
let base = multibase::decode(data);
if let Ok((base, data)) = base {
Ok((base.code(), PyBytes::new(py, &data).into()))
Ok((base.code(), PyBytes::new_bound(py, &data).into()))
} else {
Err(get_err("Failed to decode multibase", base.unwrap_err().to_string()))
}
}

#[pyfunction]
fn encode_multibase(code: char, data: &PyAny) -> PyResult<String> {
fn encode_multibase(code: char, data: Bound<PyAny>) -> PyResult<String> {
let data_bytes: &[u8];
if data.is_instance_of::<PyBytes>() {
let b: &PyBytes = data.downcast().unwrap();
let b = data.downcast::<PyBytes>().unwrap();
data_bytes = b.as_bytes();
} else if data.is_instance_of::<PyByteArray>() {
let b: &PyByteArray = data.downcast().unwrap();
data_bytes = unsafe { b.as_bytes() };
let ba = data.downcast::<PyByteArray>().unwrap();
data_bytes = unsafe { ba.as_bytes() };
} else if data.is_instance_of::<PyString>() {
let s: &PyString = data.downcast().unwrap();
let s = data.downcast::<PyString>().unwrap();
data_bytes = s.to_str()?.as_bytes();
} else {
return Err(get_err("Failed to encode multibase", "Unsupported data type".to_string()));
Expand All @@ -384,7 +388,7 @@ fn get_err(msg: &str, err: String) -> PyErr {
}

#[pymodule]
fn libipld(_py: Python, m: &PyModule) -> PyResult<()> {
fn libipld(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(decode_cid, m)?)?;
m.add_function(wrap_pyfunction!(decode_car, m)?)?;
m.add_function(wrap_pyfunction!(decode_dag_cbor, m)?)?;
Expand Down

0 comments on commit 8be60ee

Please sign in to comment.