10#include <pybind11/stl.h>
19 Matrix(py::ssize_t rows, py::ssize_t cols) : m_rows(rows), m_cols(cols) {
20 print_created(
this, std::to_string(m_rows) +
"x" + std::to_string(m_cols) +
" matrix");
22 m_data =
new float[(
size_t) (rows * cols)];
23 memset(m_data, 0,
sizeof(
float) * (
size_t) (rows * cols));
26 Matrix(
const Matrix &s) : m_rows(s.m_rows), m_cols(s.m_cols) {
28 std::to_string(m_rows) +
"x" + std::to_string(m_cols) +
" matrix");
30 m_data =
new float[(
size_t) (m_rows * m_cols)];
31 memcpy(m_data, s.m_data,
sizeof(
float) * (
size_t) (m_rows * m_cols));
34 Matrix(Matrix &&s) noexcept : m_rows(s.m_rows), m_cols(s.m_cols), m_data(s.m_data) {
43 std::to_string(m_rows) +
"x" + std::to_string(m_cols) +
" matrix");
47 Matrix &operator=(
const Matrix &s) {
52 std::to_string(m_rows) +
"x" + std::to_string(m_cols) +
" matrix");
56 m_data =
new float[(
size_t) (m_rows * m_cols)];
57 memcpy(m_data, s.m_data,
sizeof(
float) * (
size_t) (m_rows * m_cols));
61 Matrix &operator=(Matrix &&s)
noexcept {
63 std::to_string(m_rows) +
"x" + std::to_string(m_cols) +
" matrix");
76 float operator()(py::ssize_t i, py::ssize_t j)
const {
77 return m_data[(
size_t) (i * m_cols + j)];
80 float &operator()(py::ssize_t i, py::ssize_t j) {
81 return m_data[(
size_t) (i * m_cols + j)];
84 float *
data() {
return m_data; }
86 py::ssize_t rows()
const {
return m_rows; }
87 py::ssize_t cols()
const {
return m_cols; }
94 py::class_<Matrix>(m,
"Matrix", py::buffer_protocol())
95 .def(py::init<py::ssize_t, py::ssize_t>())
97 .def(py::init([](
const py::buffer &b) {
98 py::buffer_info info = b.request();
99 if (info.format != py::format_descriptor<float>::format() || info.ndim != 2) {
100 throw std::runtime_error(
"Incompatible buffer format!");
103 auto *v =
new Matrix(info.shape[0], info.shape[1]);
104 memcpy(v->data(), info.ptr,
sizeof(
float) * (
size_t) (v->rows() * v->cols()));
108 .def(
"rows", &Matrix::rows)
109 .def(
"cols", &Matrix::cols)
113 [](
const Matrix &m, std::pair<py::ssize_t, py::ssize_t> i) {
114 if (i.first >= m.rows() || i.second >= m.cols()) {
115 throw py::index_error();
117 return m(i.first, i.second);
120 [](Matrix &m, std::pair<py::ssize_t, py::ssize_t> i,
float v) {
121 if (i.first >= m.rows() || i.second >= m.cols()) {
122 throw py::index_error();
124 m(i.first, i.second) = v;
127 .def_buffer([](Matrix &m) -> py::buffer_info {
128 return py::buffer_info(
130 {m.rows(), m.cols()},
131 {sizeof(float) * size_t(m.cols()),
136 class SquareMatrix :
public Matrix {
138 explicit SquareMatrix(py::ssize_t n) : Matrix(n, n) {}
141 py::class_<SquareMatrix, Matrix>(m,
"SquareMatrix").def(py::init<py::ssize_t>());
149 py::buffer_info get_buffer_info() {
150 return py::buffer_info(
151 &value,
sizeof(value), py::format_descriptor<int32_t>::format(), 1);
154 py::class_<Buffer>(m,
"Buffer", py::buffer_protocol())
156 .def_readwrite(
"value", &Buffer::value)
157 .def_buffer(&Buffer::get_buffer_info);
160 std::unique_ptr<int32_t> value;
163 int32_t get_value()
const {
return *value; }
164 void set_value(int32_t v) { *value = v; }
166 py::buffer_info get_buffer_info()
const {
167 return py::buffer_info(
168 value.get(),
sizeof(*value), py::format_descriptor<int32_t>::format(), 1);
171 ConstBuffer() : value(
new int32_t{0}) {}
173 py::class_<ConstBuffer>(m,
"ConstBuffer", py::buffer_protocol())
175 .def_property(
"value", &ConstBuffer::get_value, &ConstBuffer::set_value)
176 .def_buffer(&ConstBuffer::get_buffer_info);
178 struct DerivedBuffer :
public Buffer {};
179 py::class_<DerivedBuffer>(m,
"DerivedBuffer", py::buffer_protocol())
181 .def_readwrite(
"value", (int32_t DerivedBuffer::*) &DerivedBuffer::value)
182 .def_buffer(&DerivedBuffer::get_buffer_info);
184 struct BufferReadOnly {
185 const uint8_t value = 0;
186 explicit BufferReadOnly(uint8_t value) : value(value) {}
188 py::buffer_info get_buffer_info() {
return py::buffer_info(&value, 1); }
190 py::class_<BufferReadOnly>(m,
"BufferReadOnly", py::buffer_protocol())
191 .def(py::init<uint8_t>())
192 .def_buffer(&BufferReadOnly::get_buffer_info);
194 struct BufferReadOnlySelect {
196 bool readonly =
false;
198 py::buffer_info get_buffer_info() {
return py::buffer_info(&value, 1, readonly); }
200 py::class_<BufferReadOnlySelect>(m,
"BufferReadOnlySelect", py::buffer_protocol())
202 .def_readwrite(
"value", &BufferReadOnlySelect::value)
203 .def_readwrite(
"readonly", &BufferReadOnlySelect::readonly)
204 .def_buffer(&BufferReadOnlySelect::get_buffer_info);
207 py::class_<py::buffer_info>(m,
"buffer_info")
209 .def_readonly(
"itemsize", &py::buffer_info::itemsize)
210 .def_readonly(
"size", &py::buffer_info::size)
211 .def_readonly(
"format", &py::buffer_info::format)
212 .def_readonly(
"ndim", &py::buffer_info::ndim)
213 .def_readonly(
"shape", &py::buffer_info::shape)
214 .def_readonly(
"strides", &py::buffer_info::strides)
215 .def_readonly(
"readonly", &py::buffer_info::readonly)
216 .def(
"__repr__", [](py::handle
self) {
217 return py::str(
"itemsize={0.itemsize!r}, size={0.size!r}, format={0.format!r}, "
218 "ndim={0.ndim!r}, shape={0.shape!r}, strides={0.strides!r}, "
219 "readonly={0.readonly!r}")
buffer_info request(bool writable=false) const
void print_copy_created(T *inst, Values &&...values)
void print_copy_assigned(T *inst, Values &&...values)
void print_created(T *inst, Values &&...values)
void print_destroyed(T *inst, Values &&...values)
void print_move_assigned(T *inst, Values &&...values)
void print_move_created(T *inst, Values &&...values)
#define TEST_SUBMODULE(name, variable)
arr data(const arr &a, Ix... index)