Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src_cpp/include/cached_import/py_cached_modules.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@ class DecimalCachedItem : public PythonCachedItem {
class ImportLibCachedItem : public PythonCachedItem {
class UtilCachedItem : public PythonCachedItem {
public:
explicit UtilCachedItem(PythonCachedItem* parent)
: PythonCachedItem{"util", parent}, find_spec{"find_spec", this} {}
UtilCachedItem() : PythonCachedItem{"importlib.util"}, find_spec{"find_spec", this} {}

PythonCachedItem find_spec;
};

public:
ImportLibCachedItem() : PythonCachedItem("importlib"), util(this) {}
ImportLibCachedItem() : PythonCachedItem("importlib"), util() {}

UtilCachedItem util;
};
Expand Down
146 changes: 146 additions & 0 deletions src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "include/py_udf.h"
#include "main/connection.h"
#include "main/query_result/materialized_query_result.h"
#include "numpy/numpy_type.h"
#include "pandas/pandas_scan.h"
#include "processor/result/factorized_table.h"
#include "pyarrow/pyarrow_scan.h"
Expand Down Expand Up @@ -388,6 +389,47 @@ bool integerFitsIn<uint8_t>(int64_t val) {
return val >= 0 && val <= UINT8_MAX;
}

static LogicalType pyHomogeneousListType(const py::list& lst) {
py::handle firstNonNull;
for (auto child : lst) {
if (!child.is_none()) {
firstNonNull = child;
break;
}
}
if (!firstNonNull) {
return LogicalType::LIST(LogicalType::ANY());
}
if (!py::isinstance<py::bool_>(firstNonNull) && !py::isinstance<py::int_>(firstNonNull) &&
!py::isinstance<py::float_>(firstNonNull)) {
return LogicalType::ANY();
}
for (auto child : lst) {
if (child.is_none()) {
continue;
}
if (child.get_type().ptr() != firstNonNull.get_type().ptr()) {
return LogicalType::ANY();
}
}
if (py::isinstance<py::bool_>(firstNonNull)) {
return LogicalType::LIST(LogicalType::BOOL());
}
if (py::isinstance<py::int_>(firstNonNull)) {
return LogicalType::LIST(LogicalType::INT64());
}
return LogicalType::LIST(LogicalType::DOUBLE());
}

static LogicalType pyNumpyArrayLogicalType(const py::array& arr) {
auto npType = NumpyTypeUtils::convertNumpyType(arr.attr("dtype"));
auto type = NumpyTypeUtils::numpyToLogicalType(npType);
for (auto i = 0; i < arr.ndim(); ++i) {
type = LogicalType::LIST(std::move(type));
}
return type;
}

static LogicalType pyLogicalType(const py::handle& val) {
auto datetime_datetime = importCache->datetime.datetime();
auto time_delta = importCache->datetime.timedelta();
Expand Down Expand Up @@ -468,8 +510,14 @@ static LogicalType pyLogicalType(const py::handle& val) {
childValueType = std::move(resultValue);
}
return LogicalType::MAP(std::move(childKeyType), std::move(childValueType));
} else if (py::isinstance<py::array>(val)) {
return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val));
} else if (py::isinstance<py::list>(val)) {
py::list lst = py::reinterpret_borrow<py::list>(val);
auto homogeneousType = pyHomogeneousListType(lst);
if (homogeneousType.getLogicalTypeID() != LogicalTypeID::ANY) {
return homogeneousType;
}
auto childType = LogicalType::ANY();
for (auto child : lst) {
auto curChildType = pyLogicalType(child);
Expand Down Expand Up @@ -568,8 +616,14 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) {
structFields.emplace_back(std::move(keyName), std::move(keyType));
}
return LogicalType::STRUCT(std::move(structFields));
} else if (py::isinstance<py::array>(val)) {
return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val));
} else if (py::isinstance<py::list>(val)) {
py::list lst = py::reinterpret_borrow<py::list>(val);
auto homogeneousType = pyHomogeneousListType(lst);
if (homogeneousType.getLogicalTypeID() != LogicalTypeID::ANY) {
return homogeneousType;
}
auto childType = LogicalType::ANY();
for (auto child : lst) {
auto curChildType = pyLogicalTypeFromParameter(child);
Expand Down Expand Up @@ -603,6 +657,90 @@ static std::string pythonObjectToJsonString(const py::handle& val) {
return py::cast<std::string>(jsonStr);
}

template<typename T>
static Value transformNumpyScalarAs(const void* ptr, const LogicalType& type) {
auto value = *reinterpret_cast<const T*>(ptr);
switch (type.getLogicalTypeID()) {
case LogicalTypeID::BOOL:
return Value::createValue<bool>(static_cast<bool>(value));
case LogicalTypeID::INT64:
return Value::createValue<int64_t>(static_cast<int64_t>(value));
case LogicalTypeID::UINT32:
return Value::createValue<uint32_t>(static_cast<uint32_t>(value));
case LogicalTypeID::INT32:
return Value::createValue<int32_t>(static_cast<int32_t>(value));
case LogicalTypeID::UINT16:
return Value::createValue<uint16_t>(static_cast<uint16_t>(value));
case LogicalTypeID::INT16:
return Value::createValue<int16_t>(static_cast<int16_t>(value));
case LogicalTypeID::UINT8:
return Value::createValue<uint8_t>(static_cast<uint8_t>(value));
case LogicalTypeID::INT8:
return Value::createValue<int8_t>(static_cast<int8_t>(value));
case LogicalTypeID::FLOAT:
return Value(static_cast<float>(value));
case LogicalTypeID::DOUBLE:
return Value::createValue<double>(static_cast<double>(value));
default:
throw RuntimeException("Unsupported numpy ndarray parameter child type " + type.toString());
}
}

static Value transformNumpyScalarAs(const void* ptr, NumpyNullableType npType,
const LogicalType& type) {
switch (npType) {
case NumpyNullableType::BOOL:
return transformNumpyScalarAs<bool>(ptr, type);
case NumpyNullableType::INT_8:
return transformNumpyScalarAs<int8_t>(ptr, type);
case NumpyNullableType::UINT_8:
return transformNumpyScalarAs<uint8_t>(ptr, type);
case NumpyNullableType::INT_16:
return transformNumpyScalarAs<int16_t>(ptr, type);
case NumpyNullableType::UINT_16:
return transformNumpyScalarAs<uint16_t>(ptr, type);
case NumpyNullableType::INT_32:
return transformNumpyScalarAs<int32_t>(ptr, type);
case NumpyNullableType::UINT_32:
return transformNumpyScalarAs<uint32_t>(ptr, type);
case NumpyNullableType::INT_64:
return transformNumpyScalarAs<int64_t>(ptr, type);
case NumpyNullableType::UINT_64:
return transformNumpyScalarAs<uint64_t>(ptr, type);
case NumpyNullableType::FLOAT_32:
return transformNumpyScalarAs<float>(ptr, type);
case NumpyNullableType::FLOAT_64:
return transformNumpyScalarAs<double>(ptr, type);
default:
throw RuntimeException("Unsupported numpy ndarray parameter dtype");
}
}

static Value transformNumpyArrayAs(const LogicalType& type, uint64_t dimension, const uint8_t* ptr,
const py::buffer_info& info, NumpyNullableType npType) {
if (dimension == static_cast<uint64_t>(info.ndim)) {
return transformNumpyScalarAs(ptr, npType, type);
}
if (type.getLogicalTypeID() != LogicalTypeID::LIST) {
throw RuntimeException("Cannot convert numpy ndarray parameter to " + type.toString());
}
std::vector<std::unique_ptr<Value>> children;
children.reserve(info.shape[dimension]);
const auto& childType = ListType::getChildType(type);
for (auto i = 0; i < info.shape[dimension]; ++i) {
auto childPtr = ptr + i * info.strides[dimension];
children.push_back(std::make_unique<Value>(
transformNumpyArrayAs(childType, dimension + 1, childPtr, info, npType)));
}
return Value(type.copy(), std::move(children));
}

static Value transformNumpyArrayAs(const py::array& arr, const LogicalType& type) {
auto info = arr.request();
auto npType = NumpyTypeUtils::convertNumpyType(arr.attr("dtype")).type;
return transformNumpyArrayAs(type, 0, reinterpret_cast<const uint8_t*>(info.ptr), info, npType);
}

Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalType& type) {
// ignore the type of the actual python object, just directly cast
auto datetime_datetime = importCache->datetime.datetime();
Expand Down Expand Up @@ -632,6 +770,8 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT
return Value::createValue<int8_t>(py::cast<py::int_>(val).cast<int8_t>());
case LogicalTypeID::DOUBLE:
return Value::createValue<double>(py::cast<py::float_>(val).cast<double>());
case LogicalTypeID::FLOAT:
return Value(py::cast<py::float_>(val).cast<float>());
case LogicalTypeID::DECIMAL: {
auto str = py::cast<std::string>(py::str(val));
int128_t result = 0;
Expand Down Expand Up @@ -708,6 +848,9 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT
return Value{uuidToAppend};
}
case LogicalTypeID::LIST: {
if (py::isinstance<py::array>(val)) {
return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type);
}
py::list lst = py::reinterpret_borrow<py::list>(val);
std::vector<std::unique_ptr<Value>> children;
for (auto child : lst) {
Expand Down Expand Up @@ -763,6 +906,9 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val,
auto jsonStr = pythonObjectToJsonString(val);
return Value::createValue<std::string>(jsonStr);
}
if (py::isinstance<py::array>(val)) {
return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type);
}
py::list lst = py::reinterpret_borrow<py::list>(val);
std::vector<std::unique_ptr<Value>> children;
for (auto child : lst) {
Expand Down
90 changes: 90 additions & 0 deletions src_py/_lbug_capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ def _ensure_arrow_atexit_cleanup() -> None:
_LBUG_MAP = 55
_LBUG_UNION = 56
_LBUG_UUID = 59
_NUMPY_MODULE: Any | None = None
_NUMPY_IMPORT_ATTEMPTED = False


def _setup_signatures() -> None:
Expand Down Expand Up @@ -392,6 +394,16 @@ def _setup_signatures() -> None:
_LIB.lbug_value_create_int32.restype = ctypes.POINTER(_LbugValue)
_LIB.lbug_value_create_int64.argtypes = [ctypes.c_int64]
_LIB.lbug_value_create_int64.restype = ctypes.POINTER(_LbugValue)
_LIB.lbug_value_create_uint8.argtypes = [ctypes.c_uint8]
_LIB.lbug_value_create_uint8.restype = ctypes.POINTER(_LbugValue)
_LIB.lbug_value_create_uint16.argtypes = [ctypes.c_uint16]
_LIB.lbug_value_create_uint16.restype = ctypes.POINTER(_LbugValue)
_LIB.lbug_value_create_uint32.argtypes = [ctypes.c_uint32]
_LIB.lbug_value_create_uint32.restype = ctypes.POINTER(_LbugValue)
_LIB.lbug_value_create_uint64.argtypes = [ctypes.c_uint64]
_LIB.lbug_value_create_uint64.restype = ctypes.POINTER(_LbugValue)
_LIB.lbug_value_create_float.argtypes = [ctypes.c_float]
_LIB.lbug_value_create_float.restype = ctypes.POINTER(_LbugValue)
_LIB.lbug_value_create_double.argtypes = [ctypes.c_double]
_LIB.lbug_value_create_double.restype = ctypes.POINTER(_LbugValue)
_LIB.lbug_value_create_string.argtypes = [ctypes.c_char_p]
Expand Down Expand Up @@ -930,11 +942,89 @@ def _parse_rendered_value(value: str) -> Any:
return value


def _numpy_module() -> Any | None:
global _NUMPY_IMPORT_ATTEMPTED, _NUMPY_MODULE
if _NUMPY_IMPORT_ATTEMPTED:
return _NUMPY_MODULE
_NUMPY_IMPORT_ATTEMPTED = True
try:
import numpy as np
except ModuleNotFoundError:
return None
_NUMPY_MODULE = np
return np


def _is_numpy_scalar(value: Any) -> bool:
np = _numpy_module()
return bool(np is not None and isinstance(value, np.generic))


def _is_numpy_array(value: Any) -> bool:
np = _numpy_module()
return bool(np is not None and isinstance(value, np.ndarray))


def _numpy_scalar_value_from_python(value: Any) -> ctypes.POINTER(_LbugValue):
dtype = value.dtype
kind = dtype.kind
item = value.item()
if kind == "b":
return _LIB.lbug_value_create_bool(bool(item))
if kind == "i":
if dtype.itemsize == 1:
return _LIB.lbug_value_create_int8(item)
if dtype.itemsize == 2:
return _LIB.lbug_value_create_int16(item)
if dtype.itemsize == 4:
return _LIB.lbug_value_create_int32(item)
return _LIB.lbug_value_create_int64(item)
if kind == "u":
if dtype.itemsize == 1:
return _LIB.lbug_value_create_uint8(item)
if dtype.itemsize == 2:
return _LIB.lbug_value_create_uint16(item)
if dtype.itemsize == 4:
return _LIB.lbug_value_create_uint32(item)
return _LIB.lbug_value_create_uint64(item)
if kind == "f":
if dtype.itemsize == 4:
return _LIB.lbug_value_create_float(item)
return _LIB.lbug_value_create_double(item)

return _value_from_python(item)


def _numpy_array_value_from_python(value: Any) -> ctypes.POINTER(_LbugValue):
if value.ndim == 0:
return _numpy_scalar_value_from_python(value[()])

child_ptrs: list[ctypes.POINTER(_LbugValue)] = []
try:
for item in value:
child_ptrs.append(_value_from_python(item))
out = ctypes.POINTER(_LbugValue)()
arr_type = ctypes.POINTER(_LbugValue) * len(child_ptrs)
arr = arr_type(*child_ptrs) if child_ptrs else arr_type()
_check_state(
_LIB.lbug_value_create_list(len(child_ptrs), arr, ctypes.byref(out)),
"Failed to create numpy ndarray list value",
)
return out
finally:
for ptr in child_ptrs:
_LIB.lbug_value_destroy(ptr)


def _value_from_python(value: Any) -> ctypes.POINTER(_LbugValue):
if value is None:
return _LIB.lbug_value_create_null()
if isinstance(value, CAPIJsonParameter):
return _LIB.lbug_value_create_json(value.value.encode())
if _is_numpy_array(value):
return _numpy_array_value_from_python(value)
if _is_numpy_scalar(value):
return _numpy_scalar_value_from_python(value)
if isinstance(value, bool):
return _LIB.lbug_value_create_bool(value)
if isinstance(value, int) and not isinstance(value, bool):
Expand Down
23 changes: 23 additions & 0 deletions test/test_issue.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pytest

try:
from tools.python_api.test.type_aliases import ConnDB
except ImportError:
Expand Down Expand Up @@ -129,6 +131,27 @@ def test_int8_type_sniffing(conn_db_readwrite: ConnDB) -> None:
result.close()


def test_issue_483_numpy_ndarray_parameter(conn_db_readwrite: ConnDB) -> None:
np = pytest.importorskip("numpy")

conn, _ = conn_db_readwrite
conn.execute("CREATE NODE TABLE T(id INT64, v FLOAT[3], PRIMARY KEY(id))")
conn.execute("CREATE (:T {id: 1})")

arr = np.array([0.1, 0.2, 0.3], dtype=np.float32)
result = conn.execute(
"MATCH (n:T {id: 1}) SET n.v = $v RETURN n.v",
{"v": arr},
)

assert result.has_next()
assert result.get_next() == [
[pytest.approx(0.1), pytest.approx(0.2), pytest.approx(0.3)]
]
assert not result.has_next()
result.close()


# TODO(Maxwell): check if we should change getCastCost() for the following test
# def test_issue_3248(conn_db_readwrite: ConnDB) -> None:
# conn, _ = conn_db_readwrite
Expand Down
Loading
Loading