From 572f7d65af1b2bcc045af08f42817258c9cc9801 Mon Sep 17 00:00:00 2001 From: polka777 Date: Wed, 22 Apr 2026 15:11:20 +0300 Subject: [PATCH 01/14] feat(type): register Pair type in type system --- include/spla.hpp | 1 + include/spla/op.hpp | 13 ++++++++++++ include/spla/pair.hpp | 28 ++++++++++++++++++++++++++ src/binding/c_op.cpp | 4 ++++ src/binding/c_type.cpp | 3 +++ src/core/ttype.hpp | 5 +++++ src/op.cpp | 45 ++++++++++++++++++++++++++++++++++++++++++ tests/test_op.cpp | 1 + 8 files changed, 100 insertions(+) create mode 100644 include/spla/pair.hpp diff --git a/include/spla.hpp b/include/spla.hpp index dc82624a6..6c19daa2c 100644 --- a/include/spla.hpp +++ b/include/spla.hpp @@ -45,5 +45,6 @@ #include "spla/timer.hpp" #include "spla/type.hpp" #include "spla/vector.hpp" +#include "spla/pair.hpp" #endif//SPLA_SPLA_HPP diff --git a/include/spla/op.hpp b/include/spla/op.hpp index 96dd16848..738b130ad 100644 --- a/include/spla/op.hpp +++ b/include/spla/op.hpp @@ -30,6 +30,8 @@ #include "object.hpp" #include "type.hpp" +#include "spla/pair.hpp" + #include @@ -64,6 +66,7 @@ namespace spla { SPLA_API static ref_ptr make_int(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_uint(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_float(std::string name, std::string code, std::function function); + SPLA_API static ref_ptr make_pair(std::string name, std::string code, std::function function); }; /** @@ -78,6 +81,8 @@ namespace spla { SPLA_API static ref_ptr make_int(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_uint(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_float(std::string name, std::string code, std::function function); + SPLA_API static ref_ptr make_pair(std::string name, std::string code, std::function function); + }; /** @@ -91,6 +96,8 @@ namespace spla { SPLA_API static ref_ptr make_int(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_uint(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_float(std::string name, std::string code, std::function function); + SPLA_API static ref_ptr make_pair(std::string name, std::string code, std::function function); + }; //////////////////////////////// Unary //////////////////////////////// @@ -130,6 +137,8 @@ namespace spla { SPLA_API extern ref_ptr FLOOR_FLOAT; SPLA_API extern ref_ptr ROUND_FLOAT; SPLA_API extern ref_ptr TRUNC_FLOAT; + SPLA_API extern ref_ptr IDENTITY_PAIR; + //////////////////////////////// Binary //////////////////////////////// @@ -182,6 +191,9 @@ namespace spla { SPLA_API extern ref_ptr BXOR_INT; SPLA_API extern ref_ptr BXOR_UINT; + SPLA_API extern ref_ptr MIN_PAIR; + SPLA_API extern ref_ptr MUL_PAIR; + //////////////////////////////// Select //////////////////////////////// SPLA_API extern ref_ptr EQZERO_INT; @@ -205,6 +217,7 @@ namespace spla { SPLA_API extern ref_ptr ALWAYS_INT; SPLA_API extern ref_ptr ALWAYS_UINT; SPLA_API extern ref_ptr ALWAYS_FLOAT; + SPLA_API extern ref_ptr ALWAYS_PAIR; SPLA_API extern ref_ptr NEVER_INT; SPLA_API extern ref_ptr NEVER_UINT; SPLA_API extern ref_ptr NEVER_FLOAT; diff --git a/include/spla/pair.hpp b/include/spla/pair.hpp new file mode 100644 index 000000000..2b40b468c --- /dev/null +++ b/include/spla/pair.hpp @@ -0,0 +1,28 @@ +#ifndef SPLA_PAIR_HPP +#define SPLA_PAIR_HPP +#include + + +namespace spla { + struct Pair { + float weight; + int vertex; + + Pair(): weight(std::numeric_limits::infinity()), vertex(-1){} + Pair(float w, int v): weight(w), vertex(v){} + + bool operator<(const Pair& other) const { + return weight < other.weight; + } + bool operator==(const Pair& other) const { + return weight == other.weight && vertex == other.vertex; + } + + bool operator!=(const Pair& other) const { + return !(*this == other); + } + + Pair& operator=(const Pair& other) = default; + }; +} +#endif \ No newline at end of file diff --git a/src/binding/c_op.cpp b/src/binding/c_op.cpp index e3202e4f0..a6f71cc92 100644 --- a/src/binding/c_op.cpp +++ b/src/binding/c_op.cpp @@ -30,6 +30,7 @@ spla_OpUnary spla_OpUnary_IDENTITY_INT() { return as_ptr(spla::IDENTITY_INT.ref_and_get()); } spla_OpUnary spla_OpUnary_IDENTITY_UINT() { return as_ptr(spla::IDENTITY_UINT.ref_and_get()); } spla_OpUnary spla_OpUnary_IDENTITY_FLOAT() { return as_ptr(spla::IDENTITY_FLOAT.ref_and_get()); } +spla_OpUnary spla_OpUnary_IDENTITY_PAIR() { return as_ptr(spla::IDENTITY_PAIR.ref_and_get()); } spla_OpUnary spla_OpUnary_AINV_INT() { return as_ptr(spla::AINV_INT.ref_and_get()); } spla_OpUnary spla_OpUnary_AINV_UINT() { return as_ptr(spla::AINV_UINT.ref_and_get()); } spla_OpUnary spla_OpUnary_AINV_FLOAT() { return as_ptr(spla::AINV_FLOAT.ref_and_get()); } @@ -103,6 +104,8 @@ spla_OpBinary spla_OpBinary_BAND_INT() { return as_ptr(spla::BA spla_OpBinary spla_OpBinary_BAND_UINT() { return as_ptr(spla::BAND_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_BXOR_INT() { return as_ptr(spla::BXOR_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_BXOR_UINT() { return as_ptr(spla::BXOR_UINT.ref_and_get()); } +spla_OpBinary spla_OpBinary_MIN_PAIR() { return as_ptr(spla::MIN_PAIR.ref_and_get()); } +spla_OpBinary spla_OpBinary_MUL_PAIR() { return as_ptr(spla::MUL_PAIR.ref_and_get()); } spla_OpSelect spla_OpSelect_EQZERO_INT() { return as_ptr(spla::EQZERO_INT.ref_and_get()); } spla_OpSelect spla_OpSelect_EQZERO_UINT() { return as_ptr(spla::EQZERO_UINT.ref_and_get()); } @@ -125,6 +128,7 @@ spla_OpSelect spla_OpSelect_LEZERO_FLOAT() { return as_ptr(spla spla_OpSelect spla_OpSelect_ALWAYS_INT() { return as_ptr(spla::ALWAYS_INT.ref_and_get()); } spla_OpSelect spla_OpSelect_ALWAYS_UINT() { return as_ptr(spla::ALWAYS_UINT.ref_and_get()); } spla_OpSelect spla_OpSelect_ALWAYS_FLOAT() { return as_ptr(spla::ALWAYS_FLOAT.ref_and_get()); } +spla_OpSelect spla_OpSelect_ALWAYS_PAIR() { return as_ptr(spla::ALWAYS_PAIR.ref_and_get()); } spla_OpSelect spla_OpSelect_NEVER_INT() { return as_ptr(spla::NEVER_INT.ref_and_get()); } spla_OpSelect spla_OpSelect_NEVER_UINT() { return as_ptr(spla::NEVER_UINT.ref_and_get()); } spla_OpSelect spla_OpSelect_NEVER_FLOAT() { return as_ptr(spla::NEVER_FLOAT.ref_and_get()); } \ No newline at end of file diff --git a/src/binding/c_type.cpp b/src/binding/c_type.cpp index 2ca432791..09fbaddce 100644 --- a/src/binding/c_type.cpp +++ b/src/binding/c_type.cpp @@ -38,4 +38,7 @@ spla_Type spla_Type_UINT() { } spla_Type spla_Type_FLOAT() { return as_ptr(spla::FLOAT.get()); +} +spla_Type spla_Type_PAIR() { + return as_ptr(spla::PAIR.get()); } \ No newline at end of file diff --git a/src/core/ttype.hpp b/src/core/ttype.hpp index 344de29ac..6bb273548 100644 --- a/src/core/ttype.hpp +++ b/src/core/ttype.hpp @@ -129,6 +129,11 @@ namespace spla { ref_ptr> get_ttype() { return FLOAT.cast_safe>(); } + template<> + ref_ptr> get_ttype() { + return PAIR.cast_safe>(); + } + /** * @} diff --git a/src/op.cpp b/src/op.cpp index a1bca7514..de85937e8 100644 --- a/src/op.cpp +++ b/src/op.cpp @@ -28,9 +28,11 @@ #include #include "spla/op.hpp" +#include "spla/pair.hpp" #include #include +#include namespace spla { @@ -69,6 +71,8 @@ namespace spla { ref_ptr FLOOR_FLOAT; ref_ptr ROUND_FLOAT; ref_ptr TRUNC_FLOAT; + ref_ptr IDENTITY_PAIR; + ////////////////////////////////////////////////////////////////////////////// @@ -121,6 +125,9 @@ namespace spla { ref_ptr BXOR_INT; ref_ptr BXOR_UINT; + ref_ptr MIN_PAIR; + ref_ptr MUL_PAIR; + ////////////////////////////////////////////////////////////////////////////// ref_ptr EQZERO_INT; @@ -144,10 +151,12 @@ namespace spla { ref_ptr ALWAYS_INT; ref_ptr ALWAYS_UINT; ref_ptr ALWAYS_FLOAT; + ref_ptr ALWAYS_PAIR; ref_ptr NEVER_INT; ref_ptr NEVER_UINT; ref_ptr NEVER_FLOAT; + template inline T min(T a, T b) { return std::min(a, b); } @@ -190,6 +199,7 @@ namespace spla { DECL_OP_UNA_S(FLOOR_FLOAT, FLOOR, T_FLOAT, { return floor(a); }); DECL_OP_UNA_S(ROUND_FLOAT, ROUND, T_FLOAT, { return round(a); }); DECL_OP_UNA_S(TRUNC_FLOAT, TRUNC, T_FLOAT, { return trunc(a); }); + IDENTITY_PAIR = spla::OpUnary::make_pair("IDENTITY_PAIR", "(a) identity_pair(a)", [](Pair a) { return a; }); DECL_OP_BIN_S(PLUS_INT, PLUS, T_INT, { return a + b; }); DECL_OP_BIN_S(PLUS_UINT, PLUS, T_UINT, { return a + b; }); @@ -240,6 +250,15 @@ namespace spla { DECL_OP_BIN_S(BXOR_INT, BXOR, T_INT, { return a ^ b; }); DECL_OP_BIN_S(BXOR_UINT, BXOR, T_UINT, { return a ^ b; }); + MUL_PAIR = OpBinary::make_pair("MUL_PAIR", + "(a, b) make_pair(a.weight, b.vertex)", + [](Pair a, Pair b) { return Pair(a.weight, b.vertex); }); + MIN_PAIR = OpBinary::make_pair("MIN_PAIR", + "(a, b) min_pair(a, b)", + [](Pair a, Pair b) { + if (a.weight == b.weight) return a.vertex < b.vertex? a : b; + return a.weight < b.weight? a : b; }); + DECL_OP_SELECT(EQZERO_INT, EQZERO, T_INT, { return a == 0; }); DECL_OP_SELECT(EQZERO_UINT, EQZERO, T_UINT, { return a == 0; }); DECL_OP_SELECT(EQZERO_FLOAT, EQZERO, T_FLOAT, { return a == 0; }); @@ -261,9 +280,11 @@ namespace spla { DECL_OP_SELECT(ALWAYS_INT, ALWAYS, T_INT, { return 1; }); DECL_OP_SELECT(ALWAYS_UINT, ALWAYS, T_UINT, { return 1; }); DECL_OP_SELECT(ALWAYS_FLOAT, ALWAYS, T_FLOAT, { return 1; }); + ALWAYS_PAIR = OpSelect::make_pair("ALWAYS_PAIR", "(a) pair_always(a)", [](Pair a) { return 1; }); DECL_OP_SELECT(NEVER_INT, NEVER, T_INT, { return 0; }); DECL_OP_SELECT(NEVER_UINT, NEVER, T_UINT, { return 0; }); DECL_OP_SELECT(NEVER_FLOAT, NEVER, T_FLOAT, { return 0; }); + } ref_ptr OpUnary::make_int(std::string name, std::string code, std::function function) { @@ -290,6 +311,14 @@ namespace spla { op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_res()->get_code(); return op.as(); } + ref_ptr OpUnary::make_pair(std::string name, std::string code, std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_res()->get_code(); + return op.as(); + } ref_ptr OpBinary::make_int(std::string name, std::string code, std::function function) { auto op = make_ref>(); @@ -315,6 +344,14 @@ namespace spla { op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); return op.as(); } + ref_ptr OpBinary::make_pair(std::string name, std::string code, std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + return op.as(); + } ref_ptr OpSelect::make_int(std::string name, std::string code, std::function function) { auto op = make_ref>(); @@ -340,5 +377,13 @@ namespace spla { op->key = op->name + "_" + op->get_type_arg_0()->get_code(); return op.as(); } + ref_ptr OpSelect::make_pair(std::string name, std::string code, std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code(); + return op.as(); + } }// namespace spla \ No newline at end of file diff --git a/tests/test_op.cpp b/tests/test_op.cpp index a1d915891..db74ae415 100644 --- a/tests/test_op.cpp +++ b/tests/test_op.cpp @@ -48,6 +48,7 @@ TEST(op_binary, info_built_in) { display_op_info(spla::MIN_FLOAT); display_op_info(spla::BONE_FLOAT); + display_op_info(spla::MIN_PAIR); } TEST(op_binary, custom) { From 9cafb1062dd8f8886ff02650cd8af3795fba8aec Mon Sep 17 00:00:00 2001 From: polka777 Date: Wed, 22 Apr 2026 15:34:13 +0300 Subject: [PATCH 02/14] feat(matrix,vector,scalar): add Pair type support --- include/spla/matrix.hpp | 2 + include/spla/scalar.hpp | 3 ++ include/spla/type.hpp | 3 ++ include/spla/vector.hpp | 2 + src/core/tmatrix.hpp | 56 ++++++++++++++++++++++++ src/core/tscalar.hpp | 96 +++++++++++++++++++++++++++++++++++++++-- src/core/tvector.hpp | 64 +++++++++++++++++++++++++++ src/matrix.cpp | 3 ++ src/scalar.cpp | 4 ++ src/type.cpp | 1 + src/vector.cpp | 3 ++ 11 files changed, 234 insertions(+), 3 deletions(-) diff --git a/include/spla/matrix.hpp b/include/spla/matrix.hpp index 706a5dc86..7038a68ad 100644 --- a/include/spla/matrix.hpp +++ b/include/spla/matrix.hpp @@ -57,9 +57,11 @@ namespace spla { SPLA_API virtual Status set_int(uint row_id, uint col_id, std::int32_t value) = 0; SPLA_API virtual Status set_uint(uint row_id, uint col_id, std::uint32_t value) = 0; SPLA_API virtual Status set_float(uint row_id, uint col_id, float value) = 0; + SPLA_API virtual Status set_pair(uint row_id, uint col_id, Pair value) = 0; SPLA_API virtual Status get_int(uint row_id, uint col_id, std::int32_t& value) = 0; SPLA_API virtual Status get_uint(uint row_id, uint col_id, std::uint32_t& value) = 0; SPLA_API virtual Status get_float(uint row_id, uint col_id, float& value) = 0; + SPLA_API virtual Status get_pair(uint row_id, uint col_id, Pair& value) = 0; SPLA_API virtual Status build(const ref_ptr& keys1, const ref_ptr& keys2, const ref_ptr& values) = 0; SPLA_API virtual Status read(ref_ptr& keys1, ref_ptr& keys2, ref_ptr& values) = 0; SPLA_API virtual Status clear() = 0; diff --git a/include/spla/scalar.hpp b/include/spla/scalar.hpp index 0a0588914..a4fa14a29 100644 --- a/include/spla/scalar.hpp +++ b/include/spla/scalar.hpp @@ -49,12 +49,15 @@ namespace spla { SPLA_API virtual Status set_int(std::int32_t value) = 0; SPLA_API virtual Status set_uint(std::uint32_t value) = 0; SPLA_API virtual Status set_float(float value) = 0; + SPLA_API virtual Status set_pair(Pair value) {return Status::InvalidArgument;} SPLA_API virtual Status get_int(std::int32_t& value) = 0; SPLA_API virtual Status get_uint(std::uint32_t& value) = 0; SPLA_API virtual Status get_float(float& value) = 0; + SPLA_API virtual Status get_pair(Pair& value) {return Status::InvalidArgument;} SPLA_API virtual T_INT as_int() = 0; SPLA_API virtual T_UINT as_uint() = 0; SPLA_API virtual T_FLOAT as_float() = 0; + SPLA_API virtual T_PAIR as_pair() = 0; SPLA_API static ref_ptr make(const ref_ptr& type); SPLA_API static ref_ptr make_int(std::int32_t value); diff --git a/include/spla/type.hpp b/include/spla/type.hpp index f61adf867..7ed7919c0 100644 --- a/include/spla/type.hpp +++ b/include/spla/type.hpp @@ -29,6 +29,7 @@ #define SPLA_TYPE_HPP #include "object.hpp" +#include "pair.hpp" #include @@ -58,11 +59,13 @@ namespace spla { using T_INT = std::int32_t; using T_UINT = std::uint32_t; using T_FLOAT = float; + using T_PAIR = Pair; SPLA_API extern ref_ptr BOOL; SPLA_API extern ref_ptr INT; SPLA_API extern ref_ptr UINT; SPLA_API extern ref_ptr FLOAT; + SPLA_API extern ref_ptr PAIR; /** * @} diff --git a/include/spla/vector.hpp b/include/spla/vector.hpp index 0ba6d512d..dedafb901 100644 --- a/include/spla/vector.hpp +++ b/include/spla/vector.hpp @@ -59,6 +59,8 @@ namespace spla { SPLA_API virtual Status get_int(uint row_id, T_INT& value) = 0; SPLA_API virtual Status get_uint(uint row_id, T_UINT& value) = 0; SPLA_API virtual Status get_float(uint row_id, float& value) = 0; + SPLA_API virtual Status get_pair(uint row_id, Pair& value) = 0; + SPLA_API virtual Status set_pair(uint row_id, Pair value) = 0; SPLA_API virtual Status fill_noize(uint seed) = 0; SPLA_API virtual Status fill_with(const ref_ptr& value) = 0; SPLA_API virtual Status build(const ref_ptr& keys, const ref_ptr& values) = 0; diff --git a/src/core/tmatrix.hpp b/src/core/tmatrix.hpp index cd36e72b5..552035df3 100644 --- a/src/core/tmatrix.hpp +++ b/src/core/tmatrix.hpp @@ -70,9 +70,11 @@ namespace spla { Status set_int(uint row_id, uint col_id, std::int32_t value) override; Status set_uint(uint row_id, uint col_id, std::uint32_t value) override; Status set_float(uint row_id, uint col_id, float value) override; + Status set_pair(uint row_id, uint col_id, Pair value) override { return Status::InvalidArgument;} Status get_int(uint row_id, uint col_id, int32_t& value) override; Status get_uint(uint row_id, uint col_id, uint32_t& value) override; Status get_float(uint row_id, uint col_id, float& value) override; + Status get_pair(uint row_id, uint col_id, Pair& value) override { return Status::InvalidArgument;} Status build(const ref_ptr& keys1, const ref_ptr& keys2, const ref_ptr& values) override; Status read(ref_ptr& keys1, ref_ptr& keys2, ref_ptr& values) override; Status clear() override; @@ -135,6 +137,8 @@ namespace spla { if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_int()); if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_uint()); if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_float()); + if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_pair()); + return Status::Ok; } @@ -161,18 +165,31 @@ namespace spla { cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); return Status::Ok; } + template<> + inline Status TMatrix::set_int(uint row_id, uint col_id, std::int32_t value) { + return Status::InvalidArgument; + } template Status TMatrix::set_uint(uint row_id, uint col_id, std::uint32_t value) { validate_rwd(FormatMatrix::CpuLil); cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); return Status::Ok; } + template<> + inline Status TMatrix::set_uint(uint row_id, uint col_id, std::uint32_t value) { + return Status::InvalidArgument; + } template Status TMatrix::set_float(uint row_id, uint col_id, float value) { validate_rwd(FormatMatrix::CpuLil); cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); return Status::Ok; } + template<> + inline Status TMatrix::set_float(uint row_id, uint col_id, float value) { + return Status::InvalidArgument; + } + template Status TMatrix::get_int(uint row_id, uint col_id, int32_t& value) { @@ -188,6 +205,10 @@ namespace spla { return Status::Ok; } + template<> + inline Status TMatrix::get_int(uint row_id, uint col_id, std::int32_t& value) { + return Status::InvalidArgument; + } template Status TMatrix::get_uint(uint row_id, uint col_id, uint32_t& value) { validate_rw(FormatMatrix::CpuDok); @@ -202,6 +223,10 @@ namespace spla { return Status::Ok; } + template<> + inline Status TMatrix::get_uint(uint row_id, uint col_id, std::uint32_t& value) { + return Status::InvalidArgument; + } template Status TMatrix::get_float(uint row_id, uint col_id, float& value) { validate_rw(FormatMatrix::CpuDok); @@ -216,6 +241,10 @@ namespace spla { return Status::Ok; } + template<> + inline Status TMatrix::get_float(uint row_id, uint col_id, float& value) { + return Status::InvalidArgument; + } template Status TMatrix::build(const ref_ptr& keys1, const ref_ptr& keys2, const ref_ptr& values) { @@ -314,6 +343,33 @@ namespace spla { return storage_manager.get(); } + template<> + inline Status TMatrix::set_pair(uint row_id, uint col_id, Pair value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + + validate_rwd(FormatMatrix::CpuLil); + cpu_lil_add_element(row_id, col_id, value, *get>()); + return Status::Ok; + } + template<> + inline Status TMatrix::get_pair(uint row_id, uint col_id, Pair& value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + validate_rw(FormatMatrix::CpuDok); + + auto& Ax = get>()->Ax; + auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; + } /** * @} diff --git a/src/core/tscalar.hpp b/src/core/tscalar.hpp index 091a469fa..8268ceff3 100644 --- a/src/core/tscalar.hpp +++ b/src/core/tscalar.hpp @@ -60,6 +60,8 @@ namespace spla { T_INT as_int() override { return static_cast(m_value); } T_UINT as_uint() override { return static_cast(m_value); } T_FLOAT as_float() override { return static_cast(m_value); } + T_PAIR as_pair() override { return static_cast(m_value); } + void set_label(std::string label) override; const std::string& get_label() const override; @@ -131,10 +133,98 @@ namespace spla { T TScalar::get_value() const { return m_value; } + template<> + inline T_PAIR TScalar::as_pair() { + return Pair(); + } + template<> + inline T_PAIR TScalar::as_pair() { + return Pair(); + } + template<> + inline T_PAIR TScalar::as_pair() { + return Pair(); + } + + template<> + class TScalar final : public Scalar { + public: + TScalar() = default; + explicit TScalar(Pair value) : m_value(value) {} + ~TScalar() override = default; - /** - * @} - */ + Status set_pair(Pair value) { + m_value = value; + return Status::Ok; + } + + Status get_pair(Pair& value) const { + value = m_value; + return Status::Ok; + } + + ref_ptr get_type() override { + return PAIR; + } + + Status set_int(std::int32_t) override { + return Status::InvalidArgument; + } + + Status set_uint(std::uint32_t) override { + return Status::InvalidArgument; + } + + Status set_float(float) override { + return Status::InvalidArgument; + } + + Status get_int(std::int32_t& ) override { + return Status::InvalidArgument; + } + + Status get_uint(std::uint32_t& ) override { + return Status::InvalidArgument; + } + + Status get_float(float& ) override { + return Status::InvalidArgument; + } + + T_INT as_int() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to int"); + return 0; + } + + T_UINT as_uint() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to uint"); + return 0; + } + + T_FLOAT as_float() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to float"); + return 0.0f; + } + T_PAIR as_pair() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to pair"); + return Pair(); + } + + void set_label(std::string label) override { + m_label = std::move(label); + } + + const std::string& get_label() const override { + return m_label; + } + + Pair& get_value() { return m_value; } + Pair get_value() const { return m_value; } + + private: + std::string m_label; + Pair m_value = Pair(); + }; }// namespace spla diff --git a/src/core/tvector.hpp b/src/core/tvector.hpp index af3e5c71f..434f3e9bd 100644 --- a/src/core/tvector.hpp +++ b/src/core/tvector.hpp @@ -42,6 +42,7 @@ #include #include +#include "spla/pair.hpp" namespace spla { @@ -72,9 +73,11 @@ namespace spla { Status set_int(uint row_id, std::int32_t value) override; Status set_uint(uint row_id, std::uint32_t value) override; Status set_float(uint row_id, float value) override; + Status set_pair(uint row_id, Pair value) override { return Status::InvalidArgument;} Status get_int(uint row_id, int32_t& value) override; Status get_uint(uint row_id, uint32_t& value) override; Status get_float(uint row_id, float& value) override; + Status get_pair(uint row_id, Pair& value) override { return Status::InvalidArgument;} Status fill_noize(uint seed) override; Status fill_with(const ref_ptr& value) override; Status build(const ref_ptr& keys, const ref_ptr& values) override; @@ -167,6 +170,11 @@ namespace spla { cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); return Status::Ok; } + template<> + inline Status TVector::set_int(uint row_id, std::int32_t value) { + return Status::InvalidArgument; + } + template Status TVector::set_uint(uint row_id, std::uint32_t value) { if (is_valid(FormatVector::CpuDense)) { @@ -179,6 +187,10 @@ namespace spla { cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); return Status::Ok; } + template<> + inline Status TVector::set_uint(uint row_id, std::uint32_t value) { + return Status::InvalidArgument; + } template Status TVector::set_float(uint row_id, float value) { if (is_valid(FormatVector::CpuDense)) { @@ -191,6 +203,10 @@ namespace spla { cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); return Status::Ok; } + template<> + inline Status TVector::set_float(uint row_id, float value) { + return Status::InvalidArgument; + } template Status TVector::get_int(uint row_id, int32_t& value) { @@ -206,6 +222,10 @@ namespace spla { return Status::Ok; } + template<> + inline Status TVector::get_int(uint row_id, int32_t& value) { + return Status::InvalidArgument; + } template Status TVector::get_uint(uint row_id, uint32_t& value) { validate_rw(FormatVector::CpuDok); @@ -220,6 +240,10 @@ namespace spla { return Status::Ok; } + template<> + inline Status TVector::get_uint(uint row_id, uint32_t& value) { + return Status::InvalidArgument; + } template Status TVector::get_float(uint row_id, float& value) { validate_rw(FormatVector::CpuDok); @@ -234,6 +258,10 @@ namespace spla { return Status::Ok; } + template<> + inline Status TVector::get_float(uint row_id, float& value) { + return Status::InvalidArgument; + } template Status TVector::fill_noize(uint seed) { @@ -359,6 +387,42 @@ namespace spla { return storage_manager.get(); } + template<> + inline Status TVector::set_pair(uint row_id, Pair value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + + if (is_valid(FormatVector::CpuDense)) { + validate_rwd(FormatVector::CpuDense); + get>()->Ax[row_id] = value; + return Status::Ok; + } + + validate_rwd(FormatVector::CpuDok); + cpu_dok_vec_add_element(row_id, value, *get>()); + return Status::Ok; + } + template<> + inline Status TVector::get_pair(uint row_id, Pair& value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + + validate_rw(FormatVector::CpuDok); + + const auto& Ax = get>()->Ax; + const auto entry = Ax.find(row_id); + + if (entry != Ax.end()) { + value = entry->second; + } else { + value = m_storage.get_fill_value(); + } + + return Status::Ok; + } + /** * @} diff --git a/src/matrix.cpp b/src/matrix.cpp index 4ecab16aa..aa319ca3b 100644 --- a/src/matrix.cpp +++ b/src/matrix.cpp @@ -51,6 +51,9 @@ namespace spla { if (type == FLOAT) { return ref_ptr(new TMatrix(n_rows, n_cols)); } + if (type == PAIR) { + return ref_ptr(new TMatrix(n_rows, n_cols)); + } LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); return ref_ptr(); diff --git a/src/scalar.cpp b/src/scalar.cpp index e48ce958e..b090a2757 100644 --- a/src/scalar.cpp +++ b/src/scalar.cpp @@ -47,6 +47,9 @@ namespace spla { if (type == FLOAT) { return ref_ptr(new TScalar()); } + if (type == PAIR) { + return ref_ptr(new TScalar()); + } LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); return ref_ptr(); @@ -61,5 +64,6 @@ namespace spla { ref_ptr Scalar::Scalar::make_float(float value) { return ref_ptr(new TScalar(value)); } + }// namespace spla \ No newline at end of file diff --git a/src/type.cpp b/src/type.cpp index 7b75a7154..28a97edc3 100644 --- a/src/type.cpp +++ b/src/type.cpp @@ -33,5 +33,6 @@ namespace spla { ref_ptr INT = TType::make_type("INT", "I", "int", "signed 4 byte integral type", 2); ref_ptr UINT = TType::make_type("UINT", "U", "uint", "unsigned 4 byte integral type", 3); ref_ptr FLOAT = TType::make_type("FLOAT", "F", "float", "4 byte floating point type", 4); + ref_ptr PAIR = TType::make_type("PAIR", "P", "struct Pair", "weight-vertex pair float-int", 5); }// namespace spla \ No newline at end of file diff --git a/src/vector.cpp b/src/vector.cpp index e6427d90f..0e3a47863 100644 --- a/src/vector.cpp +++ b/src/vector.cpp @@ -51,6 +51,9 @@ namespace spla { if (type == FLOAT) { return ref_ptr(new TVector(n_rows)); } + if (type == spla::PAIR) { + return ref_ptr(new TVector(n_rows)); + } LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); return ref_ptr{}; From 204a7bc46707c3ddb2e7a46f8b82b83570dc6178 Mon Sep 17 00:00:00 2001 From: polka777 Date: Wed, 22 Apr 2026 15:44:38 +0300 Subject: [PATCH 03/14] feat(opencl pair) support pair in opencl, operations mxv, extract row for cpu --- src/cpu/cpu_algo_registry.cpp | 2 ++ src/opencl/cl_algo_registry.cpp | 3 +++ src/opencl/cl_mxv.hpp | 8 +++++--- src/opencl/cl_program_builder.cpp | 7 +++++-- src/opencl/generated/auto_common_api.hpp | 26 +++++++++++++++++++++++- src/opencl/kernels/common_api.cl | 24 ++++++++++++++++++++++ src/opencl/kernels/mxv.cl | 16 +++++++++++++-- 7 files changed, 78 insertions(+), 8 deletions(-) diff --git a/src/cpu/cpu_algo_registry.cpp b/src/cpu/cpu_algo_registry.cpp index 4944eb756..379e602c7 100644 --- a/src/cpu/cpu_algo_registry.cpp +++ b/src/cpu/cpu_algo_registry.cpp @@ -127,6 +127,8 @@ namespace spla { g_registry->add(MAKE_KEY_CPU_0("m_extract_row", INT), std::make_shared>()); g_registry->add(MAKE_KEY_CPU_0("m_extract_row", UINT), std::make_shared>()); g_registry->add(MAKE_KEY_CPU_0("m_extract_row", FLOAT), std::make_shared>()); + g_registry->add(MAKE_KEY_CPU_0("m_extract_row", PAIR), std::make_shared>()); + // algorthm m_extract_column g_registry->add(MAKE_KEY_CPU_0("m_extract_column", INT), std::make_shared>()); diff --git a/src/opencl/cl_algo_registry.cpp b/src/opencl/cl_algo_registry.cpp index 65726579b..803316bb5 100644 --- a/src/opencl/cl_algo_registry.cpp +++ b/src/opencl/cl_algo_registry.cpp @@ -41,6 +41,7 @@ #include #include + namespace spla { void register_algo_cl(class Registry* g_registry) { @@ -83,6 +84,7 @@ namespace spla { g_registry->add(MAKE_KEY_CL_0("mxv_masked", INT), std::make_shared>()); g_registry->add(MAKE_KEY_CL_0("mxv_masked", UINT), std::make_shared>()); g_registry->add(MAKE_KEY_CL_0("mxv_masked", FLOAT), std::make_shared>()); + g_registry->add(MAKE_KEY_CL_0("mxv_masked", PAIR), std::make_shared>()); // algorthm vxm_masked g_registry->add(MAKE_KEY_CL_0("vxm_masked", INT), std::make_shared>()); @@ -93,6 +95,7 @@ namespace spla { g_registry->add(MAKE_KEY_CL_0("mxmT_masked", INT), std::make_shared>()); g_registry->add(MAKE_KEY_CL_0("mxmT_masked", UINT), std::make_shared>()); g_registry->add(MAKE_KEY_CL_0("mxmT_masked", FLOAT), std::make_shared>()); + } }// namespace spla diff --git a/src/opencl/cl_mxv.hpp b/src/opencl/cl_mxv.hpp index 9a2e437bd..6f1602ec1 100644 --- a/src/opencl/cl_mxv.hpp +++ b/src/opencl/cl_mxv.hpp @@ -263,10 +263,12 @@ namespace spla { .add_type("TYPE", get_ttype().template as()) .add_op("OP_BINARY1", op_multiply.template as()) .add_op("OP_BINARY2", op_add.template as()) - .add_op("OP_SELECT", op_select.template as()) - .set_source(source_mxv) - .acquire(); + .add_op("OP_SELECT", op_select.template as()); + if constexpr (std::is_same_v) { + program_builder.add_define("USE_PAIR_COMPARISON", 1); + } + program_builder.set_source(source_mxv).acquire(); program = program_builder.get_program(); return true; diff --git a/src/opencl/cl_program_builder.cpp b/src/opencl/cl_program_builder.cpp index 0c17e4e13..5599744bf 100644 --- a/src/opencl/cl_program_builder.cpp +++ b/src/opencl/cl_program_builder.cpp @@ -31,6 +31,7 @@ #include #include +#include namespace spla { @@ -63,6 +64,7 @@ namespace spla { return *this; } void CLProgramBuilder::acquire() { + CLAccelerator* acc = get_acc_cl(); CLProgramCache* cache = acc->get_cache(); @@ -85,13 +87,14 @@ namespace spla { for (const auto& define : m_defines) { builder << "#define " << define.first << " " << define.second << "\n"; } + builder << source_common_api; for (const auto& function : m_functions) { - builder << function.second->get_type_res()->get_cpp() << " " - << function.first << function.second->get_source_cl() << "\n"; + builder << "#define " << function.first << function.second->get_source_cl() << "\n"; } builder << m_source; + m_program_code = builder.str(); m_program = std::make_shared(); diff --git a/src/opencl/generated/auto_common_api.hpp b/src/opencl/generated/auto_common_api.hpp index 0e1e2f7a6..30b06fe21 100644 --- a/src/opencl/generated/auto_common_api.hpp +++ b/src/opencl/generated/auto_common_api.hpp @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////// -// Copyright (c) 2021 - 2023 SparseLinearAlgebra +// Copyright (c) 2021 - 2026 SparseLinearAlgebra // Autogenerated file, do not modify //////////////////////////////////////////////////////////////////// @@ -7,6 +7,30 @@ static const char source_common_api[] = R"( +struct Pair{ + float weight; + int vertex; +}; + +struct Pair make_pair(float w, int v) { + struct Pair p; + p.weight = w; + p.vertex = v; + return p; + +} + +struct Pair min_pair(struct Pair a, struct Pair b) { + if (a.weight == b.weight) return a.vertex < b.vertex ? a : b; + return a.weight < b.weight ? a : b; +} + +int pair_always(struct Pair a) { + return 1; +} +struct Pair identity_pair(struct Pair a) { + return a; +} uint random_gen_java(ulong seed) { seed = (seed * 0x5DEECE66DL + 0xBL) & ((1L << 48L) - 1); diff --git a/src/opencl/kernels/common_api.cl b/src/opencl/kernels/common_api.cl index 73e3af610..aaaa91765 100644 --- a/src/opencl/kernels/common_api.cl +++ b/src/opencl/kernels/common_api.cl @@ -26,6 +26,30 @@ /**********************************************************************************/ #include "common_def.cl" +struct Pair{ + float weight; + int vertex; +}; + +struct Pair make_pair(float w, int v) { + struct Pair p; + p.weight = w; + p.vertex = v; + return p; + +} + +struct Pair min_pair(struct Pair a, struct Pair b) { + if (a.weight == b.weight) return a.vertex < b.vertex ? a : b; + return a.weight < b.weight ? a : b; +} + +int pair_always(struct Pair a) { + return 1; +} +struct Pair identity_pair(struct Pair a) { + return a; +} uint random_gen_java(ulong seed) { seed = (seed * 0x5DEECE66DL + 0xBL) & ((1L << 48L) - 1); diff --git a/src/opencl/kernels/mxv.cl b/src/opencl/kernels/mxv.cl index e7bc77337..c61aa8478 100644 --- a/src/opencl/kernels/mxv.cl +++ b/src/opencl/kernels/mxv.cl @@ -112,7 +112,13 @@ __kernel void mxv_scalar(__global const uint* g_Ap, const uint col_id = g_Aj[i]; sum = OP_BINARY2(sum, OP_BINARY1(g_Ax[i], g_vx[col_id])); - if (early_exit && (sum != init)) break; + if (early_exit) { + #ifdef USE_PAIR_COMPARISON + if (sum.weight != init.weight || sum.vertex != init.vertex) break; + #else + if (sum != init) break; + #endif + } } } @@ -162,7 +168,13 @@ __kernel void mxv_config_scalar(__global const uint* g_Ap, const uint col_id = g_Aj[i]; sum = OP_BINARY2(sum, OP_BINARY1(g_Ax[i], g_vx[col_id])); - if (early_exit && (sum != init)) break; + if (early_exit) { + #ifdef USE_PAIR_COMPARISON + if (sum.weight != init.weight || sum.vertex != init.vertex) break; + #else + if (sum != init) break; + #endif + } } g_rx[row_id] = sum; From 5b54ef77a28c2a7f2ba725631b11bc17b21d585d Mon Sep 17 00:00:00 2001 From: polka777 Date: Wed, 22 Apr 2026 15:46:53 +0300 Subject: [PATCH 04/14] feat(io) support read matrix with weight --- include/spla/io.hpp | 2 + src/io.cpp | 128 ++++++++++++++++++++++++++++++++------------ 2 files changed, 96 insertions(+), 34 deletions(-) diff --git a/include/spla/io.hpp b/include/spla/io.hpp index 90b0d5c54..9bfceda62 100644 --- a/include/spla/io.hpp +++ b/include/spla/io.hpp @@ -80,6 +80,7 @@ namespace spla { [[nodiscard]] SPLA_API const std::vector& get_Ai() const; [[nodiscard]] SPLA_API const std::vector& get_Aj() const; + [[nodiscard]] SPLA_API const std::vector& get_Aw() const; [[nodiscard]] SPLA_API uint get_n_rows() const; [[nodiscard]] SPLA_API uint get_n_cols() const; [[nodiscard]] SPLA_API std::size_t get_n_values() const; @@ -89,6 +90,7 @@ namespace spla { std::filesystem::path m_file_path; std::vector m_Ai; std::vector m_Aj; + std::vector m_Aw; bool m_base_is_zero = false; uint m_n_rows = 0; uint m_n_cols = 0; diff --git a/src/io.cpp b/src/io.cpp index 9a962087d..d21a79fd6 100644 --- a/src/io.cpp +++ b/src/io.cpp @@ -73,6 +73,11 @@ namespace spla { std::stringstream header(line); header >> m_n_rows >> m_n_cols >> nnz; + bool file_has_values = false; + if (line.find("pattern") == std::string::npos) { //есть подстрока pattern => граф невзвешенный + file_has_values = true; + } + std::cout << "Loading matrix-market coordinate format data... " << std::endl; std::cout << " Reading from " << m_file_path << std::endl; std::cout << " Matrix size " << m_n_rows << " rows, " << m_n_cols << " cols" << std::endl; @@ -94,10 +99,12 @@ namespace spla { std::size_t to_preallocate = to_read * (make_undirected ? 2 : 1); std::vector Ai; std::vector Aj; + std::vector Av; // preallocate to avoid copy Ai.reserve(to_preallocate); Aj.reserve(to_preallocate); + if (file_has_values) Av.reserve(to_preallocate); float job_done = 0.0f; float job_total = 35.0f; @@ -152,6 +159,15 @@ namespace spla { char* end = nullptr; auto i = uint(std::strtoll(buffer + buffer_offset, &end, 10)); auto j = uint(std::strtoll(end, &end, 10)); + float val = 1.0f; //default value + + if (file_has_values) { + char* next = end; + while (*next == ' ' || *next == '\t') next++; + if (*next != '\n' && *next != '\0') { + val = static_cast(std::strtod(next, &end)); + } + } buffer_offset = line_end + 1; assert(i > 0 && j > 0); @@ -166,53 +182,94 @@ namespace spla { if (make_undirected) { Ai.push_back(j); Aj.push_back(i); + if (file_has_values) Av.push_back(val); } Ai.push_back(i); Aj.push_back(j); + if (file_has_values) Av.push_back(val); } t.lap_end();// parsing - std::vector sorted; - { - sorted.reserve(Ai.size()); - n_sort = Ai.size(); - + if (file_has_values) { + struct Edge { + uint i, j; + float w; + bool operator<(const Edge& other) const { + if (i != other.i) return i < other.i; + return j < other.j; + } + }; + + std::vector edges; + edges.reserve(Ai.size()); for (std::size_t k = 0; k < Ai.size(); k++) { - std::uint64_t entry = 0; - entry |= std::uint64_t(Ai[k]) << 32u; - entry |= std::uint64_t(Aj[k]) << 0u; - sorted.push_back(entry); + edges.push_back({Ai[k], Aj[k], Av[k]}); } - Ai.clear(); - Aj.clear(); - - std::sort(sorted.begin(), sorted.end()); - } - t.lap_end();// sorting - - std::vector reduced_Ai; - std::vector reduced_Aj; - { - reduced_Ai.reserve(sorted.size()); - reduced_Aj.reserve(sorted.size()); - - std::uint64_t entry_prev = 0xffffffffffffffff; - for (std::uint64_t entry : sorted) { - if (entry_prev != entry) { - uint i = uint((entry >> 32u) & 0xffffffff); - uint j = uint((entry >> 0u) & 0xffffffff); - reduced_Ai.push_back(i); - reduced_Aj.push_back(j); + + std::sort(edges.begin(), edges.end()); + + std::vector reduced_Ai; + std::vector reduced_Aj; + std::vector reduced_Av; + reduced_Ai.reserve(edges.size()); + reduced_Aj.reserve(edges.size()); + reduced_Av.reserve(edges.size()); + + for (std::size_t k = 0; k < edges.size(); k++) { + if (k == 0 || edges[k].i != edges[k-1].i || edges[k].j != edges[k-1].j) { + reduced_Ai.push_back(edges[k].i); + reduced_Aj.push_back(edges[k].j); + reduced_Av.push_back(edges[k].w); } - entry_prev = entry; } - + m_n_values = reduced_Ai.size(); - m_Ai = std::move(reduced_Ai); - m_Aj = std::move(reduced_Aj); + m_Ai = std::move(reduced_Ai); + m_Aj = std::move(reduced_Aj); + m_Aw = std::move(reduced_Av); + + } else { + std::vector sorted; + { + sorted.reserve(Ai.size()); + n_sort = Ai.size(); + + for (std::size_t k = 0; k < Ai.size(); k++) { + std::uint64_t entry = 0; + entry |= std::uint64_t(Ai[k]) << 32u; + entry |= std::uint64_t(Aj[k]) << 0u; + sorted.push_back(entry); + } + Ai.clear(); + Aj.clear(); + + std::sort(sorted.begin(), sorted.end()); + } + t.lap_end();// sorting + + std::vector reduced_Ai; + std::vector reduced_Aj; + { + reduced_Ai.reserve(sorted.size()); + reduced_Aj.reserve(sorted.size()); + + std::uint64_t entry_prev = 0xffffffffffffffff; + for (std::uint64_t entry : sorted) { + if (entry_prev != entry) { + uint i = uint((entry >> 32u) & 0xffffffff); + uint j = uint((entry >> 0u) & 0xffffffff); + reduced_Ai.push_back(i); + reduced_Aj.push_back(j); + } + entry_prev = entry; + } + + m_n_values = reduced_Ai.size(); + m_Ai = std::move(reduced_Ai); + m_Aj = std::move(reduced_Aj); + } } - t.lap_end();// reducing calc_stats(); t.lap_end();// stats @@ -367,6 +424,9 @@ namespace spla { const std::vector& MtxLoader::get_Aj() const { return m_Aj; } + const std::vector& MtxLoader::get_Aw() const { + return m_Aw; + } uint MtxLoader::get_n_rows() const { return m_n_rows; From c5452bcdec03e20ea0ece82b654a701736604b19 Mon Sep 17 00:00:00 2001 From: polka777 Date: Wed, 22 Apr 2026 15:48:44 +0300 Subject: [PATCH 05/14] feat(mst) add algorithm mst, executable to CMakeLists --- CMakeLists.txt | 1 + examples/mst.cpp | 124 +++++++++++++++ include/spla/algorithm.hpp | 20 +++ src/algorithm.cpp | 306 +++++++++++++++++++++++++++++++++++++ 4 files changed, 451 insertions(+) create mode 100644 examples/mst.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 8c8cfa4f0..54f890fb5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -402,6 +402,7 @@ if (SPLA_BUILD_EXAMPLES) spla_example_application(tc) spla_example_application(pi) spla_example_application(convert) + spla_example_application(mst) endif () ###################################################################### diff --git a/examples/mst.cpp b/examples/mst.cpp new file mode 100644 index 000000000..c14ceee5a --- /dev/null +++ b/examples/mst.cpp @@ -0,0 +1,124 @@ +/**********************************************************************************/ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ +/**********************************************************************************/ +/* MIT License */ +/* */ +/* Copyright (c) 2023 SparseLinearAlgebra */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy */ +/* of this software and associated documentation files (the "Software"), to deal */ +/* in the Software without restriction, including without limitation the rights */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be included in all */ +/* copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ +/* SOFTWARE. */ +/**********************************************************************************/ + +#include "common.hpp" +#include "options.hpp" + +#include +#include + +int main(int argc, const char* const* argv) { + auto options = make_options("mst", "Boruvka's Minimum Spanning Tree algorithm with spla library"); + + cxxopts::ParseResult args; + int ret; + + if (parse_options(argc, argv, options, args, ret)) { + std::cerr << "failed to parse options" << std::endl; + return ret; + } + + spla::Timer timer_total; + spla::Timer timer_gpu; + spla::Timer timer_ref; + spla::MtxLoader loader; + + timer_total.start(); + + if (!loader.load(args["mtxpath"].as())) { + std::cerr << "failed to load graph"; + return 1; + } + + std::string acc_info; + spla::Library* library = spla::Library::get(); + + library->set_platform(args["platform"].as()); + library->set_device(args["device"].as()); + library->set_queues_count(1); + library->get_accelerator_info(acc_info); + std::cout << "env: " << acc_info << std::endl; + + const spla::uint N = loader.get_n_rows(); + auto S = spla::Matrix::make(N, N, spla::PAIR); + + const auto& Ai = loader.get_Ai(); + const auto& Aj = loader.get_Aj(); + const auto& Aw = loader.get_Aw(); + + for (std::size_t k = 0; k < loader.get_n_values(); ++k) { + S->set_pair(Ai[k], Aj[k], spla::T_PAIR(Aw[k], Aj[k])); + } + + auto T_gpu = spla::Matrix::make(N, N, spla::FLOAT); + + auto desc = spla::Descriptor::make(); + + const int n_iters = args["niters"].as(); + + double total_weight_gpu = 0.0; + + if (args["run-gpu"].as()) { + library->set_force_no_acceleration(false); + + for (int i = 0; i < n_iters; ++i) { + T_gpu->clear(); + S = spla::Matrix::make(N, N, spla::PAIR); + for (std::size_t k = 0; k < loader.get_n_values(); ++k) { + S->set_pair(Ai[k], Aj[k], spla::T_PAIR(Aw[k], Aj[k])); + } + timer_gpu.lap_begin(); + spla::mst(T_gpu, S, desc, nullptr); + timer_gpu.lap_end(); + } + + total_weight_gpu = 0; + for (spla::uint i = 0; i < N; ++i) { + for (spla::uint j = i + 1; j < N; ++j) { + float w; + T_gpu->get_float(i, j, w); + if (w != 0.0) { + total_weight_gpu += w; + } + } + } + + std::cout << "GPU MST total weight: " << total_weight_gpu << std::endl; + } + + spla::Library::get()->finalize(); + + timer_total.stop(); + + std::cout << "\n=== Timing Results ===" << std::endl; + std::cout << "total(ms):" << timer_total.get_elapsed_ms() << std::endl; + std::cout << "gpu(ms): "; + timer_gpu.print(); + std::cout << std::endl; + + return 0; +} \ No newline at end of file diff --git a/include/spla/algorithm.hpp b/include/spla/algorithm.hpp index a26a2dbba..321771c7a 100644 --- a/include/spla/algorithm.hpp +++ b/include/spla/algorithm.hpp @@ -33,6 +33,7 @@ #include "matrix.hpp" #include "scalar.hpp" #include "vector.hpp" +#include "schedule.hpp" namespace spla { @@ -173,6 +174,25 @@ namespace spla { int& ntrins, std::vector>& Ai, const ref_ptr& descriptor = spla::Descriptor::make()); +/** + * @brief Boruvka's Minimum Spanning Tree algorithm + * + * Finds the Minimum Spanning Tree of a weighted undirected graph using + * Boruvka's algorithm with algebraic operations. + * + * @param T float matrix to store MST edges (result). Only upper triangle is used. + * @param S PAIR matrix adjacency matrix with edges (weight, vertex). + * The vertex field stores the target vertex of the edge. + * @param descriptor optional descriptor for algorithm configuration + * @param task_hnd optional pointer to store task handle for async execution + * + * @return ok on success + */ + SPLA_API Status mst( + const ref_ptr& T, + ref_ptr& S, + const ref_ptr& descriptor = spla::Descriptor::make(), + ref_ptr* task_hnd = nullptr); /** * @} diff --git a/src/algorithm.cpp b/src/algorithm.cpp index 189ea4b45..26326e7a7 100644 --- a/src/algorithm.cpp +++ b/src/algorithm.cpp @@ -37,6 +37,10 @@ #include #include #include +#include +#include + +#define INF std::numeric_limits::infinity() namespace spla { @@ -448,5 +452,307 @@ namespace spla { } #pragma endregion Pr +#pragma region Mst +Status mst( + const ref_ptr& T, + ref_ptr& S, + const ref_ptr& descriptor, + ref_ptr* task_hnd) { + + assert(S); + assert(T); + + struct timespec step_start, step_end; + double step_time; + + const auto n = S->get_n_rows(); + int comp = n; + + auto parent = Vector::make(n, PAIR); + for (uint i = 0; i < n; i++) { + parent->set_pair(i, T_PAIR(0.0f, i)); + } + auto edge = Vector::make(n, PAIR); + auto cedge = Vector::make(n, PAIR); + auto t_vec = Vector::make(n, PAIR); + auto mask = Vector::make(n, PAIR); + for (uint i = 0; i < n; i++) { + mask->set_pair(i, T_PAIR(1.0f, 0)); + } + auto init_inf = Scalar::make(PAIR); + T_PAIR init_val; + init_inf->set_pair(init_val); + int iteration = 0; + auto new_S = S; +#ifdef SPLA_RELEASE + std::cout << "start Boruvka MST, vertices = " << n << "\n"; + Timer tight; +#endif + + while (comp > 1) { +#ifdef SPLA_RELEASE + tight.start(); +#endif + iteration++; + int edges_added_this_iteration = 0; + // step 1, min edges for each vertices + clock_gettime(CLOCK_MONOTONIC, &step_start); + spla::exec_mxv_masked(edge, mask, S, parent, spla::MUL_PAIR, spla::MIN_PAIR, spla::ALWAYS_PAIR, init_inf); + clock_gettime(CLOCK_MONOTONIC, &step_end); + step_time = (step_end.tv_sec - step_start.tv_sec) + + (step_end.tv_nsec - step_start.tv_nsec) / 1e9; + std::cout << "--- Step 1 (min edges for each vertex, gpu): " << step_time * 1000 << " ms" << std::endl; + #ifdef SPLA_DEBUG + + std::cout << "edge = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + edge->get_pair(i, p); + std::cout << "(" << p.weight << ", " << p.vertex << "), "; + } + std::cout << "]\n"; +#endif + // step 2, min edges for each component + clock_gettime(CLOCK_MONOTONIC, &step_start); + + for (int32_t i = 0; i < n; i++) { + cedge->set_pair(i, init_val); + } + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + spla::T_PAIR p1; + spla::T_PAIR p2; + parent->get_pair(i, p); + auto p_i = p.vertex; // p_i = parent[i] + cedge->get_pair(p_i, p1); // p1 = cedge[parent[i]] + edge->get_pair(i, p2); // p2 = edge[i] + auto min_for_comp = p1.weight <= p2.weight? p1 : p2; // min(cedge[parent[i]], edge[i]) + cedge->set_pair(p_i, min_for_comp); + } + clock_gettime(CLOCK_MONOTONIC, &step_end); + step_time = (step_end.tv_sec - step_start.tv_sec) + + (step_end.tv_nsec - step_start.tv_nsec) / 1e9; + std::cout << "--- Step 2 (min edges for each component): " << step_time * 1000 << " ms" << std::endl; + +#ifdef SPLA_DEBUG + std::cout << "cedge = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + cedge->get_pair(i, p); + std::cout << "(" << p.weight << ", " << p.vertex << "), "; + } + std::cout << "]\n"; +#endif + // step 3, когда нашли лучшее ребро компоненты распространяем его на все вершины + clock_gettime(CLOCK_MONOTONIC, &step_start); + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR parent_v; + spla::T_PAIR cedge_v; + parent->get_pair(i, parent_v); + cedge->get_pair(parent_v.vertex, cedge_v); + t_vec->set_pair(i, cedge_v); //t[i] = cedge[parent[i]] + } + clock_gettime(CLOCK_MONOTONIC, &step_end); + step_time = (step_end.tv_sec - step_start.tv_sec) + + (step_end.tv_nsec - step_start.tv_nsec) / 1e9; + std::cout << "--- Step 3: " << step_time * 1000 << " ms" << std::endl; + + //step 4 выбор представителя для каждой компоненты(когда лучшее ребро в edges совпадает с лучшим ребром компоненты t) + clock_gettime(CLOCK_MONOTONIC, &step_start); + auto index = spla::Vector::make(n, spla::INT); + + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR edge_v, t_v; + edge->get_pair(i, edge_v); + t_vec->get_pair(i, t_v); + if (edge_v == t_v) index->set_int(i, i); + else index->set_int(i, n); + } + auto temp = spla::Vector::make(n, spla::INT); + for (int32_t i = 0; i < n; i++) temp->set_int(i, n); + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR parent_v; + parent->get_pair(i, parent_v); + auto p_i = parent_v.vertex; + spla::T_INT temp_v, ind_v; + temp->get_int(p_i, temp_v); + index->get_int(i, ind_v); + spla::T_INT min_v = temp_v < ind_v? temp_v : ind_v; + temp->set_int(p_i, min_v); //temp[parent[i]] = min(temp[parent[i]], index[i]) + } +#ifdef SPLA_DEBUG + std::cout << "t = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_INT p; + temp->get_int(i, p); + std::cout << p << ", "; + } + std::cout << "]\n"; +#endif + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR parent_v; + parent->get_pair(i, parent_v); + auto p_i = parent_v.vertex; + spla::T_INT temp_v; + temp->get_int(p_i, temp_v); + index->set_int(i, temp_v); + } + clock_gettime(CLOCK_MONOTONIC, &step_end); + step_time = (step_end.tv_sec - step_start.tv_sec) + + (step_end.tv_nsec - step_start.tv_nsec) / 1e9; + std::cout << "--- Step 4(выбор представителя для каждой компоненты): " << step_time * 1000 << " ms" << std::endl; + +#ifdef SPLA_DEBUG + std::cout << "index = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_INT p; + index->get_int(i, p); + std::cout << p << ", "; + } + std::cout << "]\n"; +#endif + //step 5 добавляем найденные ребра в MST + clock_gettime(CLOCK_MONOTONIC, &step_start); + auto new_parent = spla::Vector::make(n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + parent->get_pair(i, p); + new_parent->set_pair(i, p); + } + for (int32_t i = 0; i < n; i++) { + spla::T_INT ind_v; + index->get_int(i, ind_v); + if (i == ind_v) { + auto row = spla::Vector::make(n, spla::PAIR); + spla::exec_m_extract_row(row, S, i, spla::IDENTITY_PAIR); + int min_vertex = -1; + float min_weight = INF; + + for (int32_t j = 0; j < n; j++) { + spla::T_PAIR pair_row; + row->get_pair(j, pair_row); + auto pair_row_weight = pair_row.weight; + auto pair_row_vertex = pair_row.vertex; + if (pair_row_weight < INF) { + spla::T_PAIR p1, p2; + parent->get_pair(i, p1); + parent->get_pair(pair_row_vertex, p2); + if (p1.vertex != p2.vertex) { //разные компоненты + if (pair_row_weight < min_weight) { + min_weight = pair_row_weight; + min_vertex = j; + } + + } + } + + + } + if (min_vertex == -1) continue; + T->set_float(i, min_vertex, min_weight); + T->set_float(min_vertex, i, min_weight); + edges_added_this_iteration++; + if (i < min_vertex) { + spla::T_PAIR p; + spla::T_PAIR old_p; + new_parent->get_pair(i, p); + new_parent->get_pair(min_vertex, old_p); + new_parent->set_pair(min_vertex, spla::T_PAIR(0.0f, p.vertex)); + for (int k = 0; k < n; k++) { + spla::T_PAIR p1; + new_parent->get_pair(k, p1); + if (p1.vertex == old_p.vertex) new_parent->set_pair(k, spla::T_PAIR(0.0f, p.vertex)); + } + } + else { + spla::T_PAIR p; + spla::T_PAIR old_p; + new_parent->get_pair(min_vertex, p); + new_parent->get_pair(i, old_p); + new_parent->set_pair(i, spla::T_PAIR(0.0f, p.vertex)); + for (int k = 0; k < n; k++) { + spla::T_PAIR p1; + new_parent->get_pair(k, p1); + if (p1.vertex == old_p.vertex) new_parent->set_pair(k, spla::T_PAIR(0.0f, p.vertex)); + } + } + + } + } + parent = new_parent; + std::vector seen(n, false); + for (uint i = 0; i < n; i++) { + T_PAIR p; + parent->get_pair(i, p); + seen[p.vertex] = true; + } + + comp = 0; + for (uint i = 0; i < n; i++) { + if (seen[i]) comp++; + } + clock_gettime(CLOCK_MONOTONIC, &step_end); + step_time = (step_end.tv_sec - step_start.tv_sec) + + (step_end.tv_nsec - step_start.tv_nsec) / 1e9; + std::cout << "--- Step 5(добавляем найденные ребра в MST): " << step_time * 1000 << " ms" << std::endl; + + +#ifdef SPLA_DEBUG + std::cout << "parent = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + parent->get_pair(i, p); + std::cout << p.vertex << ", "; + } + std::cout << "]\n"; +#endif +#ifdef SPLA_RELEASE + tight.stop(); + std::cout << " - iteration " << iteration + << " components " << comp + << " " << tight.get_elapsed_ms() << " ms" << std::endl; + Library::get()->time_profile_dump(); + Library::get()->time_profile_reset(); +#endif + if (comp == 1) { + std::cout << "MST complete after " << iteration << " iterations" << std::endl; + return Status::Ok; + } + //обновляем матрицу смежности + clock_gettime(CLOCK_MONOTONIC, &step_start); + auto filtered_S = spla::Matrix::make(n, n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + for (int32_t j = 0; j < n; j++) { + spla::T_PAIR val; + S->get_pair(i, j, val); + if (val.weight != std::numeric_limits::infinity()) { + spla::T_PAIR parent_i, parent_j; + parent->get_pair(i, parent_i); + parent->get_pair(j, parent_j); + if ((parent_i.vertex != parent_j.vertex) && (val.weight != std::numeric_limits::infinity())) { + filtered_S->set_pair(i, j, val); + } + + } + } + } + S = filtered_S; + if (edges_added_this_iteration == 0) { + return Status::Ok; + } + clock_gettime(CLOCK_MONOTONIC, &step_end); + step_time = (step_end.tv_sec - step_start.tv_sec) + + (step_end.tv_nsec - step_start.tv_nsec) / 1e9; + std::cout << "--- Step 6(добавляем S): " << step_time * 1000 << " ms" << std::endl; + + + + + } + return Status::Ok; + } + +#pragma endregion Mst + }// namespace spla From f9c95a4c6a324fdc25e1df4a635124892d477c15 Mon Sep 17 00:00:00 2001 From: polka777 Date: Wed, 22 Apr 2026 15:50:40 +0300 Subject: [PATCH 06/14] feat(tests) test pair for operations, type registration and connection with spla primitives --- tests/CMakeLists.txt | 1 + tests/test_pair.cpp | 148 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 tests/test_pair.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e20657a89..244ea481f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -15,6 +15,7 @@ spla_test_target(test_mxmT) spla_test_target(test_mxv) spla_test_target(test_vxm) spla_test_target(test_op) +spla_test_target(test_pair) if (SPLA_BUILD_OPENCL) spla_test_target(test_opencl) spla_test_target(test_opencl_merge) diff --git a/tests/test_pair.cpp b/tests/test_pair.cpp new file mode 100644 index 000000000..9b91954b9 --- /dev/null +++ b/tests/test_pair.cpp @@ -0,0 +1,148 @@ + #include "test_common.hpp" +#include + +#include "spla.hpp" + +TEST(pair, struct_creation) { + spla::T_PAIR p(2.5f, 2); + EXPECT_EQ(p.weight, 2.5f); + EXPECT_EQ(p.vertex, 2); +} +TEST(pair, basic_operations) { + spla::T_PAIR p1(2.5f, 2); + spla::T_PAIR p2(1.5f, 5); + + EXPECT_EQ(p1.weight, 2.5f); + EXPECT_EQ(p1.vertex, 2); + EXPECT_TRUE(p2.weight < p1.weight); +} + +TEST(pair, type_registration) { + auto type = spla::PAIR; + ASSERT_TRUE(type); + EXPECT_EQ(type->get_name(), "PAIR"); + EXPECT_EQ(type->get_code(), "P"); + EXPECT_EQ(type->get_cpp(), "struct Pair"); + EXPECT_EQ(type->get_description(), "weight-vertex pair float-int"); + EXPECT_EQ(type->get_size(), sizeof(spla::Pair)); + EXPECT_EQ(type->get_id(), 5); +} + +TEST(pair, op_registration) { + spla::Library::get(); + EXPECT_EQ(spla::MIN_PAIR->get_name(), "MIN_PAIR"); +} +TEST(pair, set_get_pair_matrix) { + auto S = spla::Matrix::make(2, 2, spla::PAIR); + S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); + spla::T_PAIR pair; + S->get_pair(0, 1, pair); + EXPECT_EQ(pair.vertex, 1); + EXPECT_EQ(pair.weight, 7.0f); + +} +TEST(pair, set_get_pair_vector) { + auto V = spla::Vector::make(2, spla::PAIR); + V->set_pair(0, spla::T_PAIR(1.0f, 1)); + V->set_pair(1, spla::T_PAIR(2.0f, 2)); + spla::T_PAIR pair; + V->get_pair(0, pair); + EXPECT_EQ(pair.vertex, 1); + EXPECT_EQ(pair.weight, 1.0f); + V->get_pair(1, pair); + EXPECT_EQ(pair.vertex, 2); + EXPECT_EQ(pair.weight, 2.0f); + +} +TEST(pair, set_get_pair_scalar) { + auto V = spla::Scalar::make(spla::PAIR); + spla::T_PAIR pair = spla::T_PAIR(1.0f, 1); + V->set_pair(pair); + V->get_pair(pair); + EXPECT_EQ(pair.vertex, 1); + EXPECT_EQ(pair.weight, 1.0f); +} + +TEST(pair, mxv_pair) { + int32_t n = 7; + auto S = spla::Matrix::make(n, n, spla::PAIR); + S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); + S->set_pair(0, 4, spla::T_PAIR(4.0f, 4)); + S->set_pair(1, 0, spla::T_PAIR(7.0f, 0)); + S->set_pair(1, 2, spla::T_PAIR(11.0f, 2)); + S->set_pair(1, 3, spla::T_PAIR(10.0f, 3)); + S->set_pair(1, 4, spla::T_PAIR(9.0f, 4)); + S->set_pair(2, 1, spla::T_PAIR(11.0f, 1)); + S->set_pair(2, 3, spla::T_PAIR(5.0f, 3)); + S->set_pair(3, 1, spla::T_PAIR(10.0f, 1)); + S->set_pair(3, 2, spla::T_PAIR(5.0f, 2)); + S->set_pair(3, 4, spla::T_PAIR(15.0f, 4)); + S->set_pair(3, 5, spla::T_PAIR(12.0f, 5)); + S->set_pair(3, 6, spla::T_PAIR(8.0f, 6)); + S->set_pair(4, 0, spla::T_PAIR(4.0f, 0)); + S->set_pair(4, 1, spla::T_PAIR(9.0f, 1)); + S->set_pair(4, 3, spla::T_PAIR(15.0f, 3)); + S->set_pair(4, 5, spla::T_PAIR(6.0f, 5)); + S->set_pair(5, 3, spla::T_PAIR(12.0f, 3)); + S->set_pair(5, 4, spla::T_PAIR(6.0f, 4)); + S->set_pair(5, 6, spla::T_PAIR(13.0f, 6)); + S->set_pair(6, 3, spla::T_PAIR(8.0f, 3)); + S->set_pair(6, 5, spla::T_PAIR(13.0f, 5)); + + auto parent = spla::Vector::make(n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + parent->set_pair(i, spla::T_PAIR(0.0f, i)); + } + auto edge = spla::Vector::make(n, spla::PAIR); + + + auto mask = spla::Vector::make(n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + mask->set_pair(i, spla::T_PAIR(0.0f, 0)); + } + auto init_inf = spla::Scalar::make(spla::PAIR); + spla::T_PAIR init_val(1e9f, -1); + init_inf->set_pair(init_val); + spla::exec_mxv_masked(edge, mask, S, parent, spla::MUL_PAIR, spla::MIN_PAIR, spla::ALWAYS_PAIR, init_inf); + + spla::T_PAIR expected[] = { + spla::T_PAIR(4.0f, 4), + spla::T_PAIR(7.0f, 0), + spla::T_PAIR(5.0f, 3), + spla::T_PAIR(5.0f, 2), + spla::T_PAIR(4.0f, 0), + spla::T_PAIR(6.0f, 4), + spla::T_PAIR(8.0f, 3) + }; + + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + edge->get_pair(i, p); + EXPECT_FLOAT_EQ(p.weight, expected[i].weight); + EXPECT_EQ(p.vertex, expected[i].vertex); + } +} +TEST(pair, extract_row) { + int32_t n = 3; + auto S = spla::Matrix::make(n, n, spla::PAIR); + S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); + S->set_pair(1, 0, spla::T_PAIR(4.0f, 4)); + S->set_pair(1, 2, spla::T_PAIR(7.0f, 0)); + + auto row1 = spla::Vector::make(n, spla::PAIR); + spla::exec_m_extract_row(row1, S, 1, spla::IDENTITY_PAIR); + spla::T_PAIR expected[] = { + spla::T_PAIR(4.0f, 4), + spla::T_PAIR(), + spla::T_PAIR(7.0f, 0) + }; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + row1->get_pair(i, p); + EXPECT_FLOAT_EQ(p.weight, expected[i].weight); + EXPECT_EQ(p.vertex, expected[i].vertex); + } + +} + +SPLA_GTEST_MAIN From 239aea56fcf6baabe74cb0a04f7e05a8d479d84d Mon Sep 17 00:00:00 2001 From: polka777 Date: Thu, 23 Apr 2026 18:08:03 +0300 Subject: [PATCH 07/14] fix: tests --- src/opencl/cl_mxv.hpp | 11 +++++++---- src/opencl/cl_program_builder.cpp | 17 ++++++++++++----- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/opencl/cl_mxv.hpp b/src/opencl/cl_mxv.hpp index 6f1602ec1..862ae35b2 100644 --- a/src/opencl/cl_mxv.hpp +++ b/src/opencl/cl_mxv.hpp @@ -260,13 +260,16 @@ namespace spla { .add_define("WARP_SIZE", get_acc_cl()->get_wave_size()) .add_define("BLOCK_SIZE", m_block_size) .add_define("BLOCK_COUNT", m_block_count) - .add_type("TYPE", get_ttype().template as()) + .add_type("TYPE", get_ttype().template as()); + + if constexpr (std::is_same_v) { + program_builder.add_define("USE_PAIR_SEMANTICS", 1); + program_builder.add_define("USE_PAIR_COMPARISON", 1); + } else { + program_builder .add_op("OP_BINARY1", op_multiply.template as()) .add_op("OP_BINARY2", op_add.template as()) .add_op("OP_SELECT", op_select.template as()); - - if constexpr (std::is_same_v) { - program_builder.add_define("USE_PAIR_COMPARISON", 1); } program_builder.set_source(source_mxv).acquire(); program = program_builder.get_program(); diff --git a/src/opencl/cl_program_builder.cpp b/src/opencl/cl_program_builder.cpp index 5599744bf..e40b80a94 100644 --- a/src/opencl/cl_program_builder.cpp +++ b/src/opencl/cl_program_builder.cpp @@ -83,23 +83,30 @@ namespace spla { } std::stringstream builder; - + bool needs_pair_override = false; for (const auto& define : m_defines) { builder << "#define " << define.first << " " << define.second << "\n"; + if (define.first == "TYPE" && define.second.find("Pair") != std::string::npos) { + needs_pair_override = true; + } } builder << source_common_api; + if (needs_pair_override) { + builder << "#define OP_BINARY1(a, b) make_pair((a).weight, (b).vertex)\n\n"; + builder << "#define OP_BINARY2(a, b) min_pair(a, b)\n\n"; + builder << "#define OP_SELECT(a) pair_always(a)\n\n"; + } + for (const auto& function : m_functions) { - builder << "#define " << function.first << function.second->get_source_cl() << "\n"; + builder << function.second->get_type_res()->get_cpp() << " " + << function.first << function.second->get_source_cl() << "\n"; } builder << m_source; m_program_code = builder.str(); - m_program = std::make_shared(); - m_program->m_program = cl::Program(acc->get_context(), m_program_code); - Timer t; t.start(); auto status = m_program->m_program.build("-cl-std=CL1.2"); From 7752ae894cadea8f12c92fea8642ca743a6d5e22 Mon Sep 17 00:00:00 2001 From: polka777 Date: Fri, 24 Apr 2026 20:16:41 +0300 Subject: [PATCH 08/14] fix clang tidy code style fix clang tidy code style2 fix algorithm.cpp cland tidy 3 --- examples/mst.cpp | 220 +++--- include/spla.hpp | 2 +- include/spla/algorithm.hpp | 339 +++++---- include/spla/io.hpp | 4 +- include/spla/matrix.hpp | 6 +- include/spla/op.hpp | 18 +- include/spla/pair.hpp | 35 +- include/spla/scalar.hpp | 10 +- include/spla/vector.hpp | 4 +- src/algorithm.cpp | 1125 ++++++++++++++--------------- src/binding/c_op.cpp | 457 ++++++++---- src/binding/c_type.cpp | 8 +- src/core/tmatrix.hpp | 720 +++++++++--------- src/core/tscalar.hpp | 388 +++++----- src/core/ttype.hpp | 6 +- src/core/tvector.hpp | 809 +++++++++++---------- src/cpu/cpu_algo_registry.cpp | 4 +- src/io.cpp | 818 +++++++++++---------- src/matrix.cpp | 2 +- src/op.cpp | 772 ++++++++++---------- src/opencl/cl_algo_registry.cpp | 5 +- src/opencl/cl_mxv.hpp | 537 +++++++------- src/opencl/cl_program_builder.cpp | 21 +- src/opencl/kernels/common_def.cl | 95 +-- src/scalar.cpp | 108 +-- src/type.cpp | 62 +- src/vector.cpp | 97 +-- tests/test_pair.cpp | 223 +++--- 28 files changed, 3572 insertions(+), 3323 deletions(-) diff --git a/examples/mst.cpp b/examples/mst.cpp index c14ceee5a..23b31d051 100644 --- a/examples/mst.cpp +++ b/examples/mst.cpp @@ -1,124 +1,132 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include "common.hpp" #include "options.hpp" -#include #include +#include -int main(int argc, const char* const* argv) { - auto options = make_options("mst", "Boruvka's Minimum Spanning Tree algorithm with spla library"); - - cxxopts::ParseResult args; - int ret; - - if (parse_options(argc, argv, options, args, ret)) { - std::cerr << "failed to parse options" << std::endl; - return ret; - } - - spla::Timer timer_total; - spla::Timer timer_gpu; - spla::Timer timer_ref; - spla::MtxLoader loader; - - timer_total.start(); - - if (!loader.load(args["mtxpath"].as())) { - std::cerr << "failed to load graph"; - return 1; - } - - std::string acc_info; - spla::Library* library = spla::Library::get(); - - library->set_platform(args["platform"].as()); - library->set_device(args["device"].as()); - library->set_queues_count(1); - library->get_accelerator_info(acc_info); - std::cout << "env: " << acc_info << std::endl; - - const spla::uint N = loader.get_n_rows(); - auto S = spla::Matrix::make(N, N, spla::PAIR); - - const auto& Ai = loader.get_Ai(); - const auto& Aj = loader.get_Aj(); - const auto& Aw = loader.get_Aw(); - - for (std::size_t k = 0; k < loader.get_n_values(); ++k) { +int main(int argc, const char *const *argv) { + auto options = make_options( + "mst", "Boruvka's Minimum Spanning Tree algorithm with spla library"); + + cxxopts::ParseResult args; + int ret; + + if (parse_options(argc, argv, options, args, ret)) { + std::cerr << "failed to parse options" << std::endl; + return ret; + } + + spla::Timer timer_total; + spla::Timer timer_gpu; + spla::Timer timer_ref; + spla::MtxLoader loader; + + timer_total.start(); + + if (!loader.load(args["mtxpath"].as())) { + std::cerr << "failed to load graph"; + return 1; + } + + std::string acc_info; + spla::Library *library = spla::Library::get(); + + library->set_platform(args["platform"].as()); + library->set_device(args["device"].as()); + library->set_queues_count(1); + library->get_accelerator_info(acc_info); + std::cout << "env: " << acc_info << std::endl; + + const spla::uint N = loader.get_n_rows(); + auto S = spla::Matrix::make(N, N, spla::PAIR); + + const auto &Ai = loader.get_Ai(); + const auto &Aj = loader.get_Aj(); + const auto &Aw = loader.get_Aw(); + + for (std::size_t k = 0; k < loader.get_n_values(); ++k) { + S->set_pair(Ai[k], Aj[k], spla::T_PAIR(Aw[k], Aj[k])); + } + + auto T_gpu = spla::Matrix::make(N, N, spla::FLOAT); + + auto desc = spla::Descriptor::make(); + + const int n_iters = args["niters"].as(); + + double total_weight_gpu = 0.0; + + if (args["run-gpu"].as()) { + library->set_force_no_acceleration(false); + + for (int i = 0; i < n_iters; ++i) { + T_gpu->clear(); + S = spla::Matrix::make(N, N, spla::PAIR); + for (std::size_t k = 0; k < loader.get_n_values(); ++k) { S->set_pair(Ai[k], Aj[k], spla::T_PAIR(Aw[k], Aj[k])); + } + timer_gpu.lap_begin(); + spla::mst(T_gpu, S, desc, nullptr); + timer_gpu.lap_end(); } - - auto T_gpu = spla::Matrix::make(N, N, spla::FLOAT); - - auto desc = spla::Descriptor::make(); - - const int n_iters = args["niters"].as(); - - double total_weight_gpu = 0.0; - - if (args["run-gpu"].as()) { - library->set_force_no_acceleration(false); - - for (int i = 0; i < n_iters; ++i) { - T_gpu->clear(); - S = spla::Matrix::make(N, N, spla::PAIR); - for (std::size_t k = 0; k < loader.get_n_values(); ++k) { - S->set_pair(Ai[k], Aj[k], spla::T_PAIR(Aw[k], Aj[k])); - } - timer_gpu.lap_begin(); - spla::mst(T_gpu, S, desc, nullptr); - timer_gpu.lap_end(); - } - - total_weight_gpu = 0; - for (spla::uint i = 0; i < N; ++i) { - for (spla::uint j = i + 1; j < N; ++j) { - float w; - T_gpu->get_float(i, j, w); - if (w != 0.0) { - total_weight_gpu += w; - } - } + + total_weight_gpu = 0; + for (spla::uint i = 0; i < N; ++i) { + for (spla::uint j = i + 1; j < N; ++j) { + float w; + T_gpu->get_float(i, j, w); + if (w != 0.0) { + total_weight_gpu += w; } - - std::cout << "GPU MST total weight: " << total_weight_gpu << std::endl; + } } - - spla::Library::get()->finalize(); - - timer_total.stop(); - - std::cout << "\n=== Timing Results ===" << std::endl; - std::cout << "total(ms):" << timer_total.get_elapsed_ms() << std::endl; - std::cout << "gpu(ms): "; - timer_gpu.print(); - std::cout << std::endl; - - return 0; + + std::cout << "GPU MST total weight: " << total_weight_gpu << std::endl; + } + + spla::Library::get()->finalize(); + + timer_total.stop(); + + std::cout << "\n=== Timing Results ===" << std::endl; + std::cout << "total(ms):" << timer_total.get_elapsed_ms() << std::endl; + std::cout << "gpu(ms): "; + timer_gpu.print(); + std::cout << std::endl; + + return 0; } \ No newline at end of file diff --git a/include/spla.hpp b/include/spla.hpp index 6c19daa2c..8a2a4cc29 100644 --- a/include/spla.hpp +++ b/include/spla.hpp @@ -39,12 +39,12 @@ #include "spla/memview.hpp" #include "spla/object.hpp" #include "spla/op.hpp" +#include "spla/pair.hpp" #include "spla/ref.hpp" #include "spla/scalar.hpp" #include "spla/schedule.hpp" #include "spla/timer.hpp" #include "spla/type.hpp" #include "spla/vector.hpp" -#include "spla/pair.hpp" #endif//SPLA_SPLA_HPP diff --git a/include/spla/algorithm.hpp b/include/spla/algorithm.hpp index 321771c7a..b373265c8 100644 --- a/include/spla/algorithm.hpp +++ b/include/spla/algorithm.hpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #ifndef SPLA_ALGORITHM_HPP @@ -32,172 +39,160 @@ #include "descriptor.hpp" #include "matrix.hpp" #include "scalar.hpp" -#include "vector.hpp" #include "schedule.hpp" +#include "vector.hpp" namespace spla { - /** - * @addtogroup spla - * @{ - */ +/** + * @addtogroup spla + * @{ + */ - /** - * @brief Breadth-first search algorithm - * - * @param v int vector to store reached distances - * @param A int matrix filled with 1 where exist edge from i to j - * @param s start vertex id to search - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status bfs( - const ref_ptr& v, - const ref_ptr& A, - uint s, - const ref_ptr& descriptor = spla::Descriptor::make()); +/** + * @brief Breadth-first search algorithm + * + * @param v int vector to store reached distances + * @param A int matrix filled with 1 where exist edge from i to j + * @param s start vertex id to search + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ +SPLA_API Status +bfs(const ref_ptr &v, const ref_ptr &A, uint s, + const ref_ptr &descriptor = spla::Descriptor::make()); - /** - * @brief Naive breadth-first search algorithm (reference cpu implementation) - * - * @param v int vector to store reached distances - * @param A int graph adjacency lists filled with 1 where exist edge from i to j - * @param s start vertex id to search - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status bfs_naive( - std::vector& v, - std::vector>& A, - uint s, - const ref_ptr& descriptor = spla::Descriptor::make()); +/** + * @brief Naive breadth-first search algorithm (reference cpu implementation) + * + * @param v int vector to store reached distances + * @param A int graph adjacency lists filled with 1 where exist edge from i to j + * @param s start vertex id to search + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ +SPLA_API Status +bfs_naive(std::vector &v, std::vector> &A, uint s, + const ref_ptr &descriptor = spla::Descriptor::make()); - /** - * @brief Single-source shortest path algorithm - * - * @param v float vector to store reached distances - * @param A float matrix filled with >0.0f distances where exist edge from i to j otherwise 0.0f - * @param s start vertex id to search - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status sssp( - const ref_ptr& v, - const ref_ptr& A, - uint s, - const ref_ptr& descriptor = ref_ptr()); +/** + * @brief Single-source shortest path algorithm + * + * @param v float vector to store reached distances + * @param A float matrix filled with >0.0f distances where exist edge from i to + * j otherwise 0.0f + * @param s start vertex id to search + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ +SPLA_API Status +sssp(const ref_ptr &v, const ref_ptr &A, uint s, + const ref_ptr &descriptor = ref_ptr()); - /** - * @brief Naive single-source shortest path algorithm (reference cpu implementation) - * - * @param v float vector to store reached distances - * @param Ai uint matrix column indices - * @param Ax float matrix values with >0.0f distances where exist edge from i to j - * @param s start vertex id to search - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status sssp_naive(std::vector& v, - std::vector>& Ai, - std::vector>& Ax, - uint s, - const ref_ptr& descriptor = spla::Descriptor::make()); +/** + * @brief Naive single-source shortest path algorithm (reference cpu + * implementation) + * + * @param v float vector to store reached distances + * @param Ai uint matrix column indices + * @param Ax float matrix values with >0.0f distances where exist edge from i to + * j + * @param s start vertex id to search + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ +SPLA_API Status +sssp_naive(std::vector &v, std::vector> &Ai, + std::vector> &Ax, uint s, + const ref_ptr &descriptor = spla::Descriptor::make()); - /** - * @brief PageRank algorithm - * - * @param p float vector to store result vertices weights - * @param A float graph matrix with weights A[i][j] = alpha / outdegree(i) - * @param alpha float alpha to control PageRank (default is 0.85) - * @param eps float tolerance to control precision of PageRank (default is 1e-6) - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status pr( - ref_ptr& p, - const ref_ptr& A, - float alpha = 0.85, - float eps = 1e-6, - const ref_ptr& descriptor = spla::Descriptor::make()); +/** + * @brief PageRank algorithm + * + * @param p float vector to store result vertices weights + * @param A float graph matrix with weights A[i][j] = alpha / outdegree(i) + * @param alpha float alpha to control PageRank (default is 0.85) + * @param eps float tolerance to control precision of PageRank (default is 1e-6) + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ +SPLA_API Status +pr(ref_ptr &p, const ref_ptr &A, float alpha = 0.85, + float eps = 1e-6, + const ref_ptr &descriptor = spla::Descriptor::make()); - /** - * @brief Naive PageRank algorithm (reference cpu implementation) - * - * @param p float vector to store result vertices weights - * @param Ai float graph matrix column indices - * @param Ax float graph matrix weights A[i][j] = alpha / outdegree(i) - * @param alpha float alpha to control PageRank (default is 0.85) - * @param eps float tolerance to control precision of PageRank (default is 1e-6) - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status pr_naive( - std::vector& p, - std::vector>& Ai, - std::vector>& Ax, - float alpha = 0.85, - float eps = 1e-6, - const ref_ptr& descriptor = spla::Descriptor::make()); +/** + * @brief Naive PageRank algorithm (reference cpu implementation) + * + * @param p float vector to store result vertices weights + * @param Ai float graph matrix column indices + * @param Ax float graph matrix weights A[i][j] = alpha / outdegree(i) + * @param alpha float alpha to control PageRank (default is 0.85) + * @param eps float tolerance to control precision of PageRank (default is 1e-6) + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ +SPLA_API Status pr_naive( + std::vector &p, std::vector> &Ai, + std::vector> &Ax, float alpha = 0.85, float eps = 1e-6, + const ref_ptr &descriptor = spla::Descriptor::make()); - /** - * @brief Triangles counting algorithm - * - * @param ntrins Number of triangles counted - * @param A Lower trilingual int matrix with 1 where has edge in a graph - * @param B Buffer int matrix to store result - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status tc( - int& ntrins, - const ref_ptr& A, - const ref_ptr& B, - const ref_ptr& descriptor = spla::Descriptor::make()); +/** + * @brief Triangles counting algorithm + * + * @param ntrins Number of triangles counted + * @param A Lower trilingual int matrix with 1 where has edge in a graph + * @param B Buffer int matrix to store result + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ +SPLA_API Status +tc(int &ntrins, const ref_ptr &A, const ref_ptr &B, + const ref_ptr &descriptor = spla::Descriptor::make()); - /** - * @brief Naive triangles counting algorithm (reference cpu implementation) - * - * @param ntrins Number of triangles counted - * @param A Lower trilingual int matrix structure - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status tc_naive( - int& ntrins, - std::vector>& Ai, - const ref_ptr& descriptor = spla::Descriptor::make()); /** - * @brief Boruvka's Minimum Spanning Tree algorithm - * - * Finds the Minimum Spanning Tree of a weighted undirected graph using - * Boruvka's algorithm with algebraic operations. - * - * @param T float matrix to store MST edges (result). Only upper triangle is used. - * @param S PAIR matrix adjacency matrix with edges (weight, vertex). - * The vertex field stores the target vertex of the edge. - * @param descriptor optional descriptor for algorithm configuration - * @param task_hnd optional pointer to store task handle for async execution - * - * @return ok on success - */ - SPLA_API Status mst( - const ref_ptr& T, - ref_ptr& S, - const ref_ptr& descriptor = spla::Descriptor::make(), - ref_ptr* task_hnd = nullptr); + * @brief Naive triangles counting algorithm (reference cpu implementation) + * + * @param ntrins Number of triangles counted + * @param A Lower trilingual int matrix structure + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ +SPLA_API Status +tc_naive(int &ntrins, std::vector> &Ai, + const ref_ptr &descriptor = spla::Descriptor::make()); +/** + * @brief Boruvka's Minimum Spanning Tree algorithm + * + * Finds the Minimum Spanning Tree of a weighted undirected graph using + * Boruvka's algorithm with algebraic operations. + * + * @param T float matrix to store MST edges (result). Only upper triangle is + * used. + * @param S PAIR matrix adjacency matrix with edges (weight, vertex). + * The vertex field stores the target vertex of the edge. + * @param descriptor optional descriptor for algorithm configuration + * @param task_hnd optional pointer to store task handle for async execution + * + * @return ok on success + */ +SPLA_API Status +mst(const ref_ptr &T, ref_ptr &S, + const ref_ptr &descriptor = spla::Descriptor::make(), + ref_ptr *task_hnd = nullptr); - /** - * @} - */ +/** + * @} + */ -}// namespace spla +} // namespace spla -#endif//SPLA_ALGORITHM_HPP +#endif // SPLA_ALGORITHM_HPP diff --git a/include/spla/io.hpp b/include/spla/io.hpp index 9bfceda62..542baf3d2 100644 --- a/include/spla/io.hpp +++ b/include/spla/io.hpp @@ -80,7 +80,7 @@ namespace spla { [[nodiscard]] SPLA_API const std::vector& get_Ai() const; [[nodiscard]] SPLA_API const std::vector& get_Aj() const; - [[nodiscard]] SPLA_API const std::vector& get_Aw() const; + [[nodiscard]] SPLA_API const std::vector &get_Aw() const; [[nodiscard]] SPLA_API uint get_n_rows() const; [[nodiscard]] SPLA_API uint get_n_cols() const; [[nodiscard]] SPLA_API std::size_t get_n_values() const; @@ -90,7 +90,7 @@ namespace spla { std::filesystem::path m_file_path; std::vector m_Ai; std::vector m_Aj; - std::vector m_Aw; + std::vector m_Aw; bool m_base_is_zero = false; uint m_n_rows = 0; uint m_n_cols = 0; diff --git a/include/spla/matrix.hpp b/include/spla/matrix.hpp index 7038a68ad..c9c32a6e8 100644 --- a/include/spla/matrix.hpp +++ b/include/spla/matrix.hpp @@ -57,11 +57,13 @@ namespace spla { SPLA_API virtual Status set_int(uint row_id, uint col_id, std::int32_t value) = 0; SPLA_API virtual Status set_uint(uint row_id, uint col_id, std::uint32_t value) = 0; SPLA_API virtual Status set_float(uint row_id, uint col_id, float value) = 0; - SPLA_API virtual Status set_pair(uint row_id, uint col_id, Pair value) = 0; + SPLA_API virtual Status set_pair(uint row_id, uint col_id, + Pair value) = 0; SPLA_API virtual Status get_int(uint row_id, uint col_id, std::int32_t& value) = 0; SPLA_API virtual Status get_uint(uint row_id, uint col_id, std::uint32_t& value) = 0; SPLA_API virtual Status get_float(uint row_id, uint col_id, float& value) = 0; - SPLA_API virtual Status get_pair(uint row_id, uint col_id, Pair& value) = 0; + SPLA_API virtual Status get_pair(uint row_id, uint col_id, + Pair &value) = 0; SPLA_API virtual Status build(const ref_ptr& keys1, const ref_ptr& keys2, const ref_ptr& values) = 0; SPLA_API virtual Status read(ref_ptr& keys1, ref_ptr& keys2, ref_ptr& values) = 0; SPLA_API virtual Status clear() = 0; diff --git a/include/spla/op.hpp b/include/spla/op.hpp index 738b130ad..727bdb120 100644 --- a/include/spla/op.hpp +++ b/include/spla/op.hpp @@ -29,9 +29,8 @@ #define SPLA_OP_HPP #include "object.hpp" -#include "type.hpp" #include "spla/pair.hpp" - +#include "type.hpp" #include @@ -66,7 +65,9 @@ namespace spla { SPLA_API static ref_ptr make_int(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_uint(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_float(std::string name, std::string code, std::function function); - SPLA_API static ref_ptr make_pair(std::string name, std::string code, std::function function); + SPLA_API static ref_ptr + make_pair(std::string name, std::string code, + std::function function); }; /** @@ -81,8 +82,9 @@ namespace spla { SPLA_API static ref_ptr make_int(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_uint(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_float(std::string name, std::string code, std::function function); - SPLA_API static ref_ptr make_pair(std::string name, std::string code, std::function function); - + SPLA_API static ref_ptr + make_pair(std::string name, std::string code, + std::function function); }; /** @@ -96,8 +98,9 @@ namespace spla { SPLA_API static ref_ptr make_int(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_uint(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_float(std::string name, std::string code, std::function function); - SPLA_API static ref_ptr make_pair(std::string name, std::string code, std::function function); - + SPLA_API static ref_ptr + make_pair(std::string name, std::string code, + std::function function); }; //////////////////////////////// Unary //////////////////////////////// @@ -139,7 +142,6 @@ namespace spla { SPLA_API extern ref_ptr TRUNC_FLOAT; SPLA_API extern ref_ptr IDENTITY_PAIR; - //////////////////////////////// Binary //////////////////////////////// SPLA_API extern ref_ptr PLUS_INT; diff --git a/include/spla/pair.hpp b/include/spla/pair.hpp index 2b40b468c..8799be29f 100644 --- a/include/spla/pair.hpp +++ b/include/spla/pair.hpp @@ -2,27 +2,22 @@ #define SPLA_PAIR_HPP #include - namespace spla { - struct Pair { - float weight; - int vertex; +struct Pair { + float weight; + int vertex; + + Pair() : weight(std::numeric_limits::infinity()), vertex(-1) {} + Pair(float w, int v) : weight(w), vertex(v) {} + + bool operator<(const Pair &other) const { return weight < other.weight; } + bool operator==(const Pair &other) const { + return weight == other.weight && vertex == other.vertex; + } - Pair(): weight(std::numeric_limits::infinity()), vertex(-1){} - Pair(float w, int v): weight(w), vertex(v){} + bool operator!=(const Pair &other) const { return !(*this == other); } - bool operator<(const Pair& other) const { - return weight < other.weight; - } - bool operator==(const Pair& other) const { - return weight == other.weight && vertex == other.vertex; - } - - bool operator!=(const Pair& other) const { - return !(*this == other); - } - - Pair& operator=(const Pair& other) = default; - }; -} + Pair &operator=(const Pair &other) = default; +}; +} // namespace spla #endif \ No newline at end of file diff --git a/include/spla/scalar.hpp b/include/spla/scalar.hpp index a4fa14a29..71533ac17 100644 --- a/include/spla/scalar.hpp +++ b/include/spla/scalar.hpp @@ -49,15 +49,19 @@ namespace spla { SPLA_API virtual Status set_int(std::int32_t value) = 0; SPLA_API virtual Status set_uint(std::uint32_t value) = 0; SPLA_API virtual Status set_float(float value) = 0; - SPLA_API virtual Status set_pair(Pair value) {return Status::InvalidArgument;} + SPLA_API virtual Status set_pair(Pair value) { + return Status::InvalidArgument; + } SPLA_API virtual Status get_int(std::int32_t& value) = 0; SPLA_API virtual Status get_uint(std::uint32_t& value) = 0; SPLA_API virtual Status get_float(float& value) = 0; - SPLA_API virtual Status get_pair(Pair& value) {return Status::InvalidArgument;} + SPLA_API virtual Status get_pair(Pair &value) { + return Status::InvalidArgument; + } SPLA_API virtual T_INT as_int() = 0; SPLA_API virtual T_UINT as_uint() = 0; SPLA_API virtual T_FLOAT as_float() = 0; - SPLA_API virtual T_PAIR as_pair() = 0; + SPLA_API virtual T_PAIR as_pair() = 0; SPLA_API static ref_ptr make(const ref_ptr& type); SPLA_API static ref_ptr make_int(std::int32_t value); diff --git a/include/spla/vector.hpp b/include/spla/vector.hpp index dedafb901..dd54d55df 100644 --- a/include/spla/vector.hpp +++ b/include/spla/vector.hpp @@ -59,8 +59,8 @@ namespace spla { SPLA_API virtual Status get_int(uint row_id, T_INT& value) = 0; SPLA_API virtual Status get_uint(uint row_id, T_UINT& value) = 0; SPLA_API virtual Status get_float(uint row_id, float& value) = 0; - SPLA_API virtual Status get_pair(uint row_id, Pair& value) = 0; - SPLA_API virtual Status set_pair(uint row_id, Pair value) = 0; + SPLA_API virtual Status get_pair(uint row_id, Pair &value) = 0; + SPLA_API virtual Status set_pair(uint row_id, Pair value) = 0; SPLA_API virtual Status fill_noize(uint seed) = 0; SPLA_API virtual Status fill_with(const ref_ptr& value) = 0; SPLA_API virtual Status build(const ref_ptr& keys, const ref_ptr& values) = 0; diff --git a/src/algorithm.cpp b/src/algorithm.cpp index 26326e7a7..dc3cb493d 100644 --- a/src/algorithm.cpp +++ b/src/algorithm.cpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include @@ -37,8 +44,8 @@ #include #include #include -#include #include +#include #define INF std::numeric_limits::infinity() @@ -46,713 +53,663 @@ namespace spla { #pragma region Bfs - Status bfs(const ref_ptr& v, - const ref_ptr& A, - uint s, - const ref_ptr& descriptor) { - assert(v); - assert(A); +Status bfs(const ref_ptr &v, const ref_ptr &A, uint s, + const ref_ptr &descriptor) { + assert(v); + assert(A); - const auto N = v->get_n_rows(); + const auto N = v->get_n_rows(); - ref_ptr frontier_prev = Vector::make(N, INT); - ref_ptr frontier_new = Vector::make(N, INT); - ref_ptr frontier_size = Scalar::make_int(1); - ref_ptr depth = Scalar::make_int(1); - ref_ptr zero = Scalar::make_int(0); - int current_level = 1; - int discovered = 1; - bool frontier_empty = false; + ref_ptr frontier_prev = Vector::make(N, INT); + ref_ptr frontier_new = Vector::make(N, INT); + ref_ptr frontier_size = Scalar::make_int(1); + ref_ptr depth = Scalar::make_int(1); + ref_ptr zero = Scalar::make_int(0); + int current_level = 1; + int discovered = 1; + bool frontier_empty = false; - ref_ptr desc = Descriptor::make(); - desc->set_early_exit(true); - desc->set_struct_only(true); + ref_ptr desc = Descriptor::make(); + desc->set_early_exit(true); + desc->set_struct_only(true); - frontier_prev->set_int(s, 1); + frontier_prev->set_int(s, 1); - bool push = descriptor->get_push_only(); - bool pull = descriptor->get_pull_only(); - bool push_pull = descriptor->get_push_pull(); - float front_factor = descriptor->get_front_factor(); + bool push = descriptor->get_push_only(); + bool pull = descriptor->get_pull_only(); + bool push_pull = descriptor->get_push_pull(); + float front_factor = descriptor->get_front_factor(); - if (!(push || pull || push_pull)) push = true; + if (!(push || pull || push_pull)) + push = true; #ifndef SPLA_RELEASE - std::string mode; - if (push_pull) mode = "(push_pull " + std::to_string(front_factor * 100.0f) + "%)"; - if (pull) mode = "(pull)"; - if (push) mode = "(push)"; + std::string mode; + if (push_pull) + mode = "(push_pull " + std::to_string(front_factor * 100.0f) + "%)"; + if (pull) + mode = "(pull)"; + if (push) + mode = "(push)"; - std::cout << "start bfs from " << s << " " << mode << std::endl; + std::cout << "start bfs from " << s << " " << mode << std::endl; - Timer tight; + Timer tight; #endif - while (!frontier_empty) { + while (!frontier_empty) { #ifndef SPLA_RELEASE - tight.start(); + tight.start(); #endif - depth->set_int(current_level); - exec_v_assign_masked(v, frontier_prev, depth, SECOND_INT, NQZERO_INT); - - float front_density = float(frontier_size->as_int()) / float(N); - bool is_push_better = (front_density <= front_factor); - - if (push || (push_pull && is_push_better)) { - exec_vxm_masked(frontier_new, v, frontier_prev, A, BAND_INT, BOR_INT, EQZERO_INT, zero, desc); - } else { - exec_mxv_masked(frontier_new, v, A, frontier_prev, BAND_INT, BOR_INT, EQZERO_INT, zero, desc); - } + depth->set_int(current_level); + exec_v_assign_masked(v, frontier_prev, depth, SECOND_INT, NQZERO_INT); + + float front_density = float(frontier_size->as_int()) / float(N); + bool is_push_better = (front_density <= front_factor); + + if (push || (push_pull && is_push_better)) { + exec_vxm_masked(frontier_new, v, frontier_prev, A, BAND_INT, BOR_INT, + EQZERO_INT, zero, desc); + } else { + exec_mxv_masked(frontier_new, v, A, frontier_prev, BAND_INT, BOR_INT, + EQZERO_INT, zero, desc); + } - exec_v_count_mf(frontier_size, frontier_new); + exec_v_count_mf(frontier_size, frontier_new); #ifndef SPLA_RELEASE - tight.stop(); - std::cout << " - iter " << current_level - << " front " << frontier_size->as_int() << " discovered " << discovered << " " - << tight.get_elapsed_ms() << " ms" << std::endl; - Library::get()->time_profile_dump(); - Library::get()->time_profile_reset(); + tight.stop(); + std::cout << " - iter " << current_level << " front " + << frontier_size->as_int() << " discovered " << discovered << " " + << tight.get_elapsed_ms() << " ms" << std::endl; + Library::get()->time_profile_dump(); + Library::get()->time_profile_reset(); #endif - frontier_empty = frontier_size->as_int() == 0; - discovered += frontier_size->as_int(); - current_level += 1; - - std::swap(frontier_prev, frontier_new); - } + frontier_empty = frontier_size->as_int() == 0; + discovered += frontier_size->as_int(); + current_level += 1; - return Status::Ok; - } + std::swap(frontier_prev, frontier_new); + } - Status bfs_naive(std::vector& v, - std::vector>& A, - uint s, - const ref_ptr& descriptor) { + return Status::Ok; +} - const auto N = v.size(); +Status bfs_naive(std::vector &v, std::vector> &A, + uint s, const ref_ptr &descriptor) { - std::queue front; - std::vector visited(N, false); + const auto N = v.size(); - std::fill(v.begin(), v.end(), 0); + std::queue front; + std::vector visited(N, false); - front.push(s); - visited[s] = true; - v[s] = 1; + std::fill(v.begin(), v.end(), 0); - while (!front.empty()) { - auto i = front.front(); - front.pop(); + front.push(s); + visited[s] = true; + v[s] = 1; - for (auto j : A[i]) { - if (!visited[j]) { - visited[j] = true; - v[j] = v[i] + 1; - front.push(j); - } - } - } + while (!front.empty()) { + auto i = front.front(); + front.pop(); - return Status::Ok; + for (auto j : A[i]) { + if (!visited[j]) { + visited[j] = true; + v[j] = v[i] + 1; + front.push(j); + } } + } + + return Status::Ok; +} #pragma endregion Bfs #pragma region Sssp - Status sssp(const ref_ptr& v, - const ref_ptr& A, - uint s, - const ref_ptr& descriptor) { - assert(v); - assert(A); +Status sssp(const ref_ptr &v, const ref_ptr &A, uint s, + const ref_ptr &descriptor) { + assert(v); + assert(A); - const auto N = v->get_n_rows(); - const auto inf = std::numeric_limits::max(); + const auto N = v->get_n_rows(); + const auto inf = std::numeric_limits::max(); - ref_ptr dummy_mask = Vector::make(N, FLOAT); - ref_ptr frontier = Vector::make(N, FLOAT); - ref_ptr feedback = Vector::make(N, FLOAT); - ref_ptr feedback_size = Scalar::make_int(0); - ref_ptr inf_init = Scalar::make_float(inf); - int current_level = 1; - bool feedback_empty = false; + ref_ptr dummy_mask = Vector::make(N, FLOAT); + ref_ptr frontier = Vector::make(N, FLOAT); + ref_ptr feedback = Vector::make(N, FLOAT); + ref_ptr feedback_size = Scalar::make_int(0); + ref_ptr inf_init = Scalar::make_float(inf); + int current_level = 1; + bool feedback_empty = false; - v->set_fill_value(inf_init); - feedback->set_fill_value(inf_init); - frontier->set_fill_value(inf_init); + v->set_fill_value(inf_init); + feedback->set_fill_value(inf_init); + frontier->set_fill_value(inf_init); - v->set_float(s, 0.0f); - feedback->set_float(s, 0.0f); + v->set_float(s, 0.0f); + feedback->set_float(s, 0.0f); - bool push = descriptor->get_push_only(); - bool pull = descriptor->get_pull_only(); - bool push_pull = descriptor->get_push_pull(); - float front_factor = descriptor->get_front_factor(); + bool push = descriptor->get_push_only(); + bool pull = descriptor->get_pull_only(); + bool push_pull = descriptor->get_push_pull(); + float front_factor = descriptor->get_front_factor(); - if (!(push || pull || push_pull)) push = true; + if (!(push || pull || push_pull)) + push = true; #ifndef SPLA_RELEASE - std::string mode; - if (push_pull) mode = "(push_pull " + std::to_string(front_factor * 100.0f) + "%)"; - if (pull) mode = "(pull)"; - if (push) mode = "(push)"; + std::string mode; + if (push_pull) + mode = "(push_pull " + std::to_string(front_factor * 100.0f) + "%)"; + if (pull) + mode = "(pull)"; + if (push) + mode = "(push)"; - std::cout << "start sssp from " << s << " " << mode << std::endl; + std::cout << "start sssp from " << s << " " << mode << std::endl; - Timer tight; + Timer tight; #endif - while (!feedback_empty) { + while (!feedback_empty) { #ifndef SPLA_RELEASE - tight.start(); + tight.start(); #endif - float front_density = float(feedback_size->as_int()) / float(N); - bool is_push_better = (front_density <= front_factor); - - if (push || (push_pull && is_push_better)) { - exec_vxm_masked(frontier, dummy_mask, feedback, A, PLUS_FLOAT, MIN_FLOAT, ALWAYS_FLOAT, inf_init); - } else { - exec_mxv_masked(frontier, dummy_mask, A, feedback, PLUS_FLOAT, MIN_FLOAT, ALWAYS_FLOAT, inf_init); - } + float front_density = float(feedback_size->as_int()) / float(N); + bool is_push_better = (front_density <= front_factor); + + if (push || (push_pull && is_push_better)) { + exec_vxm_masked(frontier, dummy_mask, feedback, A, PLUS_FLOAT, MIN_FLOAT, + ALWAYS_FLOAT, inf_init); + } else { + exec_mxv_masked(frontier, dummy_mask, A, feedback, PLUS_FLOAT, MIN_FLOAT, + ALWAYS_FLOAT, inf_init); + } - exec_v_eadd_fdb(v, frontier, feedback, MIN_FLOAT); - exec_v_count_mf(feedback_size, feedback); + exec_v_eadd_fdb(v, frontier, feedback, MIN_FLOAT); + exec_v_count_mf(feedback_size, feedback); #ifndef SPLA_RELEASE - tight.stop(); - std::cout << " - iter " << current_level - << " feed " << feedback_size->as_int() - << " " << tight.get_elapsed_ms() << " ms" << std::endl; - Library::get()->time_profile_dump(); - Library::get()->time_profile_reset(); + tight.stop(); + std::cout << " - iter " << current_level << " feed " + << feedback_size->as_int() << " " << tight.get_elapsed_ms() + << " ms" << std::endl; + Library::get()->time_profile_dump(); + Library::get()->time_profile_reset(); #endif - feedback_empty = feedback_size->as_int() == 0; - current_level += 1; + feedback_empty = feedback_size->as_int() == 0; + current_level += 1; + } + + return Status::Ok; +} + +Status sssp_naive(std::vector &v, std::vector> &Ai, + std::vector> &Ax, uint s, + const ref_ptr &descriptor) { + + const auto N = v.size(); + const auto inf = std::numeric_limits::max(); + + std::queue front; + std::vector in_queue(N, false); + std::fill(v.begin(), v.end(), inf); + + front.push(s); + in_queue[s] = true; + v[s] = 0.0f; + + while (!front.empty()) { + auto i = front.front(); + front.pop(); + in_queue[i] = false; + + const auto &col_ids = Ai[i]; + const auto &col_vals = Ax[i]; + const auto n_vals = col_ids.size(); + + for (std::size_t k = 0; k < n_vals; k += 1) { + const uint j = col_ids[k]; + const float w = col_vals[k]; + + if (v[j] == inf || v[i] + w < v[j]) { + v[j] = v[i] + w; + if (!in_queue[j]) { + in_queue[j] = true; + front.push(j); } - - return Status::Ok; + } } + } - Status sssp_naive(std::vector& v, - std::vector>& Ai, - std::vector>& Ax, - uint s, - const ref_ptr& descriptor) { - - const auto N = v.size(); - const auto inf = std::numeric_limits::max(); - - std::queue front; - std::vector in_queue(N, false); - std::fill(v.begin(), v.end(), inf); - - front.push(s); - in_queue[s] = true; - v[s] = 0.0f; - - while (!front.empty()) { - auto i = front.front(); - front.pop(); - in_queue[i] = false; - - const auto& col_ids = Ai[i]; - const auto& col_vals = Ax[i]; - const auto n_vals = col_ids.size(); - - for (std::size_t k = 0; k < n_vals; k += 1) { - const uint j = col_ids[k]; - const float w = col_vals[k]; - - if (v[j] == inf || v[i] + w < v[j]) { - v[j] = v[i] + w; - if (!in_queue[j]) { - in_queue[j] = true; - front.push(j); - } - } - } - } - - return Status::Ok; - } + return Status::Ok; +} #pragma endregion Sssp #pragma region Pr - Status pr(ref_ptr& p, - const ref_ptr& A, - float alpha, - float eps, - const ref_ptr& descriptor) { - assert(p); - assert(A); +Status pr(ref_ptr &p, const ref_ptr &A, float alpha, float eps, + const ref_ptr &descriptor) { + assert(p); + assert(A); - const auto N = p->get_n_rows(); + const auto N = p->get_n_rows(); - ref_ptr dummy_mask = Vector::make(N, FLOAT); - ref_ptr p_prev = Vector::make(N, FLOAT); - ref_ptr p_tmp = Vector::make(N, FLOAT); - ref_ptr addition = Vector::make(N, FLOAT); - ref_ptr errors = Vector::make(N, FLOAT); - ref_ptr error2 = Scalar::make(FLOAT); - ref_ptr zero = Scalar::make_float(0.0f); + ref_ptr dummy_mask = Vector::make(N, FLOAT); + ref_ptr p_prev = Vector::make(N, FLOAT); + ref_ptr p_tmp = Vector::make(N, FLOAT); + ref_ptr addition = Vector::make(N, FLOAT); + ref_ptr errors = Vector::make(N, FLOAT); + ref_ptr error2 = Scalar::make(FLOAT); + ref_ptr zero = Scalar::make_float(0.0f); - addition->fill_with(Scalar::make_float((1.0f - alpha) / float(N))); - p_prev->fill_with(Scalar::make_float(1.0f / float(N))); + addition->fill_with(Scalar::make_float((1.0f - alpha) / float(N))); + p_prev->fill_with(Scalar::make_float(1.0f / float(N))); - float error = eps + 0.1f; + float error = eps + 0.1f; #ifndef SPLA_RELEASE - int iter = 0; + int iter = 0; - std::cout << "start pr alpha=" << alpha << " eps " << eps << std::endl; + std::cout << "start pr alpha=" << alpha << " eps " << eps << std::endl; - Timer tight; + Timer tight; #endif - while (error > eps) { + while (error > eps) { #ifndef SPLA_RELEASE - tight.start(); + tight.start(); #endif - // p = A*p + (1-alpha)/N - exec_mxv_masked(p_tmp, dummy_mask, A, p_prev, MULT_FLOAT, PLUS_FLOAT, ALWAYS_FLOAT, zero); - exec_v_eadd(p, p_tmp, addition, PLUS_FLOAT); + // p = A*p + (1-alpha)/N + exec_mxv_masked(p_tmp, dummy_mask, A, p_prev, MULT_FLOAT, PLUS_FLOAT, + ALWAYS_FLOAT, zero); + exec_v_eadd(p, p_tmp, addition, PLUS_FLOAT); - // error = sqrt((p[01]-prev[0])^2 + ... + p[N-1]-prev[N-1])^2) - exec_v_eadd(errors, p, p_prev, MINUS_POW2_FLOAT); - exec_v_reduce(error2, zero, errors, PLUS_FLOAT); + // error = sqrt((p[01]-prev[0])^2 + ... + p[N-1]-prev[N-1])^2) + exec_v_eadd(errors, p, p_prev, MINUS_POW2_FLOAT); + exec_v_reduce(error2, zero, errors, PLUS_FLOAT); - error = std::sqrt(error2->as_float()); + error = std::sqrt(error2->as_float()); - std::swap(p, p_prev); + std::swap(p, p_prev); #ifndef SPLA_RELEASE - tight.stop(); - std::cout << " - iter " << iter++ - << " error " << error - << " " << tight.get_elapsed_ms() << " ms" << std::endl; - Library::get()->time_profile_dump(); - Library::get()->time_profile_reset(); + tight.stop(); + std::cout << " - iter " << iter++ << " error " << error << " " + << tight.get_elapsed_ms() << " ms" << std::endl; + Library::get()->time_profile_dump(); + Library::get()->time_profile_reset(); #endif - } + } - std::swap(p, p_prev); - return Status::Ok; - } + std::swap(p, p_prev); + return Status::Ok; +} - Status pr_naive(std::vector& p, - std::vector>& Ai, - std::vector>& Ax, - float alpha, - float eps, - const ref_ptr& descriptor) { +Status pr_naive(std::vector &p, std::vector> &Ai, + std::vector> &Ax, float alpha, float eps, + const ref_ptr &descriptor) { - const auto N = p.size(); + const auto N = p.size(); - std::vector p_prev(N, 1.0f / float(N)); + std::vector p_prev(N, 1.0f / float(N)); - float error = eps + 0.1f; + float error = eps + 0.1f; - while (error > eps) { - for (std::size_t i = 0; i < N; i++) { - p[i] = 0; + while (error > eps) { + for (std::size_t i = 0; i < N; i++) { + p[i] = 0; - for (std::size_t k = 0; k < Ai[i].size(); k++) { - p[i] += Ax[i][k] * p_prev[Ai[i][k]]; - } + for (std::size_t k = 0; k < Ai[i].size(); k++) { + p[i] += Ax[i][k] * p_prev[Ai[i][k]]; + } - p[i] += (1.0f - alpha) / float(N); - } + p[i] += (1.0f - alpha) / float(N); + } - error = 0.0f; + error = 0.0f; - for (std::size_t i = 0; i < N; i++) { - error += (p[i] - p_prev[i]) * (p[i] - p_prev[i]); - } + for (std::size_t i = 0; i < N; i++) { + error += (p[i] - p_prev[i]) * (p[i] - p_prev[i]); + } - error = std::sqrt(error); + error = std::sqrt(error); - std::swap(p, p_prev); - } + std::swap(p, p_prev); + } - std::swap(p, p_prev); - return Status::Ok; - } + std::swap(p, p_prev); + return Status::Ok; +} #pragma endregion Pr #pragma region Tc - Status tc( - int& ntrins, - const ref_ptr& A, - const ref_ptr& B, - const ref_ptr& descriptor) { - assert(A); - assert(B); +Status tc(int &ntrins, const ref_ptr &A, const ref_ptr &B, + const ref_ptr &descriptor) { + assert(A); + assert(B); - ref_ptr zero = Scalar::make_int(0); - ref_ptr result = Scalar::make(INT); + ref_ptr zero = Scalar::make_int(0); + ref_ptr result = Scalar::make(INT); #ifndef SPLA_RELEASE - std::cout << "start tc" << std::endl; + std::cout << "start tc" << std::endl; - Timer tight; - tight.start(); + Timer tight; + tight.start(); #endif - spla::exec_mxmT_masked(B, A, A, A, MULT_INT, PLUS_INT, GTZERO_INT, zero); - spla::exec_m_reduce(result, zero, B, PLUS_INT); + spla::exec_mxmT_masked(B, A, A, A, MULT_INT, PLUS_INT, GTZERO_INT, zero); + spla::exec_m_reduce(result, zero, B, PLUS_INT); - ntrins = result->as_int(); + ntrins = result->as_int(); #ifndef SPLA_RELEASE - tight.stop(); + tight.stop(); - std::cout << " - ntrins " << ntrins - << " " << tight.get_elapsed_ms() << " ms" << std::endl; + std::cout << " - ntrins " << ntrins << " " << tight.get_elapsed_ms() << " ms" + << std::endl; - Library::get()->time_profile_dump(); - Library::get()->time_profile_reset(); + Library::get()->time_profile_dump(); + Library::get()->time_profile_reset(); #endif - return Status::Ok; - } + return Status::Ok; +} - Status tc_naive( - int& ntrins, - std::vector>& Ai, - const ref_ptr& descriptor) { - - ntrins = 0; - - for (const auto& row_Ai : Ai) { - for (const auto neighbor : row_Ai) { - const auto& row_neighbor = Ai[neighbor]; - - auto it1 = row_Ai.begin(); - auto it2 = row_neighbor.begin(); - - auto end1 = row_Ai.end(); - auto end2 = row_neighbor.end(); - - while (it1 != end1 && it2 != end2) { - if (*it1 == *it2) { - ++ntrins; - ++it1; - ++it2; - } else if (*it1 < *it2) { - ++it1; - } else { - ++it2; - } - } - } - } +Status tc_naive(int &ntrins, std::vector> &Ai, + const ref_ptr &descriptor) { + + ntrins = 0; + + for (const auto &row_Ai : Ai) { + for (const auto neighbor : row_Ai) { + const auto &row_neighbor = Ai[neighbor]; - return Status::Ok; + auto it1 = row_Ai.begin(); + auto it2 = row_neighbor.begin(); + + auto end1 = row_Ai.end(); + auto end2 = row_neighbor.end(); + + while (it1 != end1 && it2 != end2) { + if (*it1 == *it2) { + ++ntrins; + ++it1; + ++it2; + } else if (*it1 < *it2) { + ++it1; + } else { + ++it2; + } + } } + } + + return Status::Ok; +} #pragma endregion Pr #pragma region Mst -Status mst( - const ref_ptr& T, - ref_ptr& S, - const ref_ptr& descriptor, - ref_ptr* task_hnd) { - - assert(S); - assert(T); - - struct timespec step_start, step_end; - double step_time; - - const auto n = S->get_n_rows(); - int comp = n; - - auto parent = Vector::make(n, PAIR); - for (uint i = 0; i < n; i++) { - parent->set_pair(i, T_PAIR(0.0f, i)); - } - auto edge = Vector::make(n, PAIR); - auto cedge = Vector::make(n, PAIR); - auto t_vec = Vector::make(n, PAIR); - auto mask = Vector::make(n, PAIR); - for (uint i = 0; i < n; i++) { - mask->set_pair(i, T_PAIR(1.0f, 0)); - } - auto init_inf = Scalar::make(PAIR); - T_PAIR init_val; - init_inf->set_pair(init_val); - int iteration = 0; - auto new_S = S; +Status mst(const ref_ptr &T, ref_ptr &S, + const ref_ptr &descriptor, + ref_ptr *task_hnd) { + + assert(S); + assert(T); + + const auto n = S->get_n_rows(); + int comp = n; + + auto parent = Vector::make(n, PAIR); + for (uint i = 0; i < n; i++) { + parent->set_pair(i, T_PAIR(0.0f, i)); + } + auto edge = Vector::make(n, PAIR); + auto cedge = Vector::make(n, PAIR); + auto t_vec = Vector::make(n, PAIR); + auto mask = Vector::make(n, PAIR); + for (uint i = 0; i < n; i++) { + mask->set_pair(i, T_PAIR(1.0f, 0)); + } + auto init_inf = Scalar::make(PAIR); + T_PAIR init_val; + init_inf->set_pair(init_val); + int iteration = 0; + auto new_S = S; #ifdef SPLA_RELEASE - std::cout << "start Boruvka MST, vertices = " << n << "\n"; - Timer tight; + std::cout << "start Boruvka MST, vertices = " << n << "\n"; + Timer tight; #endif - while (comp > 1) { + while (comp > 1) { #ifdef SPLA_RELEASE tight.start(); #endif - iteration++; - int edges_added_this_iteration = 0; - // step 1, min edges for each vertices - clock_gettime(CLOCK_MONOTONIC, &step_start); - spla::exec_mxv_masked(edge, mask, S, parent, spla::MUL_PAIR, spla::MIN_PAIR, spla::ALWAYS_PAIR, init_inf); - clock_gettime(CLOCK_MONOTONIC, &step_end); - step_time = (step_end.tv_sec - step_start.tv_sec) + - (step_end.tv_nsec - step_start.tv_nsec) / 1e9; - std::cout << "--- Step 1 (min edges for each vertex, gpu): " << step_time * 1000 << " ms" << std::endl; - #ifdef SPLA_DEBUG + iteration++; + int edges_added_this_iteration = 0; + spla::exec_mxv_masked(edge, mask, S, parent, spla::MUL_PAIR, spla::MIN_PAIR, + spla::ALWAYS_PAIR, init_inf); +#ifdef SPLA_DEBUG std::cout << "edge = ["; for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - edge->get_pair(i, p); - std::cout << "(" << p.weight << ", " << p.vertex << "), "; + spla::T_PAIR p; + edge->get_pair(i, p); + std::cout << "(" << p.weight << ", " << p.vertex << "), "; } std::cout << "]\n"; #endif - // step 2, min edges for each component - clock_gettime(CLOCK_MONOTONIC, &step_start); - for (int32_t i = 0; i < n; i++) { - cedge->set_pair(i, init_val); - } - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - spla::T_PAIR p1; - spla::T_PAIR p2; - parent->get_pair(i, p); - auto p_i = p.vertex; // p_i = parent[i] - cedge->get_pair(p_i, p1); // p1 = cedge[parent[i]] - edge->get_pair(i, p2); // p2 = edge[i] - auto min_for_comp = p1.weight <= p2.weight? p1 : p2; // min(cedge[parent[i]], edge[i]) - cedge->set_pair(p_i, min_for_comp); - } - clock_gettime(CLOCK_MONOTONIC, &step_end); - step_time = (step_end.tv_sec - step_start.tv_sec) + - (step_end.tv_nsec - step_start.tv_nsec) / 1e9; - std::cout << "--- Step 2 (min edges for each component): " << step_time * 1000 << " ms" << std::endl; - + for (int32_t i = 0; i < n; i++) { + cedge->set_pair(i, init_val); + } + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + spla::T_PAIR p1; + spla::T_PAIR p2; + parent->get_pair(i, p); + auto p_i = p.vertex; + cedge->get_pair(p_i, p1); + edge->get_pair(i, p2); + auto min_for_comp = p1.weight <= p2.weight ? p1 : p2; + cedge->set_pair(p_i, min_for_comp); + } + #ifdef SPLA_DEBUG std::cout << "cedge = ["; for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - cedge->get_pair(i, p); - std::cout << "(" << p.weight << ", " << p.vertex << "), "; + spla::T_PAIR p; + cedge->get_pair(i, p); + std::cout << "(" << p.weight << ", " << p.vertex << "), "; } std::cout << "]\n"; #endif - // step 3, когда нашли лучшее ребро компоненты распространяем его на все вершины - clock_gettime(CLOCK_MONOTONIC, &step_start); - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR parent_v; - spla::T_PAIR cedge_v; - parent->get_pair(i, parent_v); - cedge->get_pair(parent_v.vertex, cedge_v); - t_vec->set_pair(i, cedge_v); //t[i] = cedge[parent[i]] - } - clock_gettime(CLOCK_MONOTONIC, &step_end); - step_time = (step_end.tv_sec - step_start.tv_sec) + - (step_end.tv_nsec - step_start.tv_nsec) / 1e9; - std::cout << "--- Step 3: " << step_time * 1000 << " ms" << std::endl; - - //step 4 выбор представителя для каждой компоненты(когда лучшее ребро в edges совпадает с лучшим ребром компоненты t) - clock_gettime(CLOCK_MONOTONIC, &step_start); - auto index = spla::Vector::make(n, spla::INT); - - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR edge_v, t_v; - edge->get_pair(i, edge_v); - t_vec->get_pair(i, t_v); - if (edge_v == t_v) index->set_int(i, i); - else index->set_int(i, n); - } - auto temp = spla::Vector::make(n, spla::INT); - for (int32_t i = 0; i < n; i++) temp->set_int(i, n); - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR parent_v; - parent->get_pair(i, parent_v); - auto p_i = parent_v.vertex; - spla::T_INT temp_v, ind_v; - temp->get_int(p_i, temp_v); - index->get_int(i, ind_v); - spla::T_INT min_v = temp_v < ind_v? temp_v : ind_v; - temp->set_int(p_i, min_v); //temp[parent[i]] = min(temp[parent[i]], index[i]) - } + + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR parent_v; + spla::T_PAIR cedge_v; + parent->get_pair(i, parent_v); + cedge->get_pair(parent_v.vertex, cedge_v); + t_vec->set_pair(i, cedge_v); + } + + auto index = spla::Vector::make(n, spla::INT); + + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR edge_v, t_v; + edge->get_pair(i, edge_v); + t_vec->get_pair(i, t_v); + if (edge_v == t_v) + index->set_int(i, i); + else + index->set_int(i, n); + } + auto temp = spla::Vector::make(n, spla::INT); + for (int32_t i = 0; i < n; i++) + temp->set_int(i, n); + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR parent_v; + parent->get_pair(i, parent_v); + auto p_i = parent_v.vertex; + spla::T_INT temp_v, ind_v; + temp->get_int(p_i, temp_v); + index->get_int(i, ind_v); + spla::T_INT min_v = temp_v < ind_v ? temp_v : ind_v; + temp->set_int(p_i, min_v); + } #ifdef SPLA_DEBUG - std::cout << "t = ["; - for (int32_t i = 0; i < n; i++) { - spla::T_INT p; - temp->get_int(i, p); - std::cout << p << ", "; - } - std::cout << "]\n"; + std::cout << "t = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_INT p; + temp->get_int(i, p); + std::cout << p << ", "; + } + std::cout << "]\n"; #endif - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR parent_v; - parent->get_pair(i, parent_v); - auto p_i = parent_v.vertex; - spla::T_INT temp_v; - temp->get_int(p_i, temp_v); - index->set_int(i, temp_v); - } - clock_gettime(CLOCK_MONOTONIC, &step_end); - step_time = (step_end.tv_sec - step_start.tv_sec) + - (step_end.tv_nsec - step_start.tv_nsec) / 1e9; - std::cout << "--- Step 4(выбор представителя для каждой компоненты): " << step_time * 1000 << " ms" << std::endl; - + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR parent_v; + parent->get_pair(i, parent_v); + auto p_i = parent_v.vertex; + spla::T_INT temp_v; + temp->get_int(p_i, temp_v); + index->set_int(i, temp_v); + } + #ifdef SPLA_DEBUG - std::cout << "index = ["; - for (int32_t i = 0; i < n; i++) { - spla::T_INT p; - index->get_int(i, p); - std::cout << p << ", "; - } - std::cout << "]\n"; + std::cout << "index = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_INT p; + index->get_int(i, p); + std::cout << p << ", "; + } + std::cout << "]\n"; #endif - //step 5 добавляем найденные ребра в MST - clock_gettime(CLOCK_MONOTONIC, &step_start); - auto new_parent = spla::Vector::make(n, spla::PAIR); - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - parent->get_pair(i, p); - new_parent->set_pair(i, p); - } - for (int32_t i = 0; i < n; i++) { - spla::T_INT ind_v; - index->get_int(i, ind_v); - if (i == ind_v) { - auto row = spla::Vector::make(n, spla::PAIR); - spla::exec_m_extract_row(row, S, i, spla::IDENTITY_PAIR); - int min_vertex = -1; - float min_weight = INF; - - for (int32_t j = 0; j < n; j++) { - spla::T_PAIR pair_row; - row->get_pair(j, pair_row); - auto pair_row_weight = pair_row.weight; - auto pair_row_vertex = pair_row.vertex; - if (pair_row_weight < INF) { - spla::T_PAIR p1, p2; - parent->get_pair(i, p1); - parent->get_pair(pair_row_vertex, p2); - if (p1.vertex != p2.vertex) { //разные компоненты - if (pair_row_weight < min_weight) { - min_weight = pair_row_weight; - min_vertex = j; - } - - } - } - - - } - if (min_vertex == -1) continue; - T->set_float(i, min_vertex, min_weight); - T->set_float(min_vertex, i, min_weight); - edges_added_this_iteration++; - if (i < min_vertex) { - spla::T_PAIR p; - spla::T_PAIR old_p; - new_parent->get_pair(i, p); - new_parent->get_pair(min_vertex, old_p); - new_parent->set_pair(min_vertex, spla::T_PAIR(0.0f, p.vertex)); - for (int k = 0; k < n; k++) { - spla::T_PAIR p1; - new_parent->get_pair(k, p1); - if (p1.vertex == old_p.vertex) new_parent->set_pair(k, spla::T_PAIR(0.0f, p.vertex)); - } - } - else { - spla::T_PAIR p; - spla::T_PAIR old_p; - new_parent->get_pair(min_vertex, p); - new_parent->get_pair(i, old_p); - new_parent->set_pair(i, spla::T_PAIR(0.0f, p.vertex)); - for (int k = 0; k < n; k++) { - spla::T_PAIR p1; - new_parent->get_pair(k, p1); - if (p1.vertex == old_p.vertex) new_parent->set_pair(k, spla::T_PAIR(0.0f, p.vertex)); - } - } - + auto new_parent = spla::Vector::make(n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + parent->get_pair(i, p); + new_parent->set_pair(i, p); + } + for (int32_t i = 0; i < n; i++) { + spla::T_INT ind_v; + index->get_int(i, ind_v); + if (i == ind_v) { + auto row = spla::Vector::make(n, spla::PAIR); + spla::exec_m_extract_row(row, S, i, spla::IDENTITY_PAIR); + int min_vertex = -1; + float min_weight = INF; + + for (int32_t j = 0; j < n; j++) { + spla::T_PAIR pair_row; + row->get_pair(j, pair_row); + auto pair_row_weight = pair_row.weight; + auto pair_row_vertex = pair_row.vertex; + if (pair_row_weight < INF) { + spla::T_PAIR p1, p2; + parent->get_pair(i, p1); + parent->get_pair(pair_row_vertex, p2); + if (p1.vertex != p2.vertex) { + if (pair_row_weight < min_weight) { + min_weight = pair_row_weight; + min_vertex = j; + } } + } } - parent = new_parent; - std::vector seen(n, false); - for (uint i = 0; i < n; i++) { - T_PAIR p; - parent->get_pair(i, p); - seen[p.vertex] = true; - } - - comp = 0; - for (uint i = 0; i < n; i++) { - if (seen[i]) comp++; + if (min_vertex == -1) + continue; + T->set_float(i, min_vertex, min_weight); + T->set_float(min_vertex, i, min_weight); + edges_added_this_iteration++; + if (i < min_vertex) { + spla::T_PAIR p; + spla::T_PAIR old_p; + new_parent->get_pair(i, p); + new_parent->get_pair(min_vertex, old_p); + new_parent->set_pair(min_vertex, spla::T_PAIR(0.0f, p.vertex)); + for (int k = 0; k < n; k++) { + spla::T_PAIR p1; + new_parent->get_pair(k, p1); + if (p1.vertex == old_p.vertex) + new_parent->set_pair(k, spla::T_PAIR(0.0f, p.vertex)); + } + } else { + spla::T_PAIR p; + spla::T_PAIR old_p; + new_parent->get_pair(min_vertex, p); + new_parent->get_pair(i, old_p); + new_parent->set_pair(i, spla::T_PAIR(0.0f, p.vertex)); + for (int k = 0; k < n; k++) { + spla::T_PAIR p1; + new_parent->get_pair(k, p1); + if (p1.vertex == old_p.vertex) + new_parent->set_pair(k, spla::T_PAIR(0.0f, p.vertex)); + } } - clock_gettime(CLOCK_MONOTONIC, &step_end); - step_time = (step_end.tv_sec - step_start.tv_sec) + - (step_end.tv_nsec - step_start.tv_nsec) / 1e9; - std::cout << "--- Step 5(добавляем найденные ребра в MST): " << step_time * 1000 << " ms" << std::endl; - + } + } + parent = new_parent; + std::vector seen(n, false); + for (uint i = 0; i < n; i++) { + T_PAIR p; + parent->get_pair(i, p); + seen[p.vertex] = true; + } + + comp = 0; + for (uint i = 0; i < n; i++) { + if (seen[i]) + comp++; + } #ifdef SPLA_DEBUG - std::cout << "parent = ["; - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - parent->get_pair(i, p); - std::cout << p.vertex << ", "; - } - std::cout << "]\n"; + std::cout << "parent = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + parent->get_pair(i, p); + std::cout << p.vertex << ", "; + } + std::cout << "]\n"; #endif #ifdef SPLA_RELEASE - tight.stop(); - std::cout << " - iteration " << iteration - << " components " << comp - << " " << tight.get_elapsed_ms() << " ms" << std::endl; - Library::get()->time_profile_dump(); - Library::get()->time_profile_reset(); + tight.stop(); + std::cout << " - iteration " << iteration << " components " << comp << " " + << tight.get_elapsed_ms() << " ms" << std::endl; + Library::get()->time_profile_dump(); + Library::get()->time_profile_reset(); #endif - if (comp == 1) { - std::cout << "MST complete after " << iteration << " iterations" << std::endl; - return Status::Ok; - } - //обновляем матрицу смежности - clock_gettime(CLOCK_MONOTONIC, &step_start); - auto filtered_S = spla::Matrix::make(n, n, spla::PAIR); - for (int32_t i = 0; i < n; i++) { - for (int32_t j = 0; j < n; j++) { - spla::T_PAIR val; - S->get_pair(i, j, val); - if (val.weight != std::numeric_limits::infinity()) { - spla::T_PAIR parent_i, parent_j; - parent->get_pair(i, parent_i); - parent->get_pair(j, parent_j); - if ((parent_i.vertex != parent_j.vertex) && (val.weight != std::numeric_limits::infinity())) { - filtered_S->set_pair(i, j, val); - } - - } - } - } - S = filtered_S; - if (edges_added_this_iteration == 0) { - return Status::Ok; - } - clock_gettime(CLOCK_MONOTONIC, &step_end); - step_time = (step_end.tv_sec - step_start.tv_sec) + - (step_end.tv_nsec - step_start.tv_nsec) / 1e9; - std::cout << "--- Step 6(добавляем S): " << step_time * 1000 << " ms" << std::endl; - - - - + if (comp == 1) { + std::cout << "MST complete after " << iteration << " iterations" + << std::endl; + return Status::Ok; + } + auto filtered_S = spla::Matrix::make(n, n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + for (int32_t j = 0; j < n; j++) { + spla::T_PAIR val; + S->get_pair(i, j, val); + if (val.weight != std::numeric_limits::infinity()) { + spla::T_PAIR parent_i, parent_j; + parent->get_pair(i, parent_i); + parent->get_pair(j, parent_j); + if ((parent_i.vertex != parent_j.vertex) && + (val.weight != std::numeric_limits::infinity())) { + filtered_S->set_pair(i, j, val); + } } - return Status::Ok; + } + } + S = filtered_S; + if (edges_added_this_iteration == 0) { + return Status::Ok; } + } + return Status::Ok; +} #pragma endregion Mst - -}// namespace spla +} // namespace spla diff --git a/src/binding/c_op.cpp b/src/binding/c_op.cpp index a6f71cc92..d959bf1cb 100644 --- a/src/binding/c_op.cpp +++ b/src/binding/c_op.cpp @@ -1,134 +1,347 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/JetBrains-Research/spla */ +/* This file is part of spla project */ +/* https://github.com/JetBrains-Research/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include "c_config.hpp" -spla_OpUnary spla_OpUnary_IDENTITY_INT() { return as_ptr(spla::IDENTITY_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_IDENTITY_UINT() { return as_ptr(spla::IDENTITY_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_IDENTITY_FLOAT() { return as_ptr(spla::IDENTITY_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_IDENTITY_PAIR() { return as_ptr(spla::IDENTITY_PAIR.ref_and_get()); } -spla_OpUnary spla_OpUnary_AINV_INT() { return as_ptr(spla::AINV_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_AINV_UINT() { return as_ptr(spla::AINV_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_AINV_FLOAT() { return as_ptr(spla::AINV_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_MINV_INT() { return as_ptr(spla::MINV_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_MINV_UINT() { return as_ptr(spla::MINV_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_MINV_FLOAT() { return as_ptr(spla::MINV_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_LNOT_INT() { return as_ptr(spla::LNOT_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_LNOT_UINT() { return as_ptr(spla::LNOT_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_LNOT_FLOAT() { return as_ptr(spla::LNOT_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_UONE_INT() { return as_ptr(spla::UONE_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_UONE_UINT() { return as_ptr(spla::UONE_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_UONE_FLOAT() { return as_ptr(spla::UONE_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ABS_INT() { return as_ptr(spla::ABS_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ABS_UINT() { return as_ptr(spla::ABS_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ABS_FLOAT() { return as_ptr(spla::ABS_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_BNOT_INT() { return as_ptr(spla::BNOT_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_BNOT_UINT() { return as_ptr(spla::BNOT_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_SQRT_FLOAT() { return as_ptr(spla::SQRT_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_LOG_FLOAT() { return as_ptr(spla::LOG_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_EXP_FLOAT() { return as_ptr(spla::EXP_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_SIN_FLOAT() { return as_ptr(spla::SIN_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_COS_FLOAT() { return as_ptr(spla::COS_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_TAN_FLOAT() { return as_ptr(spla::TAN_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ASIN_FLOAT() { return as_ptr(spla::ASIN_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ACOS_FLOAT() { return as_ptr(spla::ACOS_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ATAN_FLOAT() { return as_ptr(spla::ATAN_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_CEIL_FLOAT() { return as_ptr(spla::CEIL_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_FLOOR_FLOAT() { return as_ptr(spla::FLOOR_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ROUND_FLOAT() { return as_ptr(spla::ROUND_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_TRUNC_FLOAT() { return as_ptr(spla::TRUNC_FLOAT.ref_and_get()); } +spla_OpUnary spla_OpUnary_IDENTITY_INT() { + return as_ptr(spla::IDENTITY_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_IDENTITY_UINT() { + return as_ptr(spla::IDENTITY_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_IDENTITY_FLOAT() { + return as_ptr(spla::IDENTITY_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_IDENTITY_PAIR() { + return as_ptr(spla::IDENTITY_PAIR.ref_and_get()); +} +spla_OpUnary spla_OpUnary_AINV_INT() { + return as_ptr(spla::AINV_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_AINV_UINT() { + return as_ptr(spla::AINV_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_AINV_FLOAT() { + return as_ptr(spla::AINV_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_MINV_INT() { + return as_ptr(spla::MINV_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_MINV_UINT() { + return as_ptr(spla::MINV_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_MINV_FLOAT() { + return as_ptr(spla::MINV_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_LNOT_INT() { + return as_ptr(spla::LNOT_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_LNOT_UINT() { + return as_ptr(spla::LNOT_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_LNOT_FLOAT() { + return as_ptr(spla::LNOT_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_UONE_INT() { + return as_ptr(spla::UONE_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_UONE_UINT() { + return as_ptr(spla::UONE_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_UONE_FLOAT() { + return as_ptr(spla::UONE_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ABS_INT() { + return as_ptr(spla::ABS_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ABS_UINT() { + return as_ptr(spla::ABS_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ABS_FLOAT() { + return as_ptr(spla::ABS_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_BNOT_INT() { + return as_ptr(spla::BNOT_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_BNOT_UINT() { + return as_ptr(spla::BNOT_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_SQRT_FLOAT() { + return as_ptr(spla::SQRT_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_LOG_FLOAT() { + return as_ptr(spla::LOG_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_EXP_FLOAT() { + return as_ptr(spla::EXP_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_SIN_FLOAT() { + return as_ptr(spla::SIN_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_COS_FLOAT() { + return as_ptr(spla::COS_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_TAN_FLOAT() { + return as_ptr(spla::TAN_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ASIN_FLOAT() { + return as_ptr(spla::ASIN_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ACOS_FLOAT() { + return as_ptr(spla::ACOS_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ATAN_FLOAT() { + return as_ptr(spla::ATAN_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_CEIL_FLOAT() { + return as_ptr(spla::CEIL_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_FLOOR_FLOAT() { + return as_ptr(spla::FLOOR_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ROUND_FLOAT() { + return as_ptr(spla::ROUND_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_TRUNC_FLOAT() { + return as_ptr(spla::TRUNC_FLOAT.ref_and_get()); +} -spla_OpBinary spla_OpBinary_PLUS_INT() { return as_ptr(spla::PLUS_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_PLUS_UINT() { return as_ptr(spla::PLUS_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_PLUS_FLOAT() { return as_ptr(spla::PLUS_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MINUS_INT() { return as_ptr(spla::MINUS_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MINUS_UINT() { return as_ptr(spla::MINUS_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MINUS_FLOAT() { return as_ptr(spla::MINUS_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MULT_INT() { return as_ptr(spla::MULT_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MULT_UINT() { return as_ptr(spla::MULT_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MULT_FLOAT() { return as_ptr(spla::MULT_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_DIV_INT() { return as_ptr(spla::DIV_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_DIV_UINT() { return as_ptr(spla::DIV_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_DIV_FLOAT() { return as_ptr(spla::DIV_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MINUS_POW2_INT() { return as_ptr(spla::MINUS_POW2_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MINUS_POW2_UINT() { return as_ptr(spla::MINUS_POW2_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MINUS_POW2_FLOAT() { return as_ptr(spla::MINUS_POW2_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_FIRST_INT() { return as_ptr(spla::FIRST_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_FIRST_UINT() { return as_ptr(spla::FIRST_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_FIRST_FLOAT() { return as_ptr(spla::FIRST_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_SECOND_INT() { return as_ptr(spla::SECOND_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_SECOND_UINT() { return as_ptr(spla::SECOND_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_SECOND_FLOAT() { return as_ptr(spla::SECOND_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BONE_INT() { return as_ptr(spla::BONE_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BONE_UINT() { return as_ptr(spla::BONE_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BONE_FLOAT() { return as_ptr(spla::BONE_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MIN_INT() { return as_ptr(spla::MIN_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MIN_UINT() { return as_ptr(spla::MIN_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MIN_FLOAT() { return as_ptr(spla::MIN_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MAX_INT() { return as_ptr(spla::MAX_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MAX_UINT() { return as_ptr(spla::MAX_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MAX_FLOAT() { return as_ptr(spla::MAX_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_LOR_INT() { return as_ptr(spla::LOR_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_LOR_UINT() { return as_ptr(spla::LOR_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_LOR_FLOAT() { return as_ptr(spla::LOR_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_LAND_INT() { return as_ptr(spla::LAND_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_LAND_UINT() { return as_ptr(spla::LAND_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_LAND_FLOAT() { return as_ptr(spla::LAND_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BOR_INT() { return as_ptr(spla::BOR_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BOR_UINT() { return as_ptr(spla::BOR_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BAND_INT() { return as_ptr(spla::BAND_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BAND_UINT() { return as_ptr(spla::BAND_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BXOR_INT() { return as_ptr(spla::BXOR_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BXOR_UINT() { return as_ptr(spla::BXOR_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MIN_PAIR() { return as_ptr(spla::MIN_PAIR.ref_and_get()); } -spla_OpBinary spla_OpBinary_MUL_PAIR() { return as_ptr(spla::MUL_PAIR.ref_and_get()); } +spla_OpBinary spla_OpBinary_PLUS_INT() { + return as_ptr(spla::PLUS_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_PLUS_UINT() { + return as_ptr(spla::PLUS_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_PLUS_FLOAT() { + return as_ptr(spla::PLUS_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MINUS_INT() { + return as_ptr(spla::MINUS_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MINUS_UINT() { + return as_ptr(spla::MINUS_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MINUS_FLOAT() { + return as_ptr(spla::MINUS_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MULT_INT() { + return as_ptr(spla::MULT_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MULT_UINT() { + return as_ptr(spla::MULT_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MULT_FLOAT() { + return as_ptr(spla::MULT_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_DIV_INT() { + return as_ptr(spla::DIV_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_DIV_UINT() { + return as_ptr(spla::DIV_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_DIV_FLOAT() { + return as_ptr(spla::DIV_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MINUS_POW2_INT() { + return as_ptr(spla::MINUS_POW2_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MINUS_POW2_UINT() { + return as_ptr(spla::MINUS_POW2_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MINUS_POW2_FLOAT() { + return as_ptr(spla::MINUS_POW2_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_FIRST_INT() { + return as_ptr(spla::FIRST_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_FIRST_UINT() { + return as_ptr(spla::FIRST_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_FIRST_FLOAT() { + return as_ptr(spla::FIRST_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_SECOND_INT() { + return as_ptr(spla::SECOND_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_SECOND_UINT() { + return as_ptr(spla::SECOND_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_SECOND_FLOAT() { + return as_ptr(spla::SECOND_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BONE_INT() { + return as_ptr(spla::BONE_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BONE_UINT() { + return as_ptr(spla::BONE_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BONE_FLOAT() { + return as_ptr(spla::BONE_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MIN_INT() { + return as_ptr(spla::MIN_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MIN_UINT() { + return as_ptr(spla::MIN_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MIN_FLOAT() { + return as_ptr(spla::MIN_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MAX_INT() { + return as_ptr(spla::MAX_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MAX_UINT() { + return as_ptr(spla::MAX_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MAX_FLOAT() { + return as_ptr(spla::MAX_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_LOR_INT() { + return as_ptr(spla::LOR_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_LOR_UINT() { + return as_ptr(spla::LOR_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_LOR_FLOAT() { + return as_ptr(spla::LOR_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_LAND_INT() { + return as_ptr(spla::LAND_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_LAND_UINT() { + return as_ptr(spla::LAND_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_LAND_FLOAT() { + return as_ptr(spla::LAND_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BOR_INT() { + return as_ptr(spla::BOR_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BOR_UINT() { + return as_ptr(spla::BOR_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BAND_INT() { + return as_ptr(spla::BAND_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BAND_UINT() { + return as_ptr(spla::BAND_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BXOR_INT() { + return as_ptr(spla::BXOR_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BXOR_UINT() { + return as_ptr(spla::BXOR_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MIN_PAIR() { + return as_ptr(spla::MIN_PAIR.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MUL_PAIR() { + return as_ptr(spla::MUL_PAIR.ref_and_get()); +} -spla_OpSelect spla_OpSelect_EQZERO_INT() { return as_ptr(spla::EQZERO_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_EQZERO_UINT() { return as_ptr(spla::EQZERO_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_EQZERO_FLOAT() { return as_ptr(spla::EQZERO_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_NQZERO_INT() { return as_ptr(spla::NQZERO_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_NQZERO_UINT() { return as_ptr(spla::NQZERO_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_NQZERO_FLOAT() { return as_ptr(spla::NQZERO_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_GTZERO_INT() { return as_ptr(spla::GTZERO_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_GTZERO_UINT() { return as_ptr(spla::GTZERO_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_GTZERO_FLOAT() { return as_ptr(spla::GTZERO_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_GEZERO_INT() { return as_ptr(spla::GEZERO_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_GEZERO_UINT() { return as_ptr(spla::GEZERO_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_GEZERO_FLOAT() { return as_ptr(spla::GEZERO_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_LTZERO_INT() { return as_ptr(spla::LTZERO_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_LTZERO_UINT() { return as_ptr(spla::LTZERO_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_LTZERO_FLOAT() { return as_ptr(spla::LTZERO_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_LEZERO_INT() { return as_ptr(spla::LEZERO_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_LEZERO_UINT() { return as_ptr(spla::LEZERO_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_LEZERO_FLOAT() { return as_ptr(spla::LEZERO_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_ALWAYS_INT() { return as_ptr(spla::ALWAYS_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_ALWAYS_UINT() { return as_ptr(spla::ALWAYS_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_ALWAYS_FLOAT() { return as_ptr(spla::ALWAYS_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_ALWAYS_PAIR() { return as_ptr(spla::ALWAYS_PAIR.ref_and_get()); } -spla_OpSelect spla_OpSelect_NEVER_INT() { return as_ptr(spla::NEVER_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_NEVER_UINT() { return as_ptr(spla::NEVER_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_NEVER_FLOAT() { return as_ptr(spla::NEVER_FLOAT.ref_and_get()); } \ No newline at end of file +spla_OpSelect spla_OpSelect_EQZERO_INT() { + return as_ptr(spla::EQZERO_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_EQZERO_UINT() { + return as_ptr(spla::EQZERO_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_EQZERO_FLOAT() { + return as_ptr(spla::EQZERO_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_NQZERO_INT() { + return as_ptr(spla::NQZERO_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_NQZERO_UINT() { + return as_ptr(spla::NQZERO_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_NQZERO_FLOAT() { + return as_ptr(spla::NQZERO_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_GTZERO_INT() { + return as_ptr(spla::GTZERO_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_GTZERO_UINT() { + return as_ptr(spla::GTZERO_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_GTZERO_FLOAT() { + return as_ptr(spla::GTZERO_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_GEZERO_INT() { + return as_ptr(spla::GEZERO_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_GEZERO_UINT() { + return as_ptr(spla::GEZERO_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_GEZERO_FLOAT() { + return as_ptr(spla::GEZERO_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_LTZERO_INT() { + return as_ptr(spla::LTZERO_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_LTZERO_UINT() { + return as_ptr(spla::LTZERO_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_LTZERO_FLOAT() { + return as_ptr(spla::LTZERO_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_LEZERO_INT() { + return as_ptr(spla::LEZERO_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_LEZERO_UINT() { + return as_ptr(spla::LEZERO_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_LEZERO_FLOAT() { + return as_ptr(spla::LEZERO_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_ALWAYS_INT() { + return as_ptr(spla::ALWAYS_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_ALWAYS_UINT() { + return as_ptr(spla::ALWAYS_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_ALWAYS_FLOAT() { + return as_ptr(spla::ALWAYS_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_ALWAYS_PAIR() { + return as_ptr(spla::ALWAYS_PAIR.ref_and_get()); +} +spla_OpSelect spla_OpSelect_NEVER_INT() { + return as_ptr(spla::NEVER_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_NEVER_UINT() { + return as_ptr(spla::NEVER_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_NEVER_FLOAT() { + return as_ptr(spla::NEVER_FLOAT.ref_and_get()); +} \ No newline at end of file diff --git a/src/binding/c_type.cpp b/src/binding/c_type.cpp index 09fbaddce..1c294878a 100644 --- a/src/binding/c_type.cpp +++ b/src/binding/c_type.cpp @@ -36,9 +36,5 @@ spla_Type spla_Type_INT() { spla_Type spla_Type_UINT() { return as_ptr(spla::UINT.get()); } -spla_Type spla_Type_FLOAT() { - return as_ptr(spla::FLOAT.get()); -} -spla_Type spla_Type_PAIR() { - return as_ptr(spla::PAIR.get()); -} \ No newline at end of file +spla_Type spla_Type_FLOAT() { return as_ptr(spla::FLOAT.get()); } +spla_Type spla_Type_PAIR() { return as_ptr(spla::PAIR.get()); } \ No newline at end of file diff --git a/src/core/tmatrix.hpp b/src/core/tmatrix.hpp index 552035df3..d8d173867 100644 --- a/src/core/tmatrix.hpp +++ b/src/core/tmatrix.hpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #ifndef SPLA_TMATRIX_HPP @@ -42,340 +49,341 @@ namespace spla { - /** - * @addtogroup internal - * @{ - */ - - /** - * @class TMatrix - * @brief Matrix interface implementation with type information bound - * - * @tparam T Type of stored elements - */ - template - class TMatrix final : public Matrix { - public: - TMatrix(uint n_rows, uint n_cols); - ~TMatrix() override = default; - - uint get_n_rows() override; - uint get_n_cols() override; - ref_ptr get_type() override; - void set_label(std::string label) override; - const std::string& get_label() const override; - Status set_format(FormatMatrix format) override; - Status set_fill_value(const ref_ptr& value) override; - Status set_reduce(ref_ptr resolve_duplicates) override; - Status set_int(uint row_id, uint col_id, std::int32_t value) override; - Status set_uint(uint row_id, uint col_id, std::uint32_t value) override; - Status set_float(uint row_id, uint col_id, float value) override; - Status set_pair(uint row_id, uint col_id, Pair value) override { return Status::InvalidArgument;} - Status get_int(uint row_id, uint col_id, int32_t& value) override; - Status get_uint(uint row_id, uint col_id, uint32_t& value) override; - Status get_float(uint row_id, uint col_id, float& value) override; - Status get_pair(uint row_id, uint col_id, Pair& value) override { return Status::InvalidArgument;} - Status build(const ref_ptr& keys1, const ref_ptr& keys2, const ref_ptr& values) override; - Status read(ref_ptr& keys1, ref_ptr& keys2, ref_ptr& values) override; - Status clear() override; - - template - Decorator* get() { return m_storage.template get(); } - - void validate_rw(FormatMatrix format); - void validate_rwd(FormatMatrix format); - void validate_wd(FormatMatrix format); - void validate_ctor(FormatMatrix format); - bool is_valid(FormatMatrix format) const; - T get_fill_value() const { return m_storage.get_fill_value(); } - - static StorageManagerMatrix* get_storage_manager(); - - private: - typename StorageManagerMatrix::Storage m_storage; - std::string m_label; - }; - - template - TMatrix::TMatrix(uint n_rows, uint n_cols) { - m_storage.set_dims(n_rows, n_cols); - } - - template - uint TMatrix::get_n_rows() { - return m_storage.get_n_rows(); - } - template - uint TMatrix::get_n_cols() { - return m_storage.get_n_cols(); - } - template - ref_ptr TMatrix::get_type() { - return get_ttype().template as(); - } - - template - void TMatrix::set_label(std::string label) { - m_label = std::move(label); - LOG_MSG(Status::Ok, "set label '" << m_label << "' to " << (void*) this); - } - template - const std::string& TMatrix::get_label() const { - return m_label; - } - - template - Status TMatrix::set_format(FormatMatrix format) { - validate_rw(format); - return Status::Ok; - } - template - Status TMatrix::set_fill_value(const ref_ptr& value) { - if (value) { - m_storage.invalidate(); - - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_int()); - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_uint()); - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_float()); - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_pair()); - - - return Status::Ok; - } - - return Status::InvalidArgument; - } - template - Status TMatrix::set_reduce(ref_ptr resolve_duplicates) { - auto reduce = resolve_duplicates.template cast_safe>(); - - if (reduce) { - validate_ctor(FormatMatrix::CpuLil); - get>()->reduce = reduce->function; - validate_ctor(FormatMatrix::CpuDok); - get>()->reduce = reduce->function; - } - - return Status::InvalidArgument; - } - - template - Status TMatrix::set_int(uint row_id, uint col_id, std::int32_t value) { - validate_rwd(FormatMatrix::CpuLil); - cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); - return Status::Ok; - } - template<> - inline Status TMatrix::set_int(uint row_id, uint col_id, std::int32_t value) { - return Status::InvalidArgument; - } - template - Status TMatrix::set_uint(uint row_id, uint col_id, std::uint32_t value) { - validate_rwd(FormatMatrix::CpuLil); - cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); - return Status::Ok; - } - template<> - inline Status TMatrix::set_uint(uint row_id, uint col_id, std::uint32_t value) { - return Status::InvalidArgument; - } - template - Status TMatrix::set_float(uint row_id, uint col_id, float value) { - validate_rwd(FormatMatrix::CpuLil); - cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); - return Status::Ok; - } - template<> - inline Status TMatrix::set_float(uint row_id, uint col_id, float value) { - return Status::InvalidArgument; - } - - - template - Status TMatrix::get_int(uint row_id, uint col_id, int32_t& value) { - validate_rw(FormatMatrix::CpuDok); - - auto& Ax = get>()->Ax; - auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; - } - template<> - inline Status TMatrix::get_int(uint row_id, uint col_id, std::int32_t& value) { - return Status::InvalidArgument; - } - template - Status TMatrix::get_uint(uint row_id, uint col_id, uint32_t& value) { - validate_rw(FormatMatrix::CpuDok); - - auto& Ax = get>()->Ax; - auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; - } - template<> - inline Status TMatrix::get_uint(uint row_id, uint col_id, std::uint32_t& value) { - return Status::InvalidArgument; - } - template - Status TMatrix::get_float(uint row_id, uint col_id, float& value) { - validate_rw(FormatMatrix::CpuDok); - - auto& Ax = get>()->Ax; - auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; - } - template<> - inline Status TMatrix::get_float(uint row_id, uint col_id, float& value) { - return Status::InvalidArgument; - } - - template - Status TMatrix::build(const ref_ptr& keys1, const ref_ptr& keys2, const ref_ptr& values) { - assert(keys1); - assert(keys2); - assert(values); - - const auto key_size = sizeof(uint); - const auto value_size = sizeof(T); - const auto elements_count = keys1->get_size() / key_size; - - if (elements_count != values->get_size() / value_size) { - return Status::InvalidArgument; - } - if (elements_count * key_size != keys1->get_size()) { - return Status::InvalidArgument; - } - if (elements_count * key_size != keys2->get_size()) { - return Status::InvalidArgument; - } - - validate_rwd(FormatMatrix::CpuCoo); - CpuCoo& coo = *get>(); - - coo.Ai.resize(elements_count); - coo.Aj.resize(elements_count); - coo.Ax.resize(elements_count); - coo.values = uint(elements_count); - - keys1->read(0, key_size * elements_count, coo.Ai.data()); - keys2->read(0, key_size * elements_count, coo.Aj.data()); - values->read(0, value_size * elements_count, coo.Ax.data()); - - return Status::Ok; - } - template - Status TMatrix::read(ref_ptr& keys1, ref_ptr& keys2, ref_ptr& values) { - const auto key_size = sizeof(uint); - const auto value_size = sizeof(T); - - validate_rw(FormatMatrix::CpuCoo); - CpuCoo& coo = *get>(); - - const auto elements_count = coo.Ai.size(); - - keys1 = MemView::make(coo.Ai.data(), key_size * elements_count, false); - keys2 = MemView::make(coo.Aj.data(), key_size * elements_count, false); - values = MemView::make(coo.Ax.data(), value_size * elements_count, false); - - return Status::Ok; - } - - template - Status TMatrix::clear() { - m_storage.invalidate(); - return Status::Ok; - } - - template - void TMatrix::validate_rw(FormatMatrix format) { - StorageManagerMatrix* manager = get_storage_manager(); - manager->validate_rw(format, m_storage); - } - - template - void TMatrix::validate_rwd(FormatMatrix format) { - StorageManagerMatrix* manager = get_storage_manager(); - manager->validate_rwd(format, m_storage); - } - - template - void TMatrix::validate_wd(FormatMatrix format) { - StorageManagerMatrix* manager = get_storage_manager(); - manager->validate_wd(format, m_storage); - } - - template - void TMatrix::validate_ctor(FormatMatrix format) { - StorageManagerMatrix* manager = get_storage_manager(); - manager->validate_ctor(format, m_storage); - } - - template - bool TMatrix::is_valid(FormatMatrix format) const { - return m_storage.is_valid(format); - } - - template - StorageManagerMatrix* TMatrix::get_storage_manager() { - static std::unique_ptr> storage_manager; - - if (!storage_manager) { - storage_manager = std::make_unique>(); - register_formats_matrix(*storage_manager); - } - - return storage_manager.get(); - } - template<> - inline Status TMatrix::set_pair(uint row_id, uint col_id, Pair value) { - if (get_type() != PAIR) { - return Status::InvalidArgument; - } - - validate_rwd(FormatMatrix::CpuLil); - cpu_lil_add_element(row_id, col_id, value, *get>()); - return Status::Ok; - } - template<> - inline Status TMatrix::get_pair(uint row_id, uint col_id, Pair& value) { - if (get_type() != PAIR) { - return Status::InvalidArgument; - } - validate_rw(FormatMatrix::CpuDok); - - auto& Ax = get>()->Ax; - auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; - } - - /** - * @} - */ - -}// namespace spla - - -#endif//SPLA_TMATRIX_HPP +/** + * @addtogroup internal + * @{ + */ + +/** + * @class TMatrix + * @brief Matrix interface implementation with type information bound + * + * @tparam T Type of stored elements + */ +template class TMatrix final : public Matrix { +public: + TMatrix(uint n_rows, uint n_cols); + ~TMatrix() override = default; + + uint get_n_rows() override; + uint get_n_cols() override; + ref_ptr get_type() override; + void set_label(std::string label) override; + const std::string &get_label() const override; + Status set_format(FormatMatrix format) override; + Status set_fill_value(const ref_ptr &value) override; + Status set_reduce(ref_ptr resolve_duplicates) override; + Status set_int(uint row_id, uint col_id, std::int32_t value) override; + Status set_uint(uint row_id, uint col_id, std::uint32_t value) override; + Status set_float(uint row_id, uint col_id, float value) override; + Status set_pair(uint row_id, uint col_id, Pair value) override { + return Status::InvalidArgument; + } + Status get_int(uint row_id, uint col_id, int32_t &value) override; + Status get_uint(uint row_id, uint col_id, uint32_t &value) override; + Status get_float(uint row_id, uint col_id, float &value) override; + Status get_pair(uint row_id, uint col_id, Pair &value) override { + return Status::InvalidArgument; + } + Status build(const ref_ptr &keys1, const ref_ptr &keys2, + const ref_ptr &values) override; + Status read(ref_ptr &keys1, ref_ptr &keys2, + ref_ptr &values) override; + Status clear() override; + + template Decorator *get() { + return m_storage.template get(); + } + + void validate_rw(FormatMatrix format); + void validate_rwd(FormatMatrix format); + void validate_wd(FormatMatrix format); + void validate_ctor(FormatMatrix format); + bool is_valid(FormatMatrix format) const; + T get_fill_value() const { return m_storage.get_fill_value(); } + + static StorageManagerMatrix *get_storage_manager(); + +private: + typename StorageManagerMatrix::Storage m_storage; + std::string m_label; +}; + +template TMatrix::TMatrix(uint n_rows, uint n_cols) { + m_storage.set_dims(n_rows, n_cols); +} + +template uint TMatrix::get_n_rows() { + return m_storage.get_n_rows(); +} +template uint TMatrix::get_n_cols() { + return m_storage.get_n_cols(); +} +template ref_ptr TMatrix::get_type() { + return get_ttype().template as(); +} + +template void TMatrix::set_label(std::string label) { + m_label = std::move(label); + LOG_MSG(Status::Ok, "set label '" << m_label << "' to " << (void *)this); +} +template const std::string &TMatrix::get_label() const { + return m_label; +} + +template Status TMatrix::set_format(FormatMatrix format) { + validate_rw(format); + return Status::Ok; +} +template +Status TMatrix::set_fill_value(const ref_ptr &value) { + if (value) { + m_storage.invalidate(); + + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_int()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_uint()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_float()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_pair()); + + return Status::Ok; + } + + return Status::InvalidArgument; +} +template +Status TMatrix::set_reduce(ref_ptr resolve_duplicates) { + auto reduce = resolve_duplicates.template cast_safe>(); + + if (reduce) { + validate_ctor(FormatMatrix::CpuLil); + get>()->reduce = reduce->function; + validate_ctor(FormatMatrix::CpuDok); + get>()->reduce = reduce->function; + } + + return Status::InvalidArgument; +} + +template +Status TMatrix::set_int(uint row_id, uint col_id, std::int32_t value) { + validate_rwd(FormatMatrix::CpuLil); + cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); + return Status::Ok; +} +template <> +inline Status TMatrix::set_int(uint row_id, uint col_id, + std::int32_t value) { + return Status::InvalidArgument; +} +template +Status TMatrix::set_uint(uint row_id, uint col_id, std::uint32_t value) { + validate_rwd(FormatMatrix::CpuLil); + cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); + return Status::Ok; +} +template <> +inline Status TMatrix::set_uint(uint row_id, uint col_id, + std::uint32_t value) { + return Status::InvalidArgument; +} +template +Status TMatrix::set_float(uint row_id, uint col_id, float value) { + validate_rwd(FormatMatrix::CpuLil); + cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); + return Status::Ok; +} +template <> +inline Status TMatrix::set_float(uint row_id, uint col_id, float value) { + return Status::InvalidArgument; +} + +template +Status TMatrix::get_int(uint row_id, uint col_id, int32_t &value) { + validate_rw(FormatMatrix::CpuDok); + + auto &Ax = get>()->Ax; + auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; +} +template <> +inline Status TMatrix::get_int(uint row_id, uint col_id, + std::int32_t &value) { + return Status::InvalidArgument; +} +template +Status TMatrix::get_uint(uint row_id, uint col_id, uint32_t &value) { + validate_rw(FormatMatrix::CpuDok); + + auto &Ax = get>()->Ax; + auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; +} +template <> +inline Status TMatrix::get_uint(uint row_id, uint col_id, + std::uint32_t &value) { + return Status::InvalidArgument; +} +template +Status TMatrix::get_float(uint row_id, uint col_id, float &value) { + validate_rw(FormatMatrix::CpuDok); + + auto &Ax = get>()->Ax; + auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; +} +template <> +inline Status TMatrix::get_float(uint row_id, uint col_id, float &value) { + return Status::InvalidArgument; +} + +template +Status TMatrix::build(const ref_ptr &keys1, + const ref_ptr &keys2, + const ref_ptr &values) { + assert(keys1); + assert(keys2); + assert(values); + + const auto key_size = sizeof(uint); + const auto value_size = sizeof(T); + const auto elements_count = keys1->get_size() / key_size; + + if (elements_count != values->get_size() / value_size) { + return Status::InvalidArgument; + } + if (elements_count * key_size != keys1->get_size()) { + return Status::InvalidArgument; + } + if (elements_count * key_size != keys2->get_size()) { + return Status::InvalidArgument; + } + + validate_rwd(FormatMatrix::CpuCoo); + CpuCoo &coo = *get>(); + + coo.Ai.resize(elements_count); + coo.Aj.resize(elements_count); + coo.Ax.resize(elements_count); + coo.values = uint(elements_count); + + keys1->read(0, key_size * elements_count, coo.Ai.data()); + keys2->read(0, key_size * elements_count, coo.Aj.data()); + values->read(0, value_size * elements_count, coo.Ax.data()); + + return Status::Ok; +} +template +Status TMatrix::read(ref_ptr &keys1, ref_ptr &keys2, + ref_ptr &values) { + const auto key_size = sizeof(uint); + const auto value_size = sizeof(T); + + validate_rw(FormatMatrix::CpuCoo); + CpuCoo &coo = *get>(); + + const auto elements_count = coo.Ai.size(); + + keys1 = MemView::make(coo.Ai.data(), key_size * elements_count, false); + keys2 = MemView::make(coo.Aj.data(), key_size * elements_count, false); + values = MemView::make(coo.Ax.data(), value_size * elements_count, false); + + return Status::Ok; +} + +template Status TMatrix::clear() { + m_storage.invalidate(); + return Status::Ok; +} + +template void TMatrix::validate_rw(FormatMatrix format) { + StorageManagerMatrix *manager = get_storage_manager(); + manager->validate_rw(format, m_storage); +} + +template void TMatrix::validate_rwd(FormatMatrix format) { + StorageManagerMatrix *manager = get_storage_manager(); + manager->validate_rwd(format, m_storage); +} + +template void TMatrix::validate_wd(FormatMatrix format) { + StorageManagerMatrix *manager = get_storage_manager(); + manager->validate_wd(format, m_storage); +} + +template void TMatrix::validate_ctor(FormatMatrix format) { + StorageManagerMatrix *manager = get_storage_manager(); + manager->validate_ctor(format, m_storage); +} + +template bool TMatrix::is_valid(FormatMatrix format) const { + return m_storage.is_valid(format); +} + +template +StorageManagerMatrix *TMatrix::get_storage_manager() { + static std::unique_ptr> storage_manager; + + if (!storage_manager) { + storage_manager = std::make_unique>(); + register_formats_matrix(*storage_manager); + } + + return storage_manager.get(); +} +template <> +inline Status TMatrix::set_pair(uint row_id, uint col_id, Pair value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + + validate_rwd(FormatMatrix::CpuLil); + cpu_lil_add_element(row_id, col_id, value, *get>()); + return Status::Ok; +} +template <> +inline Status TMatrix::get_pair(uint row_id, uint col_id, Pair &value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + validate_rw(FormatMatrix::CpuDok); + + auto &Ax = get>()->Ax; + auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; +} + +/** + * @} + */ + +} // namespace spla + +#endif // SPLA_TMATRIX_HPP diff --git a/src/core/tscalar.hpp b/src/core/tscalar.hpp index 8268ceff3..75614139c 100644 --- a/src/core/tscalar.hpp +++ b/src/core/tscalar.hpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #ifndef SPLA_TSCALAR_HPP @@ -34,198 +41,151 @@ namespace spla { - /** - * @addtogroup internal - * @{ - */ - - /** - * - * @tparam T - */ - template - class TScalar final : public Scalar { - public: - TScalar() = default; - explicit TScalar(T value); - ~TScalar() override = default; - - ref_ptr get_type() override; - Status set_int(std::int32_t value) override; - Status set_uint(std::uint32_t value) override; - Status set_float(float value) override; - Status get_int(std::int32_t& value) override; - Status get_uint(std::uint32_t& value) override; - Status get_float(float& value) override; - T_INT as_int() override { return static_cast(m_value); } - T_UINT as_uint() override { return static_cast(m_value); } - T_FLOAT as_float() override { return static_cast(m_value); } - T_PAIR as_pair() override { return static_cast(m_value); } - - - void set_label(std::string label) override; - const std::string& get_label() const override; - - T& get_value(); - T get_value() const; - - private: - std::string m_label; - T m_value = T(); - }; - - template - TScalar::TScalar(T value) : m_value(value) { - } - - template - ref_ptr TScalar::get_type() { - return get_ttype().template as(); - } - - template - Status TScalar::set_int(std::int32_t value) { - m_value = static_cast(value); - return Status::Ok; - } - template - Status TScalar::set_uint(std::uint32_t value) { - m_value = static_cast(value); - return Status::Ok; - } - template - Status TScalar::set_float(float value) { - m_value = static_cast(value); - return Status::Ok; - } - - template - Status TScalar::get_int(std::int32_t& value) { - value = static_cast(m_value); - return Status::Ok; - } - template - Status TScalar::get_uint(std::uint32_t& value) { - value = static_cast(m_value); - return Status::Ok; - } - template - Status TScalar::get_float(float& value) { - value = static_cast(m_value); - return Status::Ok; - } - - template - void TScalar::set_label(std::string label) { - m_label = std::move(label); - } - - template - const std::string& TScalar::get_label() const { - return m_label; - } - - template - T& TScalar::get_value() { - return m_value; - } - template - T TScalar::get_value() const { - return m_value; - } - template<> - inline T_PAIR TScalar::as_pair() { - return Pair(); - } - template<> - inline T_PAIR TScalar::as_pair() { - return Pair(); - } - template<> - inline T_PAIR TScalar::as_pair() { - return Pair(); - } - - template<> - class TScalar final : public Scalar { - public: - TScalar() = default; - explicit TScalar(Pair value) : m_value(value) {} - ~TScalar() override = default; - - Status set_pair(Pair value) { - m_value = value; - return Status::Ok; - } - - Status get_pair(Pair& value) const { - value = m_value; - return Status::Ok; - } - - ref_ptr get_type() override { - return PAIR; - } - - Status set_int(std::int32_t) override { - return Status::InvalidArgument; - } - - Status set_uint(std::uint32_t) override { - return Status::InvalidArgument; - } - - Status set_float(float) override { - return Status::InvalidArgument; - } - - Status get_int(std::int32_t& ) override { - return Status::InvalidArgument; - } - - Status get_uint(std::uint32_t& ) override { - return Status::InvalidArgument; - } - - Status get_float(float& ) override { - return Status::InvalidArgument; - } - - T_INT as_int() override { - LOG_MSG(Status::InvalidArgument, "cannot convert Pair to int"); - return 0; - } - - T_UINT as_uint() override { - LOG_MSG(Status::InvalidArgument, "cannot convert Pair to uint"); - return 0; - } - - T_FLOAT as_float() override { - LOG_MSG(Status::InvalidArgument, "cannot convert Pair to float"); - return 0.0f; - } - T_PAIR as_pair() override { - LOG_MSG(Status::InvalidArgument, "cannot convert Pair to pair"); - return Pair(); - } - - void set_label(std::string label) override { - m_label = std::move(label); - } - - const std::string& get_label() const override { - return m_label; - } - - Pair& get_value() { return m_value; } - Pair get_value() const { return m_value; } - - private: - std::string m_label; - Pair m_value = Pair(); - }; - -}// namespace spla - -#endif//SPLA_TSCALAR_HPP +/** + * @addtogroup internal + * @{ + */ + +/** + * + * @tparam T + */ +template class TScalar final : public Scalar { +public: + TScalar() = default; + explicit TScalar(T value); + ~TScalar() override = default; + + ref_ptr get_type() override; + Status set_int(std::int32_t value) override; + Status set_uint(std::uint32_t value) override; + Status set_float(float value) override; + Status get_int(std::int32_t &value) override; + Status get_uint(std::uint32_t &value) override; + Status get_float(float &value) override; + T_INT as_int() override { return static_cast(m_value); } + T_UINT as_uint() override { return static_cast(m_value); } + T_FLOAT as_float() override { return static_cast(m_value); } + T_PAIR as_pair() override { return static_cast(m_value); } + + void set_label(std::string label) override; + const std::string &get_label() const override; + + T &get_value(); + T get_value() const; + +private: + std::string m_label; + T m_value = T(); +}; + +template TScalar::TScalar(T value) : m_value(value) {} + +template ref_ptr TScalar::get_type() { + return get_ttype().template as(); +} + +template Status TScalar::set_int(std::int32_t value) { + m_value = static_cast(value); + return Status::Ok; +} +template Status TScalar::set_uint(std::uint32_t value) { + m_value = static_cast(value); + return Status::Ok; +} +template Status TScalar::set_float(float value) { + m_value = static_cast(value); + return Status::Ok; +} + +template Status TScalar::get_int(std::int32_t &value) { + value = static_cast(m_value); + return Status::Ok; +} +template Status TScalar::get_uint(std::uint32_t &value) { + value = static_cast(m_value); + return Status::Ok; +} +template Status TScalar::get_float(float &value) { + value = static_cast(m_value); + return Status::Ok; +} + +template void TScalar::set_label(std::string label) { + m_label = std::move(label); +} + +template const std::string &TScalar::get_label() const { + return m_label; +} + +template T &TScalar::get_value() { return m_value; } +template T TScalar::get_value() const { return m_value; } +template <> inline T_PAIR TScalar::as_pair() { return Pair(); } +template <> inline T_PAIR TScalar::as_pair() { return Pair(); } +template <> inline T_PAIR TScalar::as_pair() { return Pair(); } + +template <> class TScalar final : public Scalar { +public: + TScalar() = default; + explicit TScalar(Pair value) : m_value(value) {} + ~TScalar() override = default; + + Status set_pair(Pair value) { + m_value = value; + return Status::Ok; + } + + Status get_pair(Pair &value) const { + value = m_value; + return Status::Ok; + } + + ref_ptr get_type() override { return PAIR; } + + Status set_int(std::int32_t) override { return Status::InvalidArgument; } + + Status set_uint(std::uint32_t) override { return Status::InvalidArgument; } + + Status set_float(float) override { return Status::InvalidArgument; } + + Status get_int(std::int32_t &) override { return Status::InvalidArgument; } + + Status get_uint(std::uint32_t &) override { return Status::InvalidArgument; } + + Status get_float(float &) override { return Status::InvalidArgument; } + + T_INT as_int() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to int"); + return 0; + } + + T_UINT as_uint() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to uint"); + return 0; + } + + T_FLOAT as_float() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to float"); + return 0.0f; + } + T_PAIR as_pair() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to pair"); + return Pair(); + } + + void set_label(std::string label) override { m_label = std::move(label); } + + const std::string &get_label() const override { return m_label; } + + Pair &get_value() { return m_value; } + Pair get_value() const { return m_value; } + +private: + std::string m_label; + Pair m_value = Pair(); +}; + +} // namespace spla + +#endif // SPLA_TSCALAR_HPP diff --git a/src/core/ttype.hpp b/src/core/ttype.hpp index 6bb273548..2ca57ebfe 100644 --- a/src/core/ttype.hpp +++ b/src/core/ttype.hpp @@ -129,11 +129,9 @@ namespace spla { ref_ptr> get_ttype() { return FLOAT.cast_safe>(); } - template<> - ref_ptr> get_ttype() { - return PAIR.cast_safe>(); + template <> ref_ptr> get_ttype() { + return PAIR.cast_safe>(); } - /** * @} diff --git a/src/core/tvector.hpp b/src/core/tvector.hpp index 434f3e9bd..e565f9c28 100644 --- a/src/core/tvector.hpp +++ b/src/core/tvector.hpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #ifndef SPLA_TVECTOR_HPP @@ -40,394 +47,386 @@ #include #include +#include "spla/pair.hpp" #include #include -#include "spla/pair.hpp" namespace spla { - /** - * @addtogroup internal - * @{ - */ - - /** - * @class TVector - * @brief Vector interface implementation with type information bound - * - * @tparam T Type of stored elements - */ - template - class TVector final : public Vector { - public: - explicit TVector(uint n_rows); - ~TVector() override = default; - - uint get_n_rows() override; - ref_ptr get_type() override; - void set_label(std::string label) override; - const std::string& get_label() const override; - Status set_format(FormatVector format) override; - Status set_fill_value(const ref_ptr& value) override; - Status set_reduce(ref_ptr resolve_duplicates) override; - Status set_int(uint row_id, std::int32_t value) override; - Status set_uint(uint row_id, std::uint32_t value) override; - Status set_float(uint row_id, float value) override; - Status set_pair(uint row_id, Pair value) override { return Status::InvalidArgument;} - Status get_int(uint row_id, int32_t& value) override; - Status get_uint(uint row_id, uint32_t& value) override; - Status get_float(uint row_id, float& value) override; - Status get_pair(uint row_id, Pair& value) override { return Status::InvalidArgument;} - Status fill_noize(uint seed) override; - Status fill_with(const ref_ptr& value) override; - Status build(const ref_ptr& keys, const ref_ptr& values) override; - Status read(ref_ptr& keys, ref_ptr& values) override; - Status clear() override; - - template - Decorator* get() { return m_storage.template get(); } - - void validate_rw(FormatVector format); - void validate_rwd(FormatVector format); - void validate_wd(FormatVector format); - void validate_ctor(FormatVector format); - bool is_valid(FormatVector format) const; - T get_fill_value() const { return m_storage.get_fill_value(); } - - static StorageManagerVector* get_storage_manager(); - - private: - typename StorageManagerVector::Storage m_storage; - std::string m_label; - }; - - template - TVector::TVector(uint n_rows) { - m_storage.set_dims(n_rows, 1); - } - - template - uint TVector::get_n_rows() { - return m_storage.get_n_rows(); - } - template - ref_ptr TVector::get_type() { - return get_ttype().template as(); - } - - template - void TVector::set_label(std::string label) { - m_label = std::move(label); - LOG_MSG(Status::Ok, "set label '" << m_label << "' to " << (void*) this); - } - template - const std::string& TVector::get_label() const { - return m_label; - } - - template - Status TVector::set_format(FormatVector format) { - validate_rw(format); - return Status::Ok; - } - template - Status TVector::set_fill_value(const ref_ptr& value) { - if (value) { - m_storage.invalidate(); - - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_int()); - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_uint()); - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_float()); - - return Status::Ok; - } - - return Status::InvalidArgument; - } - template - Status TVector::set_reduce(ref_ptr resolve_duplicates) { - auto reduce = resolve_duplicates.template cast_safe>(); - - if (reduce) { - validate_ctor(FormatVector::CpuDok); - auto* vec = get>(); - vec->reduce = reduce->function; - return Status::Ok; - } - - return Status::InvalidArgument; - } - - template - Status TVector::set_int(uint row_id, std::int32_t value) { - if (is_valid(FormatVector::CpuDense)) { - validate_rwd(FormatVector::CpuDense); - get>()->Ax[row_id] = static_cast(value); - return Status::Ok; - } - - validate_rwd(FormatVector::CpuDok); - cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); - return Status::Ok; - } - template<> - inline Status TVector::set_int(uint row_id, std::int32_t value) { - return Status::InvalidArgument; - } - - template - Status TVector::set_uint(uint row_id, std::uint32_t value) { - if (is_valid(FormatVector::CpuDense)) { - validate_rwd(FormatVector::CpuDense); - get>()->Ax[row_id] = static_cast(value); - return Status::Ok; - } - - validate_rwd(FormatVector::CpuDok); - cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); - return Status::Ok; - } - template<> - inline Status TVector::set_uint(uint row_id, std::uint32_t value) { - return Status::InvalidArgument; - } - template - Status TVector::set_float(uint row_id, float value) { - if (is_valid(FormatVector::CpuDense)) { - validate_rwd(FormatVector::CpuDense); - get>()->Ax[row_id] = static_cast(value); - return Status::Ok; - } - - validate_rwd(FormatVector::CpuDok); - cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); - return Status::Ok; - } - template<> - inline Status TVector::set_float(uint row_id, float value) { - return Status::InvalidArgument; - } - - template - Status TVector::get_int(uint row_id, int32_t& value) { - validate_rw(FormatVector::CpuDok); - - const auto& Ax = get>()->Ax; - const auto entry = Ax.find(row_id); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; - } - template<> - inline Status TVector::get_int(uint row_id, int32_t& value) { - return Status::InvalidArgument; - } - template - Status TVector::get_uint(uint row_id, uint32_t& value) { - validate_rw(FormatVector::CpuDok); - - const auto& Ax = get>()->Ax; - const auto entry = Ax.find(row_id); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; - } - template<> - inline Status TVector::get_uint(uint row_id, uint32_t& value) { - return Status::InvalidArgument; - } - template - Status TVector::get_float(uint row_id, float& value) { - validate_rw(FormatVector::CpuDok); - - const auto& Ax = get>()->Ax; - const auto entry = Ax.find(row_id); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; - } - template<> - inline Status TVector::get_float(uint row_id, float& value) { - return Status::InvalidArgument; - } - - template - Status TVector::fill_noize(uint seed) { - validate_wd(FormatVector::CpuDense); - auto& Ax = get>()->Ax; - auto engine = std::default_random_engine(seed); - - if constexpr (std::is_integral_v) { - std::uniform_int_distribution dist; - for (auto& x : Ax) x = dist(engine); - } - if constexpr (std::is_floating_point_v) { - std::uniform_real_distribution dist; - for (auto& x : Ax) x = dist(engine); - } - - return Status::Ok; - } - template - Status TVector::fill_with(const ref_ptr& value) { - assert(value); - - T t = T(); - - if constexpr (std::is_same::value) t = value->as_int(); - if constexpr (std::is_same::value) t = value->as_uint(); - if constexpr (std::is_same::value) t = value->as_float(); - - validate_wd(FormatVector::CpuDense); - auto& Ax = get>()->Ax; - std::fill(Ax.begin(), Ax.end(), t); - - return Status::Ok; - } - - template - Status TVector::build(const ref_ptr& keys, const ref_ptr& values) { - assert(keys); - assert(values); - - const auto key_size = sizeof(uint); - const auto value_size = sizeof(T); - const auto elements_count = keys->get_size() / key_size; - - if (elements_count != values->get_size() / value_size) { - return Status::InvalidArgument; - } - if (elements_count * key_size != keys->get_size()) { - return Status::InvalidArgument; - } - - validate_rwd(FormatVector::CpuCoo); - CpuCooVec& coo = *get>(); - - coo.Ai.resize(elements_count); - coo.Ax.resize(elements_count); - coo.values = uint(elements_count); - - keys->read(0, key_size * elements_count, coo.Ai.data()); - values->read(0, value_size * elements_count, coo.Ax.data()); - - return Status::Ok; - } - template - Status TVector::read(ref_ptr& keys, ref_ptr& values) { - const auto key_size = sizeof(uint); - const auto value_size = sizeof(T); - - validate_rw(FormatVector::CpuCoo); - CpuCooVec& coo = *get>(); - - const auto elements_count = coo.Ai.size(); - - keys = MemView::make(coo.Ai.data(), key_size * elements_count, false); - values = MemView::make(coo.Ax.data(), value_size * elements_count, false); - - return Status::Ok; - } - - template - Status TVector::clear() { - m_storage.invalidate(); - return Status::Ok; - } - - template - void TVector::validate_rw(FormatVector format) { - StorageManagerVector* manager = get_storage_manager(); - manager->validate_rw(format, m_storage); - } - - template - void TVector::validate_rwd(FormatVector format) { - StorageManagerVector* manager = get_storage_manager(); - manager->validate_rwd(format, m_storage); - } - - template - void TVector::validate_wd(FormatVector format) { - StorageManagerVector* manager = get_storage_manager(); - manager->validate_wd(format, m_storage); - } - - template - void TVector::validate_ctor(FormatVector format) { - StorageManagerVector* manager = get_storage_manager(); - manager->validate_ctor(format, m_storage); - } - - template - bool TVector::is_valid(FormatVector format) const { - return m_storage.is_valid(format); - } - - template - StorageManagerVector* TVector::get_storage_manager() { - static std::unique_ptr> storage_manager; - - if (!storage_manager) { - storage_manager = std::make_unique>(); - register_formats_vector(*storage_manager); - } - - return storage_manager.get(); - } - template<> - inline Status TVector::set_pair(uint row_id, Pair value) { - if (get_type() != PAIR) { - return Status::InvalidArgument; - } - - if (is_valid(FormatVector::CpuDense)) { - validate_rwd(FormatVector::CpuDense); - get>()->Ax[row_id] = value; - return Status::Ok; - } - - validate_rwd(FormatVector::CpuDok); - cpu_dok_vec_add_element(row_id, value, *get>()); - return Status::Ok; - } - template<> - inline Status TVector::get_pair(uint row_id, Pair& value) { - if (get_type() != PAIR) { - return Status::InvalidArgument; - } - - validate_rw(FormatVector::CpuDok); - - const auto& Ax = get>()->Ax; - const auto entry = Ax.find(row_id); - - if (entry != Ax.end()) { - value = entry->second; - } else { - value = m_storage.get_fill_value(); - } - - return Status::Ok; - } - - - /** - * @} - */ - -}// namespace spla - -#endif//SPLA_TVECTOR_HPP +/** + * @addtogroup internal + * @{ + */ + +/** + * @class TVector + * @brief Vector interface implementation with type information bound + * + * @tparam T Type of stored elements + */ +template class TVector final : public Vector { +public: + explicit TVector(uint n_rows); + ~TVector() override = default; + + uint get_n_rows() override; + ref_ptr get_type() override; + void set_label(std::string label) override; + const std::string &get_label() const override; + Status set_format(FormatVector format) override; + Status set_fill_value(const ref_ptr &value) override; + Status set_reduce(ref_ptr resolve_duplicates) override; + Status set_int(uint row_id, std::int32_t value) override; + Status set_uint(uint row_id, std::uint32_t value) override; + Status set_float(uint row_id, float value) override; + Status set_pair(uint row_id, Pair value) override { + return Status::InvalidArgument; + } + Status get_int(uint row_id, int32_t &value) override; + Status get_uint(uint row_id, uint32_t &value) override; + Status get_float(uint row_id, float &value) override; + Status get_pair(uint row_id, Pair &value) override { + return Status::InvalidArgument; + } + Status fill_noize(uint seed) override; + Status fill_with(const ref_ptr &value) override; + Status build(const ref_ptr &keys, + const ref_ptr &values) override; + Status read(ref_ptr &keys, ref_ptr &values) override; + Status clear() override; + + template Decorator *get() { + return m_storage.template get(); + } + + void validate_rw(FormatVector format); + void validate_rwd(FormatVector format); + void validate_wd(FormatVector format); + void validate_ctor(FormatVector format); + bool is_valid(FormatVector format) const; + T get_fill_value() const { return m_storage.get_fill_value(); } + + static StorageManagerVector *get_storage_manager(); + +private: + typename StorageManagerVector::Storage m_storage; + std::string m_label; +}; + +template TVector::TVector(uint n_rows) { + m_storage.set_dims(n_rows, 1); +} + +template uint TVector::get_n_rows() { + return m_storage.get_n_rows(); +} +template ref_ptr TVector::get_type() { + return get_ttype().template as(); +} + +template void TVector::set_label(std::string label) { + m_label = std::move(label); + LOG_MSG(Status::Ok, "set label '" << m_label << "' to " << (void *)this); +} +template const std::string &TVector::get_label() const { + return m_label; +} + +template Status TVector::set_format(FormatVector format) { + validate_rw(format); + return Status::Ok; +} +template +Status TVector::set_fill_value(const ref_ptr &value) { + if (value) { + m_storage.invalidate(); + + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_int()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_uint()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_float()); + + return Status::Ok; + } + + return Status::InvalidArgument; +} +template +Status TVector::set_reduce(ref_ptr resolve_duplicates) { + auto reduce = resolve_duplicates.template cast_safe>(); + + if (reduce) { + validate_ctor(FormatVector::CpuDok); + auto *vec = get>(); + vec->reduce = reduce->function; + return Status::Ok; + } + + return Status::InvalidArgument; +} + +template +Status TVector::set_int(uint row_id, std::int32_t value) { + if (is_valid(FormatVector::CpuDense)) { + validate_rwd(FormatVector::CpuDense); + get>()->Ax[row_id] = static_cast(value); + return Status::Ok; + } + + validate_rwd(FormatVector::CpuDok); + cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); + return Status::Ok; +} +template <> +inline Status TVector::set_int(uint row_id, std::int32_t value) { + return Status::InvalidArgument; +} + +template +Status TVector::set_uint(uint row_id, std::uint32_t value) { + if (is_valid(FormatVector::CpuDense)) { + validate_rwd(FormatVector::CpuDense); + get>()->Ax[row_id] = static_cast(value); + return Status::Ok; + } + + validate_rwd(FormatVector::CpuDok); + cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); + return Status::Ok; +} +template <> +inline Status TVector::set_uint(uint row_id, std::uint32_t value) { + return Status::InvalidArgument; +} +template Status TVector::set_float(uint row_id, float value) { + if (is_valid(FormatVector::CpuDense)) { + validate_rwd(FormatVector::CpuDense); + get>()->Ax[row_id] = static_cast(value); + return Status::Ok; + } + + validate_rwd(FormatVector::CpuDok); + cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); + return Status::Ok; +} +template <> inline Status TVector::set_float(uint row_id, float value) { + return Status::InvalidArgument; +} + +template Status TVector::get_int(uint row_id, int32_t &value) { + validate_rw(FormatVector::CpuDok); + + const auto &Ax = get>()->Ax; + const auto entry = Ax.find(row_id); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; +} +template <> inline Status TVector::get_int(uint row_id, int32_t &value) { + return Status::InvalidArgument; +} +template +Status TVector::get_uint(uint row_id, uint32_t &value) { + validate_rw(FormatVector::CpuDok); + + const auto &Ax = get>()->Ax; + const auto entry = Ax.find(row_id); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; +} +template <> +inline Status TVector::get_uint(uint row_id, uint32_t &value) { + return Status::InvalidArgument; +} +template Status TVector::get_float(uint row_id, float &value) { + validate_rw(FormatVector::CpuDok); + + const auto &Ax = get>()->Ax; + const auto entry = Ax.find(row_id); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; +} +template <> inline Status TVector::get_float(uint row_id, float &value) { + return Status::InvalidArgument; +} + +template Status TVector::fill_noize(uint seed) { + validate_wd(FormatVector::CpuDense); + auto &Ax = get>()->Ax; + auto engine = std::default_random_engine(seed); + + if constexpr (std::is_integral_v) { + std::uniform_int_distribution dist; + for (auto &x : Ax) + x = dist(engine); + } + if constexpr (std::is_floating_point_v) { + std::uniform_real_distribution dist; + for (auto &x : Ax) + x = dist(engine); + } + + return Status::Ok; +} +template +Status TVector::fill_with(const ref_ptr &value) { + assert(value); + + T t = T(); + + if constexpr (std::is_same::value) + t = value->as_int(); + if constexpr (std::is_same::value) + t = value->as_uint(); + if constexpr (std::is_same::value) + t = value->as_float(); + + validate_wd(FormatVector::CpuDense); + auto &Ax = get>()->Ax; + std::fill(Ax.begin(), Ax.end(), t); + + return Status::Ok; +} + +template +Status TVector::build(const ref_ptr &keys, + const ref_ptr &values) { + assert(keys); + assert(values); + + const auto key_size = sizeof(uint); + const auto value_size = sizeof(T); + const auto elements_count = keys->get_size() / key_size; + + if (elements_count != values->get_size() / value_size) { + return Status::InvalidArgument; + } + if (elements_count * key_size != keys->get_size()) { + return Status::InvalidArgument; + } + + validate_rwd(FormatVector::CpuCoo); + CpuCooVec &coo = *get>(); + + coo.Ai.resize(elements_count); + coo.Ax.resize(elements_count); + coo.values = uint(elements_count); + + keys->read(0, key_size * elements_count, coo.Ai.data()); + values->read(0, value_size * elements_count, coo.Ax.data()); + + return Status::Ok; +} +template +Status TVector::read(ref_ptr &keys, ref_ptr &values) { + const auto key_size = sizeof(uint); + const auto value_size = sizeof(T); + + validate_rw(FormatVector::CpuCoo); + CpuCooVec &coo = *get>(); + + const auto elements_count = coo.Ai.size(); + + keys = MemView::make(coo.Ai.data(), key_size * elements_count, false); + values = MemView::make(coo.Ax.data(), value_size * elements_count, false); + + return Status::Ok; +} + +template Status TVector::clear() { + m_storage.invalidate(); + return Status::Ok; +} + +template void TVector::validate_rw(FormatVector format) { + StorageManagerVector *manager = get_storage_manager(); + manager->validate_rw(format, m_storage); +} + +template void TVector::validate_rwd(FormatVector format) { + StorageManagerVector *manager = get_storage_manager(); + manager->validate_rwd(format, m_storage); +} + +template void TVector::validate_wd(FormatVector format) { + StorageManagerVector *manager = get_storage_manager(); + manager->validate_wd(format, m_storage); +} + +template void TVector::validate_ctor(FormatVector format) { + StorageManagerVector *manager = get_storage_manager(); + manager->validate_ctor(format, m_storage); +} + +template bool TVector::is_valid(FormatVector format) const { + return m_storage.is_valid(format); +} + +template +StorageManagerVector *TVector::get_storage_manager() { + static std::unique_ptr> storage_manager; + + if (!storage_manager) { + storage_manager = std::make_unique>(); + register_formats_vector(*storage_manager); + } + + return storage_manager.get(); +} +template <> inline Status TVector::set_pair(uint row_id, Pair value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + + if (is_valid(FormatVector::CpuDense)) { + validate_rwd(FormatVector::CpuDense); + get>()->Ax[row_id] = value; + return Status::Ok; + } + + validate_rwd(FormatVector::CpuDok); + cpu_dok_vec_add_element(row_id, value, *get>()); + return Status::Ok; +} +template <> inline Status TVector::get_pair(uint row_id, Pair &value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + + validate_rw(FormatVector::CpuDok); + + const auto &Ax = get>()->Ax; + const auto entry = Ax.find(row_id); + + if (entry != Ax.end()) { + value = entry->second; + } else { + value = m_storage.get_fill_value(); + } + + return Status::Ok; +} + +/** + * @} + */ + +} // namespace spla + +#endif // SPLA_TVECTOR_HPP diff --git a/src/cpu/cpu_algo_registry.cpp b/src/cpu/cpu_algo_registry.cpp index 379e602c7..2e23d1ce0 100644 --- a/src/cpu/cpu_algo_registry.cpp +++ b/src/cpu/cpu_algo_registry.cpp @@ -127,8 +127,8 @@ namespace spla { g_registry->add(MAKE_KEY_CPU_0("m_extract_row", INT), std::make_shared>()); g_registry->add(MAKE_KEY_CPU_0("m_extract_row", UINT), std::make_shared>()); g_registry->add(MAKE_KEY_CPU_0("m_extract_row", FLOAT), std::make_shared>()); - g_registry->add(MAKE_KEY_CPU_0("m_extract_row", PAIR), std::make_shared>()); - + g_registry->add(MAKE_KEY_CPU_0("m_extract_row", PAIR), + std::make_shared>()); // algorthm m_extract_column g_registry->add(MAKE_KEY_CPU_0("m_extract_column", INT), std::make_shared>()); diff --git a/src/io.cpp b/src/io.cpp index d21a79fd6..578a20136 100644 --- a/src/io.cpp +++ b/src/io.cpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include @@ -44,398 +51,421 @@ namespace spla { - MtxLoader::MtxLoader(std::string name) : m_name(std::move(name)) { +MtxLoader::MtxLoader(std::string name) : m_name(std::move(name)) {} + +bool MtxLoader::load(std::filesystem::path file_path, bool offset_indices, + bool make_undirected, bool remove_loops) { + m_file_path = std::move(file_path); + m_base_is_zero = offset_indices; + + std::fstream file(m_file_path, std::ios::in); + if (!file.is_open()) { + LOG_MSG(Status::Error, "failed to open file " << m_file_path); + return false; + } + + Timer t; + t.start(); + + std::size_t n_lines = 0; + std::size_t n_sort = 0; + + std::string line; + while (std::getline(file, line)) { + if (line[0] != '%') + break; + n_lines++; + } + + std::size_t nnz; + std::stringstream header(line); + header >> m_n_rows >> m_n_cols >> nnz; + + bool file_has_values = false; + if (line.find("pattern") == + std::string::npos) { // есть подстрока pattern => граф невзвешенный + file_has_values = true; + } + + std::cout << "Loading matrix-market coordinate format data... " << std::endl; + std::cout << " Reading from " << m_file_path << std::endl; + std::cout << " Matrix size " << m_n_rows << " rows, " << m_n_cols << " cols" + << std::endl; + std::cout << " Data: " << nnz << " directed edges" << std::endl; + if (remove_loops) + std::cout << " Opt: remove self-loops" << std::endl; + if (offset_indices) + std::cout << " Opt: offset indices by -1" << std::endl; + if (make_undirected) + std::cout << " Opt: double edges" << std::endl; + std::cout << " Reading data: "; + + // optimized reading by sliding window + const std::size_t BUFFER_CAPACITY = 1024 * 8; + std::size_t buffer_size = 0; + std::size_t buffer_offset = 0; + char buffer[BUFFER_CAPACITY + 1]; + + // read data + std::size_t to_count = 0; + std::size_t to_read = nnz; + std::size_t to_preallocate = to_read * (make_undirected ? 2 : 1); + std::vector Ai; + std::vector Aj; + std::vector Av; + + // preallocate to avoid copy + Ai.reserve(to_preallocate); + Aj.reserve(to_preallocate); + if (file_has_values) + Av.reserve(to_preallocate); + + float job_done = 0.0f; + float job_total = 35.0f; + + while (to_read > 0) { + to_count++; + to_read--; + n_lines++; + + // display current progress of reading the file + while (float(to_count) / float(nnz) > job_done / job_total) { + job_done += 1.0f; + std::cout << "|"; } - bool MtxLoader::load(std::filesystem::path file_path, bool offset_indices, bool make_undirected, bool remove_loops) { - m_file_path = std::move(file_path); - m_base_is_zero = offset_indices; - - std::fstream file(m_file_path, std::ios::in); - if (!file.is_open()) { - LOG_MSG(Status::Error, "failed to open file " << m_file_path); - return false; + // try to find where next line is ends up + bool line_found = false; + std::size_t line_end; + while (!line_found) { + line_end = buffer_offset; + + // travers buffer to find ending + while (line_end < buffer_size && buffer[line_end] != '\n') { + line_end += 1; + } + + // buffer not ended of file is ended + line_found = line_end < buffer_size || file.eof(); + + // not found in buffer, need to fetch more data + if (!line_found) { + assert(!file.eof()); + assert(buffer_offset <= BUFFER_CAPACITY); + + if (buffer_offset > 0) { + if (buffer_offset < BUFFER_CAPACITY) { + std::memcpy(buffer, buffer + buffer_offset, + BUFFER_CAPACITY - buffer_offset); + } + buffer_offset = BUFFER_CAPACITY - buffer_offset; } - Timer t; - t.start(); - - std::size_t n_lines = 0; - std::size_t n_sort = 0; - - std::string line; - while (std::getline(file, line)) { - if (line[0] != '%') break; - n_lines++; - } - - std::size_t nnz; - std::stringstream header(line); - header >> m_n_rows >> m_n_cols >> nnz; - - bool file_has_values = false; - if (line.find("pattern") == std::string::npos) { //есть подстрока pattern => граф невзвешенный - file_has_values = true; - } - - std::cout << "Loading matrix-market coordinate format data... " << std::endl; - std::cout << " Reading from " << m_file_path << std::endl; - std::cout << " Matrix size " << m_n_rows << " rows, " << m_n_cols << " cols" << std::endl; - std::cout << " Data: " << nnz << " directed edges" << std::endl; - if (remove_loops) std::cout << " Opt: remove self-loops" << std::endl; - if (offset_indices) std::cout << " Opt: offset indices by -1" << std::endl; - if (make_undirected) std::cout << " Opt: double edges" << std::endl; - std::cout << " Reading data: "; - - // optimized reading by sliding window - const std::size_t BUFFER_CAPACITY = 1024 * 8; - std::size_t buffer_size = 0; - std::size_t buffer_offset = 0; - char buffer[BUFFER_CAPACITY + 1]; - - // read data - std::size_t to_count = 0; - std::size_t to_read = nnz; - std::size_t to_preallocate = to_read * (make_undirected ? 2 : 1); - std::vector Ai; - std::vector Aj; - std::vector Av; - - // preallocate to avoid copy - Ai.reserve(to_preallocate); - Aj.reserve(to_preallocate); - if (file_has_values) Av.reserve(to_preallocate); - - float job_done = 0.0f; - float job_total = 35.0f; - - while (to_read > 0) { - to_count++; - to_read--; - n_lines++; - - // display current progress of reading the file - while (float(to_count) / float(nnz) > job_done / job_total) { - job_done += 1.0f; - std::cout << "|"; - } - - // try to find where next line is ends up - bool line_found = false; - std::size_t line_end; - while (!line_found) { - line_end = buffer_offset; - - // travers buffer to find ending - while (line_end < buffer_size && buffer[line_end] != '\n') { - line_end += 1; - } - - // buffer not ended of file is ended - line_found = line_end < buffer_size || file.eof(); - - // not found in buffer, need to fetch more data - if (!line_found) { - assert(!file.eof()); - assert(buffer_offset <= BUFFER_CAPACITY); - - if (buffer_offset > 0) { - if (buffer_offset < BUFFER_CAPACITY) { - std::memcpy(buffer, buffer + buffer_offset, BUFFER_CAPACITY - buffer_offset); - } - buffer_offset = BUFFER_CAPACITY - buffer_offset; - } - - auto bytes_to_read = BUFFER_CAPACITY - buffer_offset; - file.read(buffer + buffer_offset, std::streamsize(bytes_to_read)); - auto bytes_actually_read = file.gcount(); - buffer_size = buffer_offset + bytes_actually_read; - buffer_offset = 0; - buffer[buffer_size] = '\0'; - assert(buffer_size <= BUFFER_CAPACITY + 1); - } - } - - char* end = nullptr; - auto i = uint(std::strtoll(buffer + buffer_offset, &end, 10)); - auto j = uint(std::strtoll(end, &end, 10)); - float val = 1.0f; //default value - - if (file_has_values) { - char* next = end; - while (*next == ' ' || *next == '\t') next++; - if (*next != '\n' && *next != '\0') { - val = static_cast(std::strtod(next, &end)); - } - } - buffer_offset = line_end + 1; - - assert(i > 0 && j > 0); - - if (remove_loops) { - if (i == j) continue; - } - if (offset_indices) { - i -= 1; - j -= 1; - } - if (make_undirected) { - Ai.push_back(j); - Aj.push_back(i); - if (file_has_values) Av.push_back(val); - } - - Ai.push_back(i); - Aj.push_back(j); - if (file_has_values) Av.push_back(val); - } - t.lap_end();// parsing - - if (file_has_values) { - struct Edge { - uint i, j; - float w; - bool operator<(const Edge& other) const { - if (i != other.i) return i < other.i; - return j < other.j; - } - }; - - std::vector edges; - edges.reserve(Ai.size()); - for (std::size_t k = 0; k < Ai.size(); k++) { - edges.push_back({Ai[k], Aj[k], Av[k]}); - } - - std::sort(edges.begin(), edges.end()); - - std::vector reduced_Ai; - std::vector reduced_Aj; - std::vector reduced_Av; - reduced_Ai.reserve(edges.size()); - reduced_Aj.reserve(edges.size()); - reduced_Av.reserve(edges.size()); - - for (std::size_t k = 0; k < edges.size(); k++) { - if (k == 0 || edges[k].i != edges[k-1].i || edges[k].j != edges[k-1].j) { - reduced_Ai.push_back(edges[k].i); - reduced_Aj.push_back(edges[k].j); - reduced_Av.push_back(edges[k].w); - } - } - - m_n_values = reduced_Ai.size(); - m_Ai = std::move(reduced_Ai); - m_Aj = std::move(reduced_Aj); - m_Aw = std::move(reduced_Av); - - } else { - std::vector sorted; - { - sorted.reserve(Ai.size()); - n_sort = Ai.size(); - - for (std::size_t k = 0; k < Ai.size(); k++) { - std::uint64_t entry = 0; - entry |= std::uint64_t(Ai[k]) << 32u; - entry |= std::uint64_t(Aj[k]) << 0u; - sorted.push_back(entry); - } - Ai.clear(); - Aj.clear(); - - std::sort(sorted.begin(), sorted.end()); - } - t.lap_end();// sorting - - std::vector reduced_Ai; - std::vector reduced_Aj; - { - reduced_Ai.reserve(sorted.size()); - reduced_Aj.reserve(sorted.size()); - - std::uint64_t entry_prev = 0xffffffffffffffff; - for (std::uint64_t entry : sorted) { - if (entry_prev != entry) { - uint i = uint((entry >> 32u) & 0xffffffff); - uint j = uint((entry >> 0u) & 0xffffffff); - reduced_Ai.push_back(i); - reduced_Aj.push_back(j); - } - entry_prev = entry; - } - - m_n_values = reduced_Ai.size(); - m_Ai = std::move(reduced_Ai); - m_Aj = std::move(reduced_Aj); - } - } - - calc_stats(); - t.lap_end();// stats - - t.stop(); - - std::cout << " 100%" << std::endl; - std::cout << " Parsed in " << t.get_laps_ms()[0] * 1e-3 << " sec " << n_lines << " lines" - << " speed " << float(n_lines) / (t.get_laps_ms()[0] * 1e-3) << " lines/sec" << std::endl; - std::cout << " Sorted in " << t.get_laps_ms()[1] * 1e-3 << " sec " << n_sort << " lines" << std::endl; - std::cout << " Reduced in " << t.get_laps_ms()[2] * 1e-3 << " sec " << m_n_values << " lines" << std::endl; - std::cout << " Calc stats in " << t.get_laps_ms()[3] * 1e-3 << " sec" << std::endl; - std::cout << " Loaded in " << t.get_elapsed_ms() * 1e-3 << " sec, " << m_n_values << " edges total" << std::endl; - - output_stats(); - - return true; + auto bytes_to_read = BUFFER_CAPACITY - buffer_offset; + file.read(buffer + buffer_offset, std::streamsize(bytes_to_read)); + auto bytes_actually_read = file.gcount(); + buffer_size = buffer_offset + bytes_actually_read; + buffer_offset = 0; + buffer[buffer_size] = '\0'; + assert(buffer_size <= BUFFER_CAPACITY + 1); + } } - bool MtxLoader::save(const std::filesystem::path& file_path, bool stats_only) { - std::fstream file(file_path, std::ios::out); - - if (!file.is_open()) { - LOG_MSG(Status::Error, "failed to open file " << file_path); - return false; - } - - file << "%%MatrixMarket matrix coordinate pattern general\n"; - file << "%-------------------------------------------------------------------------------\n"; - file << "%-------------------------------------------------------------------------------\n"; - - file << "% meta-info:\n"; - file << "% name: " << m_name << "\n"; - file << "% source-file: " << m_file_path << "\n"; - file << "% deg-avg: " << m_deg_avg << "\n"; - file << "% deg-sd: " << m_deg_sd << "\n"; - file << "% deg-min: " << m_deg_min << "\n"; - file << "% deg-max: " << m_deg_max << "\n"; - file << "% deg-distribution: \n"; - - for (std::size_t i = 0; i < m_deg_distribution.size(); i++) { - file << "% " << m_deg_ranges[i] << " " << m_deg_ranges[i + 1] << " " << m_deg_distribution[i] << "\n"; - } - - file << "%-------------------------------------------------------------------------------\n"; - file << m_n_rows << " " << m_n_cols << " " << m_n_values << "\n"; - - if (!stats_only) { - const uint offset = m_base_is_zero ? 1 : 0; - for (std::size_t k = 0; k < m_n_values; k++) { - file << m_Ai[k] + offset << " " << m_Aj[k] + offset << "\n"; - } - } - - return true; + char *end = nullptr; + auto i = uint(std::strtoll(buffer + buffer_offset, &end, 10)); + auto j = uint(std::strtoll(end, &end, 10)); + float val = 1.0f; // default value + + if (file_has_values) { + char *next = end; + while (*next == ' ' || *next == '\t') + next++; + if (*next != '\n' && *next != '\0') { + val = static_cast(std::strtod(next, &end)); + } } + buffer_offset = line_end + 1; - void MtxLoader::calc_stats() { - std::vector deg_pre_vertex(m_n_rows, 0.0f); - - for (auto i : m_Ai) { - deg_pre_vertex[m_base_is_zero ? i : i - 1] += 1; - } + assert(i > 0 && j > 0); - m_deg_sd = 0.0; - m_deg_avg = 0.0; - m_deg_max = -1.0; - m_deg_min = 1.0 + static_cast(m_n_values); - - for (auto deg : deg_pre_vertex) { - m_deg_min = std::min(m_deg_min, static_cast(deg)); - m_deg_max = std::max(m_deg_max, static_cast(deg)); - m_deg_avg += deg; - m_deg_sd += deg * deg; - } - - auto n = static_cast(m_n_rows); - - m_deg_avg = m_deg_avg / n; - m_deg_sd = std::sqrt(n * (m_deg_sd / n - m_deg_avg * m_deg_avg) / (n > 1.0 ? n - 1.0 : 1.0)); - - const uint GROUPS_COUNT_MAX = std::max(uint(10), uint(std::log2(double(m_n_rows) * 0.77))); - - std::vector count_per_deg(static_cast(m_deg_max) + 2, 0); - std::vector count_per_deg_offsets(static_cast(m_deg_max) + 2, 0); - - for (uint i = 0; i < m_n_rows; i++) { - count_per_deg[std::min(deg_pre_vertex[i], uint(m_deg_max))] += 1; - } - - std::exclusive_scan(count_per_deg.begin(), count_per_deg.end(), count_per_deg_offsets.begin(), 0); - count_per_deg_offsets.back() += 1; - - std::vector distributions; - std::vector ranges; - - auto range = m_deg_max - m_deg_min; - auto groups_count = std::max(std::min(GROUPS_COUNT_MAX, static_cast(range)), 1u); - auto g = static_cast(groups_count); - - auto total = static_cast(count_per_deg_offsets.back()); - auto from = count_per_deg_offsets.begin(); - - ranges.push_back(static_cast(m_deg_min)); - for (uint i = 0; i < groups_count; ++i) { - auto next = (from + 1 == count_per_deg_offsets.end()) ? from : from + 1; - auto to = std::lower_bound(next, count_per_deg_offsets.end(), static_cast(total / g * static_cast(i + 1))); - auto to_offset = std::distance(count_per_deg_offsets.begin(), to); - - assert(to != count_per_deg_offsets.end()); - - distributions.push_back(static_cast(*to - *from) / total); - ranges.push_back(static_cast(to_offset)); - from = to; - } - - m_deg_distribution = std::move(distributions); - m_deg_ranges = std::move(ranges); + if (remove_loops) { + if (i == j) + continue; } - void MtxLoader::output_stats() { - std::cout << " " - << "deg: " - << "min " << m_deg_min << ", " - << "max " << m_deg_max << ", " - << "avg " << m_deg_avg << ", " - << "sd " << m_deg_sd << std::endl; - - std::cout << " distribution:" << std::endl; - - const auto n = static_cast(m_n_rows); - const auto default_precision{std::cout.precision()}; - const auto n_digits = static_cast(std::log10(n) + 1.0); - - const double DISPLAY_DENSITY = std::max(double(100), double(m_deg_distribution.size())); - - for (std::size_t i = 0; i < m_deg_distribution.size(); i++) { - auto deg = m_deg_distribution[i] >= 0.01 ? m_deg_distribution[i] : 0.0; - auto k_start = m_deg_ranges[i]; - auto k_end = m_deg_ranges[i + 1]; - auto k_count = std::round(static_cast(deg * DISPLAY_DENSITY)); - - std::cout << " [" << std::setw(n_digits) << k_start << " - " << std::setw(n_digits) << k_end << "): "; - std::cout << std::setw(6) << std::setprecision(2) << deg * 100.0 << std::setprecision(default_precision) << "% "; - for (uint s = 0; s < k_count; ++s) { std::cout << "*"; } - std::cout << std::endl; - } + if (offset_indices) { + i -= 1; + j -= 1; + } + if (make_undirected) { + Ai.push_back(j); + Aj.push_back(i); + if (file_has_values) + Av.push_back(val); } - const std::vector& MtxLoader::get_Ai() const { - return m_Ai; + Ai.push_back(i); + Aj.push_back(j); + if (file_has_values) + Av.push_back(val); + } + t.lap_end(); // parsing + + if (file_has_values) { + struct Edge { + uint i, j; + float w; + bool operator<(const Edge &other) const { + if (i != other.i) + return i < other.i; + return j < other.j; + } + }; + + std::vector edges; + edges.reserve(Ai.size()); + for (std::size_t k = 0; k < Ai.size(); k++) { + edges.push_back({Ai[k], Aj[k], Av[k]}); } - const std::vector& MtxLoader::get_Aj() const { - return m_Aj; + + std::sort(edges.begin(), edges.end()); + + std::vector reduced_Ai; + std::vector reduced_Aj; + std::vector reduced_Av; + reduced_Ai.reserve(edges.size()); + reduced_Aj.reserve(edges.size()); + reduced_Av.reserve(edges.size()); + + for (std::size_t k = 0; k < edges.size(); k++) { + if (k == 0 || edges[k].i != edges[k - 1].i || + edges[k].j != edges[k - 1].j) { + reduced_Ai.push_back(edges[k].i); + reduced_Aj.push_back(edges[k].j); + reduced_Av.push_back(edges[k].w); + } } - const std::vector& MtxLoader::get_Aw() const { - return m_Aw; + + m_n_values = reduced_Ai.size(); + m_Ai = std::move(reduced_Ai); + m_Aj = std::move(reduced_Aj); + m_Aw = std::move(reduced_Av); + + } else { + std::vector sorted; + { + sorted.reserve(Ai.size()); + n_sort = Ai.size(); + + for (std::size_t k = 0; k < Ai.size(); k++) { + std::uint64_t entry = 0; + entry |= std::uint64_t(Ai[k]) << 32u; + entry |= std::uint64_t(Aj[k]) << 0u; + sorted.push_back(entry); + } + Ai.clear(); + Aj.clear(); + + std::sort(sorted.begin(), sorted.end()); } + t.lap_end(); // sorting + + std::vector reduced_Ai; + std::vector reduced_Aj; + { + reduced_Ai.reserve(sorted.size()); + reduced_Aj.reserve(sorted.size()); + + std::uint64_t entry_prev = 0xffffffffffffffff; + for (std::uint64_t entry : sorted) { + if (entry_prev != entry) { + uint i = uint((entry >> 32u) & 0xffffffff); + uint j = uint((entry >> 0u) & 0xffffffff); + reduced_Ai.push_back(i); + reduced_Aj.push_back(j); + } + entry_prev = entry; + } - uint MtxLoader::get_n_rows() const { - return m_n_rows; + m_n_values = reduced_Ai.size(); + m_Ai = std::move(reduced_Ai); + m_Aj = std::move(reduced_Aj); } - uint MtxLoader::get_n_cols() const { - return m_n_cols; + } + + calc_stats(); + t.lap_end(); // stats + + t.stop(); + + std::cout << " 100%" << std::endl; + std::cout << " Parsed in " << t.get_laps_ms()[0] * 1e-3 << " sec " << n_lines + << " lines" + << " speed " << float(n_lines) / (t.get_laps_ms()[0] * 1e-3) + << " lines/sec" << std::endl; + std::cout << " Sorted in " << t.get_laps_ms()[1] * 1e-3 << " sec " << n_sort + << " lines" << std::endl; + std::cout << " Reduced in " << t.get_laps_ms()[2] * 1e-3 << " sec " + << m_n_values << " lines" << std::endl; + std::cout << " Calc stats in " << t.get_laps_ms()[3] * 1e-3 << " sec" + << std::endl; + std::cout << " Loaded in " << t.get_elapsed_ms() * 1e-3 << " sec, " + << m_n_values << " edges total" << std::endl; + + output_stats(); + + return true; +} + +bool MtxLoader::save(const std::filesystem::path &file_path, bool stats_only) { + std::fstream file(file_path, std::ios::out); + + if (!file.is_open()) { + LOG_MSG(Status::Error, "failed to open file " << file_path); + return false; + } + + file << "%%MatrixMarket matrix coordinate pattern general\n"; + file << "%-------------------------------------------------------------------" + "------------\n"; + file << "%-------------------------------------------------------------------" + "------------\n"; + + file << "% meta-info:\n"; + file << "% name: " << m_name << "\n"; + file << "% source-file: " << m_file_path << "\n"; + file << "% deg-avg: " << m_deg_avg << "\n"; + file << "% deg-sd: " << m_deg_sd << "\n"; + file << "% deg-min: " << m_deg_min << "\n"; + file << "% deg-max: " << m_deg_max << "\n"; + file << "% deg-distribution: \n"; + + for (std::size_t i = 0; i < m_deg_distribution.size(); i++) { + file << "% " << m_deg_ranges[i] << " " << m_deg_ranges[i + 1] << " " + << m_deg_distribution[i] << "\n"; + } + + file << "%-------------------------------------------------------------------" + "------------\n"; + file << m_n_rows << " " << m_n_cols << " " << m_n_values << "\n"; + + if (!stats_only) { + const uint offset = m_base_is_zero ? 1 : 0; + for (std::size_t k = 0; k < m_n_values; k++) { + file << m_Ai[k] + offset << " " << m_Aj[k] + offset << "\n"; } - std::size_t MtxLoader::get_n_values() const { - return m_n_values; + } + + return true; +} + +void MtxLoader::calc_stats() { + std::vector deg_pre_vertex(m_n_rows, 0.0f); + + for (auto i : m_Ai) { + deg_pre_vertex[m_base_is_zero ? i : i - 1] += 1; + } + + m_deg_sd = 0.0; + m_deg_avg = 0.0; + m_deg_max = -1.0; + m_deg_min = 1.0 + static_cast(m_n_values); + + for (auto deg : deg_pre_vertex) { + m_deg_min = std::min(m_deg_min, static_cast(deg)); + m_deg_max = std::max(m_deg_max, static_cast(deg)); + m_deg_avg += deg; + m_deg_sd += deg * deg; + } + + auto n = static_cast(m_n_rows); + + m_deg_avg = m_deg_avg / n; + m_deg_sd = std::sqrt(n * (m_deg_sd / n - m_deg_avg * m_deg_avg) / + (n > 1.0 ? n - 1.0 : 1.0)); + + const uint GROUPS_COUNT_MAX = + std::max(uint(10), uint(std::log2(double(m_n_rows) * 0.77))); + + std::vector count_per_deg(static_cast(m_deg_max) + 2, 0); + std::vector count_per_deg_offsets(static_cast(m_deg_max) + 2, 0); + + for (uint i = 0; i < m_n_rows; i++) { + count_per_deg[std::min(deg_pre_vertex[i], uint(m_deg_max))] += 1; + } + + std::exclusive_scan(count_per_deg.begin(), count_per_deg.end(), + count_per_deg_offsets.begin(), 0); + count_per_deg_offsets.back() += 1; + + std::vector distributions; + std::vector ranges; + + auto range = m_deg_max - m_deg_min; + auto groups_count = + std::max(std::min(GROUPS_COUNT_MAX, static_cast(range)), 1u); + auto g = static_cast(groups_count); + + auto total = static_cast(count_per_deg_offsets.back()); + auto from = count_per_deg_offsets.begin(); + + ranges.push_back(static_cast(m_deg_min)); + for (uint i = 0; i < groups_count; ++i) { + auto next = (from + 1 == count_per_deg_offsets.end()) ? from : from + 1; + auto to = std::lower_bound( + next, count_per_deg_offsets.end(), + static_cast(total / g * static_cast(i + 1))); + auto to_offset = std::distance(count_per_deg_offsets.begin(), to); + + assert(to != count_per_deg_offsets.end()); + + distributions.push_back(static_cast(*to - *from) / total); + ranges.push_back(static_cast(to_offset)); + from = to; + } + + m_deg_distribution = std::move(distributions); + m_deg_ranges = std::move(ranges); +} +void MtxLoader::output_stats() { + std::cout << " " + << "deg: " + << "min " << m_deg_min << ", " + << "max " << m_deg_max << ", " + << "avg " << m_deg_avg << ", " + << "sd " << m_deg_sd << std::endl; + + std::cout << " distribution:" << std::endl; + + const auto n = static_cast(m_n_rows); + const auto default_precision{std::cout.precision()}; + const auto n_digits = static_cast(std::log10(n) + 1.0); + + const double DISPLAY_DENSITY = + std::max(double(100), double(m_deg_distribution.size())); + + for (std::size_t i = 0; i < m_deg_distribution.size(); i++) { + auto deg = m_deg_distribution[i] >= 0.01 ? m_deg_distribution[i] : 0.0; + auto k_start = m_deg_ranges[i]; + auto k_end = m_deg_ranges[i + 1]; + auto k_count = std::round(static_cast(deg * DISPLAY_DENSITY)); + + std::cout << " [" << std::setw(n_digits) << k_start << " - " + << std::setw(n_digits) << k_end << "): "; + std::cout << std::setw(6) << std::setprecision(2) << deg * 100.0 + << std::setprecision(default_precision) << "% "; + for (uint s = 0; s < k_count; ++s) { + std::cout << "*"; } + std::cout << std::endl; + } +} + +const std::vector &MtxLoader::get_Ai() const { return m_Ai; } +const std::vector &MtxLoader::get_Aj() const { return m_Aj; } +const std::vector &MtxLoader::get_Aw() const { return m_Aw; } + +uint MtxLoader::get_n_rows() const { return m_n_rows; } +uint MtxLoader::get_n_cols() const { return m_n_cols; } +std::size_t MtxLoader::get_n_values() const { return m_n_values; } -}// namespace spla \ No newline at end of file +} // namespace spla \ No newline at end of file diff --git a/src/matrix.cpp b/src/matrix.cpp index aa319ca3b..ea9839250 100644 --- a/src/matrix.cpp +++ b/src/matrix.cpp @@ -52,7 +52,7 @@ namespace spla { return ref_ptr(new TMatrix(n_rows, n_cols)); } if (type == PAIR) { - return ref_ptr(new TMatrix(n_rows, n_cols)); + return ref_ptr(new TMatrix(n_rows, n_cols)); } LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); diff --git a/src/op.cpp b/src/op.cpp index de85937e8..9ac1049ab 100644 --- a/src/op.cpp +++ b/src/op.cpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include @@ -36,354 +43,379 @@ namespace spla { - ref_ptr IDENTITY_INT; - ref_ptr IDENTITY_UINT; - ref_ptr IDENTITY_FLOAT; - ref_ptr AINV_INT; - ref_ptr AINV_UINT; - ref_ptr AINV_FLOAT; - ref_ptr MINV_INT; - ref_ptr MINV_UINT; - ref_ptr MINV_FLOAT; - ref_ptr LNOT_INT; - ref_ptr LNOT_UINT; - ref_ptr LNOT_FLOAT; - ref_ptr UONE_INT; - ref_ptr UONE_UINT; - ref_ptr UONE_FLOAT; - ref_ptr ABS_INT; - ref_ptr ABS_UINT; - ref_ptr ABS_FLOAT; - - ref_ptr BNOT_INT; - ref_ptr BNOT_UINT; - - ref_ptr SQRT_FLOAT; - ref_ptr LOG_FLOAT; - ref_ptr EXP_FLOAT; - ref_ptr SIN_FLOAT; - ref_ptr COS_FLOAT; - ref_ptr TAN_FLOAT; - ref_ptr ASIN_FLOAT; - ref_ptr ACOS_FLOAT; - ref_ptr ATAN_FLOAT; - ref_ptr CEIL_FLOAT; - ref_ptr FLOOR_FLOAT; - ref_ptr ROUND_FLOAT; - ref_ptr TRUNC_FLOAT; - ref_ptr IDENTITY_PAIR; - - - ////////////////////////////////////////////////////////////////////////////// - - ref_ptr PLUS_INT; - ref_ptr PLUS_UINT; - ref_ptr PLUS_FLOAT; - ref_ptr MINUS_INT; - ref_ptr MINUS_UINT; - ref_ptr MINUS_FLOAT; - ref_ptr MULT_INT; - ref_ptr MULT_UINT; - ref_ptr MULT_FLOAT; - ref_ptr DIV_INT; - ref_ptr DIV_UINT; - ref_ptr DIV_FLOAT; - - ref_ptr MINUS_POW2_INT; - ref_ptr MINUS_POW2_UINT; - ref_ptr MINUS_POW2_FLOAT; - - ref_ptr FIRST_INT; - ref_ptr FIRST_UINT; - ref_ptr FIRST_FLOAT; - ref_ptr SECOND_INT; - ref_ptr SECOND_UINT; - ref_ptr SECOND_FLOAT; - - ref_ptr BONE_INT; - ref_ptr BONE_UINT; - ref_ptr BONE_FLOAT; - - ref_ptr MIN_INT; - ref_ptr MIN_UINT; - ref_ptr MIN_FLOAT; - ref_ptr MAX_INT; - ref_ptr MAX_UINT; - ref_ptr MAX_FLOAT; - - ref_ptr LOR_INT; - ref_ptr LOR_UINT; - ref_ptr LOR_FLOAT; - ref_ptr LAND_INT; - ref_ptr LAND_UINT; - ref_ptr LAND_FLOAT; - - ref_ptr BOR_INT; - ref_ptr BOR_UINT; - ref_ptr BAND_INT; - ref_ptr BAND_UINT; - ref_ptr BXOR_INT; - ref_ptr BXOR_UINT; - - ref_ptr MIN_PAIR; - ref_ptr MUL_PAIR; - - ////////////////////////////////////////////////////////////////////////////// - - ref_ptr EQZERO_INT; - ref_ptr EQZERO_UINT; - ref_ptr EQZERO_FLOAT; - ref_ptr NQZERO_INT; - ref_ptr NQZERO_UINT; - ref_ptr NQZERO_FLOAT; - ref_ptr GTZERO_INT; - ref_ptr GTZERO_UINT; - ref_ptr GTZERO_FLOAT; - ref_ptr GEZERO_INT; - ref_ptr GEZERO_UINT; - ref_ptr GEZERO_FLOAT; - ref_ptr LTZERO_INT; - ref_ptr LTZERO_UINT; - ref_ptr LTZERO_FLOAT; - ref_ptr LEZERO_INT; - ref_ptr LEZERO_UINT; - ref_ptr LEZERO_FLOAT; - ref_ptr ALWAYS_INT; - ref_ptr ALWAYS_UINT; - ref_ptr ALWAYS_FLOAT; - ref_ptr ALWAYS_PAIR; - ref_ptr NEVER_INT; - ref_ptr NEVER_UINT; - ref_ptr NEVER_FLOAT; - - - template - inline T min(T a, T b) { return std::min(a, b); } - - template - inline T max(T a, T b) { return std::max(a, b); } - - void register_ops() { - DECL_OP_UNA_S(IDENTITY_INT, IDENSTITY, T_INT, { return a; }); - DECL_OP_UNA_S(IDENTITY_UINT, IDENSTITY, T_UINT, { return a; }); - DECL_OP_UNA_S(IDENTITY_FLOAT, IDENSTITY, T_FLOAT, { return a; }); - DECL_OP_UNA_S(AINV_INT, AINV, T_INT, { return -a; }); - DECL_OP_UNA_S(AINV_UINT, AINV, T_UINT, { return -a; }); - DECL_OP_UNA_S(AINV_FLOAT, AINV, T_FLOAT, { return -a; }); - DECL_OP_UNA_S(MINV_INT, MINV, T_INT, { return 1 / a; }); - DECL_OP_UNA_S(MINV_UINT, MINV, T_UINT, { return 1 / a; }); - DECL_OP_UNA_S(MINV_FLOAT, MINV, T_FLOAT, { return 1.0f / a; }); - DECL_OP_UNA_S(LNOT_INT, LNOT, T_INT, { return !(a != 0); }); - DECL_OP_UNA_S(LNOT_UINT, LNOT, T_UINT, { return !(a != 0); }); - DECL_OP_UNA_S(LNOT_FLOAT, LNOT, T_FLOAT, { return !(a != 0); }); - DECL_OP_UNA_S(UONE_INT, UONE, T_INT, { return 1; }); - DECL_OP_UNA_S(UONE_UINT, UONE, T_UINT, { return 1; }); - DECL_OP_UNA_S(UONE_FLOAT, UONE, T_FLOAT, { return 1; }); - DECL_OP_UNA_S(ABS_INT, ABS, T_INT, { return abs(a); }); - DECL_OP_UNA_S(ABS_UINT, ABS, T_UINT, { return a; }); - DECL_OP_UNA_S(ABS_FLOAT, ABS, T_FLOAT, { return fabs(a); }); - - DECL_OP_UNA_S(BNOT_INT, BNOT, T_INT, { return ~a; }); - DECL_OP_UNA_S(BNOT_UINT, BNOT, T_UINT, { return ~a; }); - - DECL_OP_UNA_S(SQRT_FLOAT, SQRT, T_FLOAT, { return sqrt(a); }); - DECL_OP_UNA_S(LOG_FLOAT, LOG, T_FLOAT, { return log(a); }); - DECL_OP_UNA_S(EXP_FLOAT, EXP, T_FLOAT, { return exp(a); }); - DECL_OP_UNA_S(SIN_FLOAT, SIN, T_FLOAT, { return sin(a); }); - DECL_OP_UNA_S(COS_FLOAT, COS, T_FLOAT, { return cos(a); }); - DECL_OP_UNA_S(TAN_FLOAT, TAN, T_FLOAT, { return tan(a); }); - DECL_OP_UNA_S(ASIN_FLOAT, ASIN, T_FLOAT, { return asin(a); }); - DECL_OP_UNA_S(ACOS_FLOAT, ACOS, T_FLOAT, { return acos(a); }); - DECL_OP_UNA_S(ATAN_FLOAT, ATAN, T_FLOAT, { return atan(a); }); - DECL_OP_UNA_S(CEIL_FLOAT, CEIL, T_FLOAT, { return ceil(a); }); - DECL_OP_UNA_S(FLOOR_FLOAT, FLOOR, T_FLOAT, { return floor(a); }); - DECL_OP_UNA_S(ROUND_FLOAT, ROUND, T_FLOAT, { return round(a); }); - DECL_OP_UNA_S(TRUNC_FLOAT, TRUNC, T_FLOAT, { return trunc(a); }); - IDENTITY_PAIR = spla::OpUnary::make_pair("IDENTITY_PAIR", "(a) identity_pair(a)", [](Pair a) { return a; }); - - DECL_OP_BIN_S(PLUS_INT, PLUS, T_INT, { return a + b; }); - DECL_OP_BIN_S(PLUS_UINT, PLUS, T_UINT, { return a + b; }); - DECL_OP_BIN_S(PLUS_FLOAT, PLUS, T_FLOAT, { return a + b; }); - DECL_OP_BIN_S(MINUS_INT, MINUS, T_INT, { return a - b; }); - DECL_OP_BIN_S(MINUS_UINT, MINUS, T_UINT, { return a - b; }); - DECL_OP_BIN_S(MINUS_FLOAT, MINUS, T_FLOAT, { return a - b; }); - DECL_OP_BIN_S(MULT_INT, MULT, T_INT, { return a * b; }); - DECL_OP_BIN_S(MULT_UINT, MULT, T_UINT, { return a * b; }); - DECL_OP_BIN_S(MULT_FLOAT, MULT, T_FLOAT, { return a * b; }); - DECL_OP_BIN_S(DIV_INT, DIV, T_INT, { return a / b; }); - DECL_OP_BIN_S(DIV_UINT, DIV, T_UINT, { return a / b; }); - DECL_OP_BIN_S(DIV_FLOAT, DIV, T_FLOAT, { return a / b; }); - - DECL_OP_BIN_S(MINUS_POW2_INT, MINUS_POW2, T_INT, { return (a - b) * (a - b); }); - DECL_OP_BIN_S(MINUS_POW2_UINT, MINUS_POW2, T_UINT, { return (a - b) * (a - b); }); - DECL_OP_BIN_S(MINUS_POW2_FLOAT, MINUS_POW2, T_FLOAT, { return (a - b) * (a - b); }); - - DECL_OP_BIN_S(FIRST_INT, FIRST, T_INT, { return a; }); - DECL_OP_BIN_S(FIRST_UINT, FIRST, T_UINT, { return a; }); - DECL_OP_BIN_S(FIRST_FLOAT, FIRST, T_FLOAT, { return a; }); - DECL_OP_BIN_S(SECOND_INT, SECOND, T_INT, { return b; }); - DECL_OP_BIN_S(SECOND_UINT, SECOND, T_UINT, { return b; }); - DECL_OP_BIN_S(SECOND_FLOAT, SECOND, T_FLOAT, { return b; }); - - DECL_OP_BIN_S(BONE_INT, BONE, T_INT, { return 1; }); - DECL_OP_BIN_S(BONE_UINT, BONE, T_UINT, { return 1; }); - DECL_OP_BIN_S(BONE_FLOAT, BONE, T_FLOAT, { return 1; }); - - DECL_OP_BIN_S(MIN_INT, MIN, T_INT, { return min(a, b); }); - DECL_OP_BIN_S(MIN_UINT, MIN, T_UINT, { return min(a, b); }); - DECL_OP_BIN_S(MIN_FLOAT, MIN, T_FLOAT, { return min(a, b); }); - DECL_OP_BIN_S(MAX_INT, MAX, T_INT, { return max(a, b); }); - DECL_OP_BIN_S(MAX_UINT, MAX, T_UINT, { return max(a, b); }); - DECL_OP_BIN_S(MAX_FLOAT, MAX, T_FLOAT, { return max(a, b); }); - - DECL_OP_BIN_S(LOR_INT, LOR, T_INT, { return a || b; }); - DECL_OP_BIN_S(LOR_UINT, LOR, T_UINT, { return a || b; }); - DECL_OP_BIN_S(LOR_FLOAT, LOR, T_FLOAT, { return a || b; }); - DECL_OP_BIN_S(LAND_INT, LAND, T_INT, { return a && b; }); - DECL_OP_BIN_S(LAND_UINT, LAND, T_UINT, { return a && b; }); - DECL_OP_BIN_S(LAND_FLOAT, LAND, T_FLOAT, { return a && b; }); - - DECL_OP_BIN_S(BOR_INT, BOR, T_INT, { return a | b; }); - DECL_OP_BIN_S(BOR_UINT, BOR, T_UINT, { return a | b; }); - DECL_OP_BIN_S(BAND_INT, BAND, T_INT, { return a & b; }); - DECL_OP_BIN_S(BAND_UINT, BAND, T_UINT, { return a & b; }); - DECL_OP_BIN_S(BXOR_INT, BXOR, T_INT, { return a ^ b; }); - DECL_OP_BIN_S(BXOR_UINT, BXOR, T_UINT, { return a ^ b; }); - - MUL_PAIR = OpBinary::make_pair("MUL_PAIR", - "(a, b) make_pair(a.weight, b.vertex)", - [](Pair a, Pair b) { return Pair(a.weight, b.vertex); }); - MIN_PAIR = OpBinary::make_pair("MIN_PAIR", - "(a, b) min_pair(a, b)", - [](Pair a, Pair b) { - if (a.weight == b.weight) return a.vertex < b.vertex? a : b; - return a.weight < b.weight? a : b; }); - - DECL_OP_SELECT(EQZERO_INT, EQZERO, T_INT, { return a == 0; }); - DECL_OP_SELECT(EQZERO_UINT, EQZERO, T_UINT, { return a == 0; }); - DECL_OP_SELECT(EQZERO_FLOAT, EQZERO, T_FLOAT, { return a == 0; }); - DECL_OP_SELECT(NQZERO_INT, NQZERO, T_INT, { return a != 0; }); - DECL_OP_SELECT(NQZERO_UINT, NQZERO, T_UINT, { return a != 0; }); - DECL_OP_SELECT(NQZERO_FLOAT, NQZERO, T_FLOAT, { return a != 0; }); - DECL_OP_SELECT(GTZERO_INT, GTZERO, T_INT, { return a > 0; }); - DECL_OP_SELECT(GTZERO_UINT, GTZERO, T_UINT, { return a > 0; }); - DECL_OP_SELECT(GTZERO_FLOAT, GTZERO, T_FLOAT, { return a > 0; }); - DECL_OP_SELECT(GEZERO_INT, GEZERO, T_INT, { return a >= 0; }); - DECL_OP_SELECT(GEZERO_UINT, GEZERO, T_UINT, { return a >= 0; }); - DECL_OP_SELECT(GEZERO_FLOAT, GEZERO, T_FLOAT, { return a >= 0; }); - DECL_OP_SELECT(LTZERO_INT, LTZERO, T_INT, { return a < 0; }); - DECL_OP_SELECT(LTZERO_UINT, LTZERO, T_UINT, { return a < 0; }); - DECL_OP_SELECT(LTZERO_FLOAT, LTZERO, T_FLOAT, { return a < 0; }); - DECL_OP_SELECT(LEZERO_INT, LEZERO, T_INT, { return a <= 0; }); - DECL_OP_SELECT(LEZERO_UINT, LEZERO, T_UINT, { return a <= 0; }); - DECL_OP_SELECT(LEZERO_FLOAT, LEZERO, T_FLOAT, { return a <= 0; }); - DECL_OP_SELECT(ALWAYS_INT, ALWAYS, T_INT, { return 1; }); - DECL_OP_SELECT(ALWAYS_UINT, ALWAYS, T_UINT, { return 1; }); - DECL_OP_SELECT(ALWAYS_FLOAT, ALWAYS, T_FLOAT, { return 1; }); - ALWAYS_PAIR = OpSelect::make_pair("ALWAYS_PAIR", "(a) pair_always(a)", [](Pair a) { return 1; }); - DECL_OP_SELECT(NEVER_INT, NEVER, T_INT, { return 0; }); - DECL_OP_SELECT(NEVER_UINT, NEVER, T_UINT, { return 0; }); - DECL_OP_SELECT(NEVER_FLOAT, NEVER, T_FLOAT, { return 0; }); - - } - - ref_ptr OpUnary::make_int(std::string name, std::string code, std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_res()->get_code(); - return op.as(); - } - ref_ptr OpUnary::make_uint(std::string name, std::string code, std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_res()->get_code(); - return op.as(); - } - ref_ptr OpUnary::make_float(std::string name, std::string code, std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_res()->get_code(); - return op.as(); - } - ref_ptr OpUnary::make_pair(std::string name, std::string code, std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_res()->get_code(); - return op.as(); - } - - ref_ptr OpBinary::make_int(std::string name, std::string code, std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); - return op.as(); - } - ref_ptr OpBinary::make_uint(std::string name, std::string code, std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); - return op.as(); - } - ref_ptr OpBinary::make_float(std::string name, std::string code, std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); - return op.as(); - } - ref_ptr OpBinary::make_pair(std::string name, std::string code, std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); - return op.as(); - } - - ref_ptr OpSelect::make_int(std::string name, std::string code, std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code(); - return op.as(); - } - ref_ptr OpSelect::make_uint(std::string name, std::string code, std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code(); - return op.as(); - } - ref_ptr OpSelect::make_float(std::string name, std::string code, std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code(); - return op.as(); - } - ref_ptr OpSelect::make_pair(std::string name, std::string code, std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code(); - return op.as(); - } - -}// namespace spla \ No newline at end of file +ref_ptr IDENTITY_INT; +ref_ptr IDENTITY_UINT; +ref_ptr IDENTITY_FLOAT; +ref_ptr AINV_INT; +ref_ptr AINV_UINT; +ref_ptr AINV_FLOAT; +ref_ptr MINV_INT; +ref_ptr MINV_UINT; +ref_ptr MINV_FLOAT; +ref_ptr LNOT_INT; +ref_ptr LNOT_UINT; +ref_ptr LNOT_FLOAT; +ref_ptr UONE_INT; +ref_ptr UONE_UINT; +ref_ptr UONE_FLOAT; +ref_ptr ABS_INT; +ref_ptr ABS_UINT; +ref_ptr ABS_FLOAT; + +ref_ptr BNOT_INT; +ref_ptr BNOT_UINT; + +ref_ptr SQRT_FLOAT; +ref_ptr LOG_FLOAT; +ref_ptr EXP_FLOAT; +ref_ptr SIN_FLOAT; +ref_ptr COS_FLOAT; +ref_ptr TAN_FLOAT; +ref_ptr ASIN_FLOAT; +ref_ptr ACOS_FLOAT; +ref_ptr ATAN_FLOAT; +ref_ptr CEIL_FLOAT; +ref_ptr FLOOR_FLOAT; +ref_ptr ROUND_FLOAT; +ref_ptr TRUNC_FLOAT; +ref_ptr IDENTITY_PAIR; + +////////////////////////////////////////////////////////////////////////////// + +ref_ptr PLUS_INT; +ref_ptr PLUS_UINT; +ref_ptr PLUS_FLOAT; +ref_ptr MINUS_INT; +ref_ptr MINUS_UINT; +ref_ptr MINUS_FLOAT; +ref_ptr MULT_INT; +ref_ptr MULT_UINT; +ref_ptr MULT_FLOAT; +ref_ptr DIV_INT; +ref_ptr DIV_UINT; +ref_ptr DIV_FLOAT; + +ref_ptr MINUS_POW2_INT; +ref_ptr MINUS_POW2_UINT; +ref_ptr MINUS_POW2_FLOAT; + +ref_ptr FIRST_INT; +ref_ptr FIRST_UINT; +ref_ptr FIRST_FLOAT; +ref_ptr SECOND_INT; +ref_ptr SECOND_UINT; +ref_ptr SECOND_FLOAT; + +ref_ptr BONE_INT; +ref_ptr BONE_UINT; +ref_ptr BONE_FLOAT; + +ref_ptr MIN_INT; +ref_ptr MIN_UINT; +ref_ptr MIN_FLOAT; +ref_ptr MAX_INT; +ref_ptr MAX_UINT; +ref_ptr MAX_FLOAT; + +ref_ptr LOR_INT; +ref_ptr LOR_UINT; +ref_ptr LOR_FLOAT; +ref_ptr LAND_INT; +ref_ptr LAND_UINT; +ref_ptr LAND_FLOAT; + +ref_ptr BOR_INT; +ref_ptr BOR_UINT; +ref_ptr BAND_INT; +ref_ptr BAND_UINT; +ref_ptr BXOR_INT; +ref_ptr BXOR_UINT; + +ref_ptr MIN_PAIR; +ref_ptr MUL_PAIR; + +////////////////////////////////////////////////////////////////////////////// + +ref_ptr EQZERO_INT; +ref_ptr EQZERO_UINT; +ref_ptr EQZERO_FLOAT; +ref_ptr NQZERO_INT; +ref_ptr NQZERO_UINT; +ref_ptr NQZERO_FLOAT; +ref_ptr GTZERO_INT; +ref_ptr GTZERO_UINT; +ref_ptr GTZERO_FLOAT; +ref_ptr GEZERO_INT; +ref_ptr GEZERO_UINT; +ref_ptr GEZERO_FLOAT; +ref_ptr LTZERO_INT; +ref_ptr LTZERO_UINT; +ref_ptr LTZERO_FLOAT; +ref_ptr LEZERO_INT; +ref_ptr LEZERO_UINT; +ref_ptr LEZERO_FLOAT; +ref_ptr ALWAYS_INT; +ref_ptr ALWAYS_UINT; +ref_ptr ALWAYS_FLOAT; +ref_ptr ALWAYS_PAIR; +ref_ptr NEVER_INT; +ref_ptr NEVER_UINT; +ref_ptr NEVER_FLOAT; + +template inline T min(T a, T b) { return std::min(a, b); } + +template inline T max(T a, T b) { return std::max(a, b); } + +void register_ops() { + DECL_OP_UNA_S(IDENTITY_INT, IDENSTITY, T_INT, { return a; }); + DECL_OP_UNA_S(IDENTITY_UINT, IDENSTITY, T_UINT, { return a; }); + DECL_OP_UNA_S(IDENTITY_FLOAT, IDENSTITY, T_FLOAT, { return a; }); + DECL_OP_UNA_S(AINV_INT, AINV, T_INT, { return -a; }); + DECL_OP_UNA_S(AINV_UINT, AINV, T_UINT, { return -a; }); + DECL_OP_UNA_S(AINV_FLOAT, AINV, T_FLOAT, { return -a; }); + DECL_OP_UNA_S(MINV_INT, MINV, T_INT, { return 1 / a; }); + DECL_OP_UNA_S(MINV_UINT, MINV, T_UINT, { return 1 / a; }); + DECL_OP_UNA_S(MINV_FLOAT, MINV, T_FLOAT, { return 1.0f / a; }); + DECL_OP_UNA_S(LNOT_INT, LNOT, T_INT, { return !(a != 0); }); + DECL_OP_UNA_S(LNOT_UINT, LNOT, T_UINT, { return !(a != 0); }); + DECL_OP_UNA_S(LNOT_FLOAT, LNOT, T_FLOAT, { return !(a != 0); }); + DECL_OP_UNA_S(UONE_INT, UONE, T_INT, { return 1; }); + DECL_OP_UNA_S(UONE_UINT, UONE, T_UINT, { return 1; }); + DECL_OP_UNA_S(UONE_FLOAT, UONE, T_FLOAT, { return 1; }); + DECL_OP_UNA_S(ABS_INT, ABS, T_INT, { return abs(a); }); + DECL_OP_UNA_S(ABS_UINT, ABS, T_UINT, { return a; }); + DECL_OP_UNA_S(ABS_FLOAT, ABS, T_FLOAT, { return fabs(a); }); + + DECL_OP_UNA_S(BNOT_INT, BNOT, T_INT, { return ~a; }); + DECL_OP_UNA_S(BNOT_UINT, BNOT, T_UINT, { return ~a; }); + + DECL_OP_UNA_S(SQRT_FLOAT, SQRT, T_FLOAT, { return sqrt(a); }); + DECL_OP_UNA_S(LOG_FLOAT, LOG, T_FLOAT, { return log(a); }); + DECL_OP_UNA_S(EXP_FLOAT, EXP, T_FLOAT, { return exp(a); }); + DECL_OP_UNA_S(SIN_FLOAT, SIN, T_FLOAT, { return sin(a); }); + DECL_OP_UNA_S(COS_FLOAT, COS, T_FLOAT, { return cos(a); }); + DECL_OP_UNA_S(TAN_FLOAT, TAN, T_FLOAT, { return tan(a); }); + DECL_OP_UNA_S(ASIN_FLOAT, ASIN, T_FLOAT, { return asin(a); }); + DECL_OP_UNA_S(ACOS_FLOAT, ACOS, T_FLOAT, { return acos(a); }); + DECL_OP_UNA_S(ATAN_FLOAT, ATAN, T_FLOAT, { return atan(a); }); + DECL_OP_UNA_S(CEIL_FLOAT, CEIL, T_FLOAT, { return ceil(a); }); + DECL_OP_UNA_S(FLOOR_FLOAT, FLOOR, T_FLOAT, { return floor(a); }); + DECL_OP_UNA_S(ROUND_FLOAT, ROUND, T_FLOAT, { return round(a); }); + DECL_OP_UNA_S(TRUNC_FLOAT, TRUNC, T_FLOAT, { return trunc(a); }); + IDENTITY_PAIR = spla::OpUnary::make_pair( + "IDENTITY_PAIR", "(a) identity_pair(a)", [](Pair a) { return a; }); + + DECL_OP_BIN_S(PLUS_INT, PLUS, T_INT, { return a + b; }); + DECL_OP_BIN_S(PLUS_UINT, PLUS, T_UINT, { return a + b; }); + DECL_OP_BIN_S(PLUS_FLOAT, PLUS, T_FLOAT, { return a + b; }); + DECL_OP_BIN_S(MINUS_INT, MINUS, T_INT, { return a - b; }); + DECL_OP_BIN_S(MINUS_UINT, MINUS, T_UINT, { return a - b; }); + DECL_OP_BIN_S(MINUS_FLOAT, MINUS, T_FLOAT, { return a - b; }); + DECL_OP_BIN_S(MULT_INT, MULT, T_INT, { return a * b; }); + DECL_OP_BIN_S(MULT_UINT, MULT, T_UINT, { return a * b; }); + DECL_OP_BIN_S(MULT_FLOAT, MULT, T_FLOAT, { return a * b; }); + DECL_OP_BIN_S(DIV_INT, DIV, T_INT, { return a / b; }); + DECL_OP_BIN_S(DIV_UINT, DIV, T_UINT, { return a / b; }); + DECL_OP_BIN_S(DIV_FLOAT, DIV, T_FLOAT, { return a / b; }); + + DECL_OP_BIN_S(MINUS_POW2_INT, MINUS_POW2, T_INT, + { return (a - b) * (a - b); }); + DECL_OP_BIN_S(MINUS_POW2_UINT, MINUS_POW2, T_UINT, + { return (a - b) * (a - b); }); + DECL_OP_BIN_S(MINUS_POW2_FLOAT, MINUS_POW2, T_FLOAT, + { return (a - b) * (a - b); }); + + DECL_OP_BIN_S(FIRST_INT, FIRST, T_INT, { return a; }); + DECL_OP_BIN_S(FIRST_UINT, FIRST, T_UINT, { return a; }); + DECL_OP_BIN_S(FIRST_FLOAT, FIRST, T_FLOAT, { return a; }); + DECL_OP_BIN_S(SECOND_INT, SECOND, T_INT, { return b; }); + DECL_OP_BIN_S(SECOND_UINT, SECOND, T_UINT, { return b; }); + DECL_OP_BIN_S(SECOND_FLOAT, SECOND, T_FLOAT, { return b; }); + + DECL_OP_BIN_S(BONE_INT, BONE, T_INT, { return 1; }); + DECL_OP_BIN_S(BONE_UINT, BONE, T_UINT, { return 1; }); + DECL_OP_BIN_S(BONE_FLOAT, BONE, T_FLOAT, { return 1; }); + + DECL_OP_BIN_S(MIN_INT, MIN, T_INT, { return min(a, b); }); + DECL_OP_BIN_S(MIN_UINT, MIN, T_UINT, { return min(a, b); }); + DECL_OP_BIN_S(MIN_FLOAT, MIN, T_FLOAT, { return min(a, b); }); + DECL_OP_BIN_S(MAX_INT, MAX, T_INT, { return max(a, b); }); + DECL_OP_BIN_S(MAX_UINT, MAX, T_UINT, { return max(a, b); }); + DECL_OP_BIN_S(MAX_FLOAT, MAX, T_FLOAT, { return max(a, b); }); + + DECL_OP_BIN_S(LOR_INT, LOR, T_INT, { return a || b; }); + DECL_OP_BIN_S(LOR_UINT, LOR, T_UINT, { return a || b; }); + DECL_OP_BIN_S(LOR_FLOAT, LOR, T_FLOAT, { return a || b; }); + DECL_OP_BIN_S(LAND_INT, LAND, T_INT, { return a && b; }); + DECL_OP_BIN_S(LAND_UINT, LAND, T_UINT, { return a && b; }); + DECL_OP_BIN_S(LAND_FLOAT, LAND, T_FLOAT, { return a && b; }); + + DECL_OP_BIN_S(BOR_INT, BOR, T_INT, { return a | b; }); + DECL_OP_BIN_S(BOR_UINT, BOR, T_UINT, { return a | b; }); + DECL_OP_BIN_S(BAND_INT, BAND, T_INT, { return a & b; }); + DECL_OP_BIN_S(BAND_UINT, BAND, T_UINT, { return a & b; }); + DECL_OP_BIN_S(BXOR_INT, BXOR, T_INT, { return a ^ b; }); + DECL_OP_BIN_S(BXOR_UINT, BXOR, T_UINT, { return a ^ b; }); + + MUL_PAIR = OpBinary::make_pair( + "MUL_PAIR", "(a, b) make_pair(a.weight, b.vertex)", + [](Pair a, Pair b) { return Pair(a.weight, b.vertex); }); + MIN_PAIR = OpBinary::make_pair("MIN_PAIR", "(a, b) min_pair(a, b)", + [](Pair a, Pair b) { + if (a.weight == b.weight) + return a.vertex < b.vertex ? a : b; + return a.weight < b.weight ? a : b; + }); + + DECL_OP_SELECT(EQZERO_INT, EQZERO, T_INT, { return a == 0; }); + DECL_OP_SELECT(EQZERO_UINT, EQZERO, T_UINT, { return a == 0; }); + DECL_OP_SELECT(EQZERO_FLOAT, EQZERO, T_FLOAT, { return a == 0; }); + DECL_OP_SELECT(NQZERO_INT, NQZERO, T_INT, { return a != 0; }); + DECL_OP_SELECT(NQZERO_UINT, NQZERO, T_UINT, { return a != 0; }); + DECL_OP_SELECT(NQZERO_FLOAT, NQZERO, T_FLOAT, { return a != 0; }); + DECL_OP_SELECT(GTZERO_INT, GTZERO, T_INT, { return a > 0; }); + DECL_OP_SELECT(GTZERO_UINT, GTZERO, T_UINT, { return a > 0; }); + DECL_OP_SELECT(GTZERO_FLOAT, GTZERO, T_FLOAT, { return a > 0; }); + DECL_OP_SELECT(GEZERO_INT, GEZERO, T_INT, { return a >= 0; }); + DECL_OP_SELECT(GEZERO_UINT, GEZERO, T_UINT, { return a >= 0; }); + DECL_OP_SELECT(GEZERO_FLOAT, GEZERO, T_FLOAT, { return a >= 0; }); + DECL_OP_SELECT(LTZERO_INT, LTZERO, T_INT, { return a < 0; }); + DECL_OP_SELECT(LTZERO_UINT, LTZERO, T_UINT, { return a < 0; }); + DECL_OP_SELECT(LTZERO_FLOAT, LTZERO, T_FLOAT, { return a < 0; }); + DECL_OP_SELECT(LEZERO_INT, LEZERO, T_INT, { return a <= 0; }); + DECL_OP_SELECT(LEZERO_UINT, LEZERO, T_UINT, { return a <= 0; }); + DECL_OP_SELECT(LEZERO_FLOAT, LEZERO, T_FLOAT, { return a <= 0; }); + DECL_OP_SELECT(ALWAYS_INT, ALWAYS, T_INT, { return 1; }); + DECL_OP_SELECT(ALWAYS_UINT, ALWAYS, T_UINT, { return 1; }); + DECL_OP_SELECT(ALWAYS_FLOAT, ALWAYS, T_FLOAT, { return 1; }); + ALWAYS_PAIR = OpSelect::make_pair("ALWAYS_PAIR", "(a) pair_always(a)", + [](Pair a) { return 1; }); + DECL_OP_SELECT(NEVER_INT, NEVER, T_INT, { return 0; }); + DECL_OP_SELECT(NEVER_UINT, NEVER, T_UINT, { return 0; }); + DECL_OP_SELECT(NEVER_FLOAT, NEVER, T_FLOAT, { return 0; }); +} + +ref_ptr OpUnary::make_int(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_res()->get_code(); + return op.as(); +} +ref_ptr OpUnary::make_uint(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_res()->get_code(); + return op.as(); +} +ref_ptr OpUnary::make_float(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_res()->get_code(); + return op.as(); +} +ref_ptr OpUnary::make_pair(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_res()->get_code(); + return op.as(); +} + +ref_ptr +OpBinary::make_int(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + return op.as(); +} +ref_ptr +OpBinary::make_uint(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + return op.as(); +} +ref_ptr +OpBinary::make_float(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + return op.as(); +} +ref_ptr +OpBinary::make_pair(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + return op.as(); +} + +ref_ptr OpSelect::make_int(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code(); + return op.as(); +} +ref_ptr OpSelect::make_uint(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code(); + return op.as(); +} +ref_ptr OpSelect::make_float(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code(); + return op.as(); +} +ref_ptr OpSelect::make_pair(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code(); + return op.as(); +} + +} // namespace spla \ No newline at end of file diff --git a/src/opencl/cl_algo_registry.cpp b/src/opencl/cl_algo_registry.cpp index 803316bb5..715acde66 100644 --- a/src/opencl/cl_algo_registry.cpp +++ b/src/opencl/cl_algo_registry.cpp @@ -41,7 +41,6 @@ #include #include - namespace spla { void register_algo_cl(class Registry* g_registry) { @@ -84,7 +83,8 @@ namespace spla { g_registry->add(MAKE_KEY_CL_0("mxv_masked", INT), std::make_shared>()); g_registry->add(MAKE_KEY_CL_0("mxv_masked", UINT), std::make_shared>()); g_registry->add(MAKE_KEY_CL_0("mxv_masked", FLOAT), std::make_shared>()); - g_registry->add(MAKE_KEY_CL_0("mxv_masked", PAIR), std::make_shared>()); + g_registry->add(MAKE_KEY_CL_0("mxv_masked", PAIR), + std::make_shared>()); // algorthm vxm_masked g_registry->add(MAKE_KEY_CL_0("vxm_masked", INT), std::make_shared>()); @@ -95,7 +95,6 @@ namespace spla { g_registry->add(MAKE_KEY_CL_0("mxmT_masked", INT), std::make_shared>()); g_registry->add(MAKE_KEY_CL_0("mxmT_masked", UINT), std::make_shared>()); g_registry->add(MAKE_KEY_CL_0("mxmT_masked", FLOAT), std::make_shared>()); - } }// namespace spla diff --git a/src/opencl/cl_mxv.hpp b/src/opencl/cl_mxv.hpp index 862ae35b2..f7a04317c 100644 --- a/src/opencl/cl_mxv.hpp +++ b/src/opencl/cl_mxv.hpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #ifndef SPLA_CL_MXV_HPP @@ -49,239 +56,259 @@ namespace spla { - template - class Algo_mxv_masked_cl final : public RegistryAlgo { - public: - ~Algo_mxv_masked_cl() override = default; - - std::string get_name() override { - return "mxv_masked"; - } - - std::string get_description() override { - return "parallel matrix-vector masked product on opencl device"; - } - - Status execute(const DispatchContext& ctx) override { - auto t = ctx.task.template cast_safe(); - auto early_exit = t->get_desc_or_default()->get_early_exit(); - - if (early_exit) { - return execute_config_scalar(ctx); - } else { - return execute_vector(ctx); - } - } - - private: - Status execute_vector(const DispatchContext& ctx) { - TIME_PROFILE_SCOPE("opencl/mxv/vector"); - - auto t = ctx.task.template cast_safe(); - - ref_ptr> r = t->r.template cast_safe>(); - ref_ptr> mask = t->mask.template cast_safe>(); - ref_ptr> M = t->M.template cast_safe>(); - ref_ptr> v = t->v.template cast_safe>(); - ref_ptr> op_multiply = t->op_multiply.template cast_safe>(); - ref_ptr> op_add = t->op_add.template cast_safe>(); - ref_ptr> op_select = t->op_select.template cast_safe>(); - ref_ptr> init = t->init.template cast_safe>(); - - r->validate_wd(FormatVector::AccDense); - mask->validate_rw(FormatVector::AccDense); - M->validate_rw(FormatMatrix::AccCsr); - v->validate_rw(FormatVector::AccDense); - - std::shared_ptr program; - if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError; - - auto* p_cl_r = r->template get>(); - auto* p_cl_mask = mask->template get>(); - auto* p_cl_M = M->template get>(); - auto* p_cl_v = v->template get>(); - - auto* p_cl_acc = get_acc_cl(); - auto& queue = p_cl_acc->get_queue_default(); - - auto kernel_vector = program->make_kernel("mxv_vector"); - kernel_vector.setArg(0, p_cl_M->Ap); - kernel_vector.setArg(1, p_cl_M->Aj); - kernel_vector.setArg(2, p_cl_M->Ax); - kernel_vector.setArg(3, p_cl_v->Ax); - kernel_vector.setArg(4, p_cl_mask->Ax); - kernel_vector.setArg(5, p_cl_r->Ax); - kernel_vector.setArg(6, init->get_value()); - kernel_vector.setArg(7, r->get_n_rows()); - - uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_count, 1, 512); - - cl::NDRange exec_global(m_block_count * n_groups_to_dispatch, m_block_size); - cl::NDRange exec_local(m_block_count, m_block_size); - CL_DISPATCH_PROFILED("exec", queue, kernel_vector, cl::NDRange(), exec_global, exec_local); - - return Status::Ok; - } - - Status execute_scalar(const DispatchContext& ctx) { - TIME_PROFILE_SCOPE("opencl/mxv/scalar"); - - auto t = ctx.task.template cast_safe(); - - ref_ptr> r = t->r.template cast_safe>(); - ref_ptr> mask = t->mask.template cast_safe>(); - ref_ptr> M = t->M.template cast_safe>(); - ref_ptr> v = t->v.template cast_safe>(); - ref_ptr> op_multiply = t->op_multiply.template cast_safe>(); - ref_ptr> op_add = t->op_add.template cast_safe>(); - ref_ptr> op_select = t->op_select.template cast_safe>(); - ref_ptr> init = t->init.template cast_safe>(); - - r->validate_wd(FormatVector::AccDense); - mask->validate_rw(FormatVector::AccDense); - M->validate_rw(FormatMatrix::AccCsr); - v->validate_rw(FormatVector::AccDense); - - std::shared_ptr program; - if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError; - - auto* p_cl_r = r->template get>(); - auto* p_cl_mask = mask->template get>(); - auto* p_cl_M = M->template get>(); - auto* p_cl_v = v->template get>(); - auto early_exit = t->get_desc_or_default()->get_early_exit(); - - auto* p_cl_acc = get_acc_cl(); - auto& queue = p_cl_acc->get_queue_default(); - - auto kernel_scalar = program->make_kernel("mxv_scalar"); - kernel_scalar.setArg(0, p_cl_M->Ap); - kernel_scalar.setArg(1, p_cl_M->Aj); - kernel_scalar.setArg(2, p_cl_M->Ax); - kernel_scalar.setArg(3, p_cl_v->Ax); - kernel_scalar.setArg(4, p_cl_mask->Ax); - kernel_scalar.setArg(5, p_cl_r->Ax); - kernel_scalar.setArg(6, init->get_value()); - kernel_scalar.setArg(7, r->get_n_rows()); - kernel_scalar.setArg(8, uint(early_exit)); - - uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_size, 1, 512); - - cl::NDRange exec_global(m_block_size * n_groups_to_dispatch); - cl::NDRange exec_local(m_block_size); - CL_DISPATCH_PROFILED("exec", queue, kernel_scalar, cl::NDRange(), exec_global, exec_local); - - return Status::Ok; - } - - Status execute_config_scalar(const DispatchContext& ctx) { - TIME_PROFILE_SCOPE("opencl/mxv/config-scalar"); - - auto t = ctx.task.template cast_safe(); - - ref_ptr> r = t->r.template cast_safe>(); - ref_ptr> mask = t->mask.template cast_safe>(); - ref_ptr> M = t->M.template cast_safe>(); - ref_ptr> v = t->v.template cast_safe>(); - ref_ptr> op_multiply = t->op_multiply.template cast_safe>(); - ref_ptr> op_add = t->op_add.template cast_safe>(); - ref_ptr> op_select = t->op_select.template cast_safe>(); - ref_ptr> init = t->init.template cast_safe>(); - - r->validate_wd(FormatVector::AccDense); - mask->validate_rw(FormatVector::AccDense); - M->validate_rw(FormatMatrix::AccCsr); - v->validate_rw(FormatVector::AccDense); - - std::shared_ptr program; - if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError; - - auto* p_cl_r = r->template get>(); - auto* p_cl_mask = mask->template get>(); - auto* p_cl_M = M->template get>(); - auto* p_cl_v = v->template get>(); - auto early_exit = t->get_desc_or_default()->get_early_exit(); - - auto* p_cl_acc = get_acc_cl(); - auto& queue = p_cl_acc->get_queue_default(); - - uint config_size = 0; - cl::Buffer cl_config(p_cl_acc->get_context(), CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS, sizeof(uint) * M->get_n_rows()); - cl::Buffer cl_config_size(p_cl_acc->get_context(), CL_MEM_READ_WRITE | CL_MEM_HOST_READ_ONLY | CL_MEM_COPY_HOST_PTR, sizeof(uint), &config_size); - - auto kernel_config = program->make_kernel("mxv_config"); - kernel_config.setArg(0, p_cl_mask->Ax); - kernel_config.setArg(1, p_cl_r->Ax); - kernel_config.setArg(2, cl_config); - kernel_config.setArg(3, cl_config_size); - kernel_config.setArg(4, init->get_value()); - kernel_config.setArg(5, M->get_n_rows()); - - uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_size, 1, 1024); - - cl::NDRange config_global(m_block_size * n_groups_to_dispatch); - cl::NDRange config_local(m_block_size); - CL_DISPATCH_PROFILED("config", queue, kernel_config, cl::NDRange(), config_global, config_local); - - CL_READ_PROFILED("config-size", queue, cl_config_size, true, 0, sizeof(config_size), &config_size); - - auto kernel_config_scalar = program->make_kernel("mxv_config_scalar"); - kernel_config_scalar.setArg(0, p_cl_M->Ap); - kernel_config_scalar.setArg(1, p_cl_M->Aj); - kernel_config_scalar.setArg(2, p_cl_M->Ax); - kernel_config_scalar.setArg(3, p_cl_v->Ax); - kernel_config_scalar.setArg(4, cl_config); - kernel_config_scalar.setArg(5, p_cl_r->Ax); - kernel_config_scalar.setArg(6, init->get_value()); - kernel_config_scalar.setArg(7, config_size); - kernel_config_scalar.setArg(8, uint(early_exit)); - - n_groups_to_dispatch = div_up_clamp(config_size, m_block_size, 1, 1024); - - cl::NDRange exec_global(m_block_size * n_groups_to_dispatch); - cl::NDRange exec_local(m_block_size); - CL_DISPATCH_PROFILED("exec", queue, kernel_config_scalar, cl::NDRange(), exec_global, exec_local); - - return Status::Ok; - } - - bool ensure_kernel(const ref_ptr>& op_multiply, - const ref_ptr>& op_add, - const ref_ptr>& op_select, - std::shared_ptr& program) { - m_block_size = get_acc_cl()->get_wave_size(); - m_block_count = 1; - - assert(m_block_count >= 1); - - CLProgramBuilder program_builder; - program_builder - .set_name("mxv") - .add_define("WARP_SIZE", get_acc_cl()->get_wave_size()) - .add_define("BLOCK_SIZE", m_block_size) - .add_define("BLOCK_COUNT", m_block_count) - .add_type("TYPE", get_ttype().template as()); - - if constexpr (std::is_same_v) { - program_builder.add_define("USE_PAIR_SEMANTICS", 1); - program_builder.add_define("USE_PAIR_COMPARISON", 1); - } else { - program_builder - .add_op("OP_BINARY1", op_multiply.template as()) - .add_op("OP_BINARY2", op_add.template as()) - .add_op("OP_SELECT", op_select.template as()); - } - program_builder.set_source(source_mxv).acquire(); - program = program_builder.get_program(); - - return true; - } - - private: - uint m_block_size = 0; - uint m_block_count = 0; - }; - -}// namespace spla - -#endif//SPLA_CL_MXV_HPP +template class Algo_mxv_masked_cl final : public RegistryAlgo { +public: + ~Algo_mxv_masked_cl() override = default; + + std::string get_name() override { return "mxv_masked"; } + + std::string get_description() override { + return "parallel matrix-vector masked product on opencl device"; + } + + Status execute(const DispatchContext &ctx) override { + auto t = ctx.task.template cast_safe(); + auto early_exit = t->get_desc_or_default()->get_early_exit(); + + if (early_exit) { + return execute_config_scalar(ctx); + } else { + return execute_vector(ctx); + } + } + +private: + Status execute_vector(const DispatchContext &ctx) { + TIME_PROFILE_SCOPE("opencl/mxv/vector"); + + auto t = ctx.task.template cast_safe(); + + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> mask = t->mask.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> v = t->v.template cast_safe>(); + ref_ptr> op_multiply = + t->op_multiply.template cast_safe>(); + ref_ptr> op_add = + t->op_add.template cast_safe>(); + ref_ptr> op_select = + t->op_select.template cast_safe>(); + ref_ptr> init = t->init.template cast_safe>(); + + r->validate_wd(FormatVector::AccDense); + mask->validate_rw(FormatVector::AccDense); + M->validate_rw(FormatMatrix::AccCsr); + v->validate_rw(FormatVector::AccDense); + + std::shared_ptr program; + if (!ensure_kernel(op_multiply, op_add, op_select, program)) + return Status::CompilationError; + + auto *p_cl_r = r->template get>(); + auto *p_cl_mask = mask->template get>(); + auto *p_cl_M = M->template get>(); + auto *p_cl_v = v->template get>(); + + auto *p_cl_acc = get_acc_cl(); + auto &queue = p_cl_acc->get_queue_default(); + + auto kernel_vector = program->make_kernel("mxv_vector"); + kernel_vector.setArg(0, p_cl_M->Ap); + kernel_vector.setArg(1, p_cl_M->Aj); + kernel_vector.setArg(2, p_cl_M->Ax); + kernel_vector.setArg(3, p_cl_v->Ax); + kernel_vector.setArg(4, p_cl_mask->Ax); + kernel_vector.setArg(5, p_cl_r->Ax); + kernel_vector.setArg(6, init->get_value()); + kernel_vector.setArg(7, r->get_n_rows()); + + uint n_groups_to_dispatch = + div_up_clamp(r->get_n_rows(), m_block_count, 1, 512); + + cl::NDRange exec_global(m_block_count * n_groups_to_dispatch, m_block_size); + cl::NDRange exec_local(m_block_count, m_block_size); + CL_DISPATCH_PROFILED("exec", queue, kernel_vector, cl::NDRange(), + exec_global, exec_local); + + return Status::Ok; + } + + Status execute_scalar(const DispatchContext &ctx) { + TIME_PROFILE_SCOPE("opencl/mxv/scalar"); + + auto t = ctx.task.template cast_safe(); + + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> mask = t->mask.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> v = t->v.template cast_safe>(); + ref_ptr> op_multiply = + t->op_multiply.template cast_safe>(); + ref_ptr> op_add = + t->op_add.template cast_safe>(); + ref_ptr> op_select = + t->op_select.template cast_safe>(); + ref_ptr> init = t->init.template cast_safe>(); + + r->validate_wd(FormatVector::AccDense); + mask->validate_rw(FormatVector::AccDense); + M->validate_rw(FormatMatrix::AccCsr); + v->validate_rw(FormatVector::AccDense); + + std::shared_ptr program; + if (!ensure_kernel(op_multiply, op_add, op_select, program)) + return Status::CompilationError; + + auto *p_cl_r = r->template get>(); + auto *p_cl_mask = mask->template get>(); + auto *p_cl_M = M->template get>(); + auto *p_cl_v = v->template get>(); + auto early_exit = t->get_desc_or_default()->get_early_exit(); + + auto *p_cl_acc = get_acc_cl(); + auto &queue = p_cl_acc->get_queue_default(); + + auto kernel_scalar = program->make_kernel("mxv_scalar"); + kernel_scalar.setArg(0, p_cl_M->Ap); + kernel_scalar.setArg(1, p_cl_M->Aj); + kernel_scalar.setArg(2, p_cl_M->Ax); + kernel_scalar.setArg(3, p_cl_v->Ax); + kernel_scalar.setArg(4, p_cl_mask->Ax); + kernel_scalar.setArg(5, p_cl_r->Ax); + kernel_scalar.setArg(6, init->get_value()); + kernel_scalar.setArg(7, r->get_n_rows()); + kernel_scalar.setArg(8, uint(early_exit)); + + uint n_groups_to_dispatch = + div_up_clamp(r->get_n_rows(), m_block_size, 1, 512); + + cl::NDRange exec_global(m_block_size * n_groups_to_dispatch); + cl::NDRange exec_local(m_block_size); + CL_DISPATCH_PROFILED("exec", queue, kernel_scalar, cl::NDRange(), + exec_global, exec_local); + + return Status::Ok; + } + + Status execute_config_scalar(const DispatchContext &ctx) { + TIME_PROFILE_SCOPE("opencl/mxv/config-scalar"); + + auto t = ctx.task.template cast_safe(); + + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> mask = t->mask.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> v = t->v.template cast_safe>(); + ref_ptr> op_multiply = + t->op_multiply.template cast_safe>(); + ref_ptr> op_add = + t->op_add.template cast_safe>(); + ref_ptr> op_select = + t->op_select.template cast_safe>(); + ref_ptr> init = t->init.template cast_safe>(); + + r->validate_wd(FormatVector::AccDense); + mask->validate_rw(FormatVector::AccDense); + M->validate_rw(FormatMatrix::AccCsr); + v->validate_rw(FormatVector::AccDense); + + std::shared_ptr program; + if (!ensure_kernel(op_multiply, op_add, op_select, program)) + return Status::CompilationError; + + auto *p_cl_r = r->template get>(); + auto *p_cl_mask = mask->template get>(); + auto *p_cl_M = M->template get>(); + auto *p_cl_v = v->template get>(); + auto early_exit = t->get_desc_or_default()->get_early_exit(); + + auto *p_cl_acc = get_acc_cl(); + auto &queue = p_cl_acc->get_queue_default(); + + uint config_size = 0; + cl::Buffer cl_config(p_cl_acc->get_context(), + CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS, + sizeof(uint) * M->get_n_rows()); + cl::Buffer cl_config_size(p_cl_acc->get_context(), + CL_MEM_READ_WRITE | CL_MEM_HOST_READ_ONLY | + CL_MEM_COPY_HOST_PTR, + sizeof(uint), &config_size); + + auto kernel_config = program->make_kernel("mxv_config"); + kernel_config.setArg(0, p_cl_mask->Ax); + kernel_config.setArg(1, p_cl_r->Ax); + kernel_config.setArg(2, cl_config); + kernel_config.setArg(3, cl_config_size); + kernel_config.setArg(4, init->get_value()); + kernel_config.setArg(5, M->get_n_rows()); + + uint n_groups_to_dispatch = + div_up_clamp(r->get_n_rows(), m_block_size, 1, 1024); + + cl::NDRange config_global(m_block_size * n_groups_to_dispatch); + cl::NDRange config_local(m_block_size); + CL_DISPATCH_PROFILED("config", queue, kernel_config, cl::NDRange(), + config_global, config_local); + + CL_READ_PROFILED("config-size", queue, cl_config_size, true, 0, + sizeof(config_size), &config_size); + + auto kernel_config_scalar = program->make_kernel("mxv_config_scalar"); + kernel_config_scalar.setArg(0, p_cl_M->Ap); + kernel_config_scalar.setArg(1, p_cl_M->Aj); + kernel_config_scalar.setArg(2, p_cl_M->Ax); + kernel_config_scalar.setArg(3, p_cl_v->Ax); + kernel_config_scalar.setArg(4, cl_config); + kernel_config_scalar.setArg(5, p_cl_r->Ax); + kernel_config_scalar.setArg(6, init->get_value()); + kernel_config_scalar.setArg(7, config_size); + kernel_config_scalar.setArg(8, uint(early_exit)); + + n_groups_to_dispatch = div_up_clamp(config_size, m_block_size, 1, 1024); + + cl::NDRange exec_global(m_block_size * n_groups_to_dispatch); + cl::NDRange exec_local(m_block_size); + CL_DISPATCH_PROFILED("exec", queue, kernel_config_scalar, cl::NDRange(), + exec_global, exec_local); + + return Status::Ok; + } + + bool ensure_kernel(const ref_ptr> &op_multiply, + const ref_ptr> &op_add, + const ref_ptr> &op_select, + std::shared_ptr &program) { + m_block_size = get_acc_cl()->get_wave_size(); + m_block_count = 1; + + assert(m_block_count >= 1); + + CLProgramBuilder program_builder; + program_builder.set_name("mxv") + .add_define("WARP_SIZE", get_acc_cl()->get_wave_size()) + .add_define("BLOCK_SIZE", m_block_size) + .add_define("BLOCK_COUNT", m_block_count) + .add_type("TYPE", get_ttype().template as()); + + if constexpr (std::is_same_v) { + program_builder.add_define("USE_PAIR_SEMANTICS", 1); + program_builder.add_define("USE_PAIR_COMPARISON", 1); + } else { + program_builder.add_op("OP_BINARY1", op_multiply.template as()) + .add_op("OP_BINARY2", op_add.template as()) + .add_op("OP_SELECT", op_select.template as()); + } + program_builder.set_source(source_mxv).acquire(); + program = program_builder.get_program(); + + return true; + } + +private: + uint m_block_size = 0; + uint m_block_count = 0; +}; + +} // namespace spla + +#endif // SPLA_CL_MXV_HPP diff --git a/src/opencl/cl_program_builder.cpp b/src/opencl/cl_program_builder.cpp index e40b80a94..c1a4ab67d 100644 --- a/src/opencl/cl_program_builder.cpp +++ b/src/opencl/cl_program_builder.cpp @@ -30,8 +30,8 @@ #include #include -#include #include +#include namespace spla { @@ -64,7 +64,7 @@ namespace spla { return *this; } void CLProgramBuilder::acquire() { - + CLAccelerator* acc = get_acc_cl(); CLProgramCache* cache = acc->get_cache(); @@ -86,17 +86,19 @@ namespace spla { bool needs_pair_override = false; for (const auto& define : m_defines) { builder << "#define " << define.first << " " << define.second << "\n"; - if (define.first == "TYPE" && define.second.find("Pair") != std::string::npos) { - needs_pair_override = true; + if (define.first == "TYPE" && + define.second.find("Pair") != std::string::npos) { + needs_pair_override = true; } } - + builder << source_common_api; if (needs_pair_override) { - builder << "#define OP_BINARY1(a, b) make_pair((a).weight, (b).vertex)\n\n"; - builder << "#define OP_BINARY2(a, b) min_pair(a, b)\n\n"; - builder << "#define OP_SELECT(a) pair_always(a)\n\n"; + builder << "#define OP_BINARY1(a, b) make_pair((a).weight, " + "(b).vertex)\n\n"; + builder << "#define OP_BINARY2(a, b) min_pair(a, b)\n\n"; + builder << "#define OP_SELECT(a) pair_always(a)\n\n"; } for (const auto& function : m_functions) { @@ -104,9 +106,8 @@ namespace spla { << function.first << function.second->get_source_cl() << "\n"; } builder << m_source; - - m_program_code = builder.str(); + m_program_code = builder.str(); Timer t; t.start(); auto status = m_program->m_program.build("-cl-std=CL1.2"); diff --git a/src/opencl/kernels/common_def.cl b/src/opencl/kernels/common_def.cl index 880708f42..792ce452a 100644 --- a/src/opencl/kernels/common_def.cl +++ b/src/opencl/kernels/common_def.cl @@ -1,42 +1,54 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #pragma once -#define TYPE int -#define BLOCK_SIZE 32 +#define TYPE int +#define BLOCK_SIZE 32 #define LM_NUM_MEM_BANKS 32 -#define BLOCK_COUNT 1 -#define WARP_SIZE 32 -#define OP_SELECT(a) a -#define OP_UNARY(a) a -#define OP_BINARY(a, b) a + b +#define BLOCK_COUNT 1 +#define WARP_SIZE 32 +#define OP_UNARY(a) a +#ifndef USE_PAIR_SEMANTICS #define OP_BINARY1(a, b) a + b #define OP_BINARY2(a, b) a + b +#define OP_SELECT(a) a +#else +#define OP_BINARY1(a, b) make_pair((a).weight, (b).vertex) +#define OP_BINARY2(a, b) min_pair(a, b) +#define OP_SELECT(a) pair_always(a) +#endif #define __kernel #define __global @@ -48,22 +60,19 @@ #define half float struct float2 { - float x; + float x; }; struct float3 { - float x, y, z; + float x, y, z; }; struct float4 { - float x, y, z, w; + float x, y, z, w; }; -#define uint unsigned int +#define uint unsigned int #define ulong unsigned long int -enum cl_mem_fence_flags { - CLK_LOCAL_MEM_FENCE, - CLK_GLOBAL_MEM_FENCE -}; +enum cl_mem_fence_flags { CLK_LOCAL_MEM_FENCE, CLK_GLOBAL_MEM_FENCE }; void barrier(cl_mem_fence_flags flags); @@ -74,16 +83,16 @@ size_t get_local_id(uint dimindx); size_t get_num_groups(uint dimindx); size_t get_group_id(uint dimindx); size_t get_global_offset(uint dimindx); -uint get_work_dim(); +uint get_work_dim(); -#define atomic_add(p, val) p[0] += val -#define atomic_sub(p, val) p[0] -= val -#define atomic_inc(p) (p)[0] -#define atomic_dec(p) (p)[0] +#define atomic_add(p, val) p[0] += val +#define atomic_sub(p, val) p[0] -= val +#define atomic_inc(p) (p)[0] +#define atomic_dec(p) (p)[0] #define atomic_cmpxchg(p, cmp, val) ((p)[0] == cmp ? val : (p)[0]) -#define min(x, y) (x < y ? x : y) -#define max(x, y) (x > y ? x : y) -#define sin(x) x -#define cos(x) x +#define min(x, y) (x < y ? x : y) +#define max(x, y) (x > y ? x : y) +#define sin(x) x +#define cos(x) x #define fract(x, ptr) x \ No newline at end of file diff --git a/src/scalar.cpp b/src/scalar.cpp index b090a2757..a6235c9c9 100644 --- a/src/scalar.cpp +++ b/src/scalar.cpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include @@ -30,40 +37,39 @@ namespace spla { - ref_ptr Scalar::make(const ref_ptr& type) { - if (!type) { - LOG_MSG(Status::InvalidArgument, "passed null type"); - return ref_ptr{}; - } +ref_ptr Scalar::make(const ref_ptr &type) { + if (!type) { + LOG_MSG(Status::InvalidArgument, "passed null type"); + return ref_ptr{}; + } - Library::get(); + Library::get(); - if (type == INT) { - return ref_ptr(new TScalar()); - } - if (type == UINT) { - return ref_ptr(new TScalar()); - } - if (type == FLOAT) { - return ref_ptr(new TScalar()); - } - if (type == PAIR) { - return ref_ptr(new TScalar()); - } + if (type == INT) { + return ref_ptr(new TScalar()); + } + if (type == UINT) { + return ref_ptr(new TScalar()); + } + if (type == FLOAT) { + return ref_ptr(new TScalar()); + } + if (type == PAIR) { + return ref_ptr(new TScalar()); + } - LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); - return ref_ptr(); - } + LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); + return ref_ptr(); +} - ref_ptr Scalar::make_int(std::int32_t value) { - return ref_ptr(new TScalar(value)); - } - ref_ptr Scalar::make_uint(std::uint32_t value) { - return ref_ptr(new TScalar(value)); - } - ref_ptr Scalar::Scalar::make_float(float value) { - return ref_ptr(new TScalar(value)); - } - +ref_ptr Scalar::make_int(std::int32_t value) { + return ref_ptr(new TScalar(value)); +} +ref_ptr Scalar::make_uint(std::uint32_t value) { + return ref_ptr(new TScalar(value)); +} +ref_ptr Scalar::Scalar::make_float(float value) { + return ref_ptr(new TScalar(value)); +} -}// namespace spla \ No newline at end of file +} // namespace spla \ No newline at end of file diff --git a/src/type.cpp b/src/type.cpp index 28a97edc3..ce3f4c64a 100644 --- a/src/type.cpp +++ b/src/type.cpp @@ -1,38 +1,50 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include namespace spla { - ref_ptr BOOL = TType::make_type("BOOL", "B", "bool", "4 byte logical type", 1); - ref_ptr INT = TType::make_type("INT", "I", "int", "signed 4 byte integral type", 2); - ref_ptr UINT = TType::make_type("UINT", "U", "uint", "unsigned 4 byte integral type", 3); - ref_ptr FLOAT = TType::make_type("FLOAT", "F", "float", "4 byte floating point type", 4); - ref_ptr PAIR = TType::make_type("PAIR", "P", "struct Pair", "weight-vertex pair float-int", 5); +ref_ptr BOOL = + TType::make_type("BOOL", "B", "bool", "4 byte logical type", 1); +ref_ptr INT = TType::make_type("INT", "I", "int", + "signed 4 byte integral type", 2); +ref_ptr UINT = TType::make_type( + "UINT", "U", "uint", "unsigned 4 byte integral type", 3); +ref_ptr FLOAT = TType::make_type( + "FLOAT", "F", "float", "4 byte floating point type", 4); +ref_ptr PAIR = TType::make_type( + "PAIR", "P", "struct Pair", "weight-vertex pair float-int", 5); -}// namespace spla \ No newline at end of file +} // namespace spla \ No newline at end of file diff --git a/src/vector.cpp b/src/vector.cpp index 0e3a47863..e152cc80d 100644 --- a/src/vector.cpp +++ b/src/vector.cpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include @@ -30,33 +37,33 @@ namespace spla { - ref_ptr Vector::make(uint n_rows, const ref_ptr& type) { - if (n_rows <= 0) { - LOG_MSG(Status::InvalidArgument, "passed 0 dim"); - return ref_ptr{}; - } - if (!type) { - LOG_MSG(Status::InvalidArgument, "passed null type"); - return ref_ptr{}; - } +ref_ptr Vector::make(uint n_rows, const ref_ptr &type) { + if (n_rows <= 0) { + LOG_MSG(Status::InvalidArgument, "passed 0 dim"); + return ref_ptr{}; + } + if (!type) { + LOG_MSG(Status::InvalidArgument, "passed null type"); + return ref_ptr{}; + } - Library::get(); + Library::get(); - if (type == INT) { - return ref_ptr(new TVector(n_rows)); - } - if (type == UINT) { - return ref_ptr(new TVector(n_rows)); - } - if (type == FLOAT) { - return ref_ptr(new TVector(n_rows)); - } - if (type == spla::PAIR) { - return ref_ptr(new TVector(n_rows)); - } + if (type == INT) { + return ref_ptr(new TVector(n_rows)); + } + if (type == UINT) { + return ref_ptr(new TVector(n_rows)); + } + if (type == FLOAT) { + return ref_ptr(new TVector(n_rows)); + } + if (type == spla::PAIR) { + return ref_ptr(new TVector(n_rows)); + } - LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); - return ref_ptr{}; - } + LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); + return ref_ptr{}; +} -}// namespace spla \ No newline at end of file +} // namespace spla \ No newline at end of file diff --git a/tests/test_pair.cpp b/tests/test_pair.cpp index 9b91954b9..429f0cd58 100644 --- a/tests/test_pair.cpp +++ b/tests/test_pair.cpp @@ -1,148 +1,137 @@ - #include "test_common.hpp" +#include "test_common.hpp" #include #include "spla.hpp" TEST(pair, struct_creation) { - spla::T_PAIR p(2.5f, 2); - EXPECT_EQ(p.weight, 2.5f); - EXPECT_EQ(p.vertex, 2); + spla::T_PAIR p(2.5f, 2); + EXPECT_EQ(p.weight, 2.5f); + EXPECT_EQ(p.vertex, 2); } TEST(pair, basic_operations) { - spla::T_PAIR p1(2.5f, 2); - spla::T_PAIR p2(1.5f, 5); - - EXPECT_EQ(p1.weight, 2.5f); - EXPECT_EQ(p1.vertex, 2); - EXPECT_TRUE(p2.weight < p1.weight); + spla::T_PAIR p1(2.5f, 2); + spla::T_PAIR p2(1.5f, 5); + + EXPECT_EQ(p1.weight, 2.5f); + EXPECT_EQ(p1.vertex, 2); + EXPECT_TRUE(p2.weight < p1.weight); } TEST(pair, type_registration) { - auto type = spla::PAIR; - ASSERT_TRUE(type); - EXPECT_EQ(type->get_name(), "PAIR"); - EXPECT_EQ(type->get_code(), "P"); - EXPECT_EQ(type->get_cpp(), "struct Pair"); - EXPECT_EQ(type->get_description(), "weight-vertex pair float-int"); - EXPECT_EQ(type->get_size(), sizeof(spla::Pair)); - EXPECT_EQ(type->get_id(), 5); + auto type = spla::PAIR; + ASSERT_TRUE(type); + EXPECT_EQ(type->get_name(), "PAIR"); + EXPECT_EQ(type->get_code(), "P"); + EXPECT_EQ(type->get_cpp(), "struct Pair"); + EXPECT_EQ(type->get_description(), "weight-vertex pair float-int"); + EXPECT_EQ(type->get_size(), sizeof(spla::Pair)); + EXPECT_EQ(type->get_id(), 5); } TEST(pair, op_registration) { - spla::Library::get(); - EXPECT_EQ(spla::MIN_PAIR->get_name(), "MIN_PAIR"); + spla::Library::get(); + EXPECT_EQ(spla::MIN_PAIR->get_name(), "MIN_PAIR"); } TEST(pair, set_get_pair_matrix) { - auto S = spla::Matrix::make(2, 2, spla::PAIR); - S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); - spla::T_PAIR pair; - S->get_pair(0, 1, pair); - EXPECT_EQ(pair.vertex, 1); - EXPECT_EQ(pair.weight, 7.0f); - + auto S = spla::Matrix::make(2, 2, spla::PAIR); + S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); + spla::T_PAIR pair; + S->get_pair(0, 1, pair); + EXPECT_EQ(pair.vertex, 1); + EXPECT_EQ(pair.weight, 7.0f); } TEST(pair, set_get_pair_vector) { - auto V = spla::Vector::make(2, spla::PAIR); - V->set_pair(0, spla::T_PAIR(1.0f, 1)); - V->set_pair(1, spla::T_PAIR(2.0f, 2)); - spla::T_PAIR pair; - V->get_pair(0, pair); - EXPECT_EQ(pair.vertex, 1); - EXPECT_EQ(pair.weight, 1.0f); - V->get_pair(1, pair); - EXPECT_EQ(pair.vertex, 2); - EXPECT_EQ(pair.weight, 2.0f); - + auto V = spla::Vector::make(2, spla::PAIR); + V->set_pair(0, spla::T_PAIR(1.0f, 1)); + V->set_pair(1, spla::T_PAIR(2.0f, 2)); + spla::T_PAIR pair; + V->get_pair(0, pair); + EXPECT_EQ(pair.vertex, 1); + EXPECT_EQ(pair.weight, 1.0f); + V->get_pair(1, pair); + EXPECT_EQ(pair.vertex, 2); + EXPECT_EQ(pair.weight, 2.0f); } TEST(pair, set_get_pair_scalar) { - auto V = spla::Scalar::make(spla::PAIR); - spla::T_PAIR pair = spla::T_PAIR(1.0f, 1); - V->set_pair(pair); - V->get_pair(pair); - EXPECT_EQ(pair.vertex, 1); - EXPECT_EQ(pair.weight, 1.0f); + auto V = spla::Scalar::make(spla::PAIR); + spla::T_PAIR pair = spla::T_PAIR(1.0f, 1); + V->set_pair(pair); + V->get_pair(pair); + EXPECT_EQ(pair.vertex, 1); + EXPECT_EQ(pair.weight, 1.0f); } TEST(pair, mxv_pair) { - int32_t n = 7; - auto S = spla::Matrix::make(n, n, spla::PAIR); - S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); - S->set_pair(0, 4, spla::T_PAIR(4.0f, 4)); - S->set_pair(1, 0, spla::T_PAIR(7.0f, 0)); - S->set_pair(1, 2, spla::T_PAIR(11.0f, 2)); - S->set_pair(1, 3, spla::T_PAIR(10.0f, 3)); - S->set_pair(1, 4, spla::T_PAIR(9.0f, 4)); - S->set_pair(2, 1, spla::T_PAIR(11.0f, 1)); - S->set_pair(2, 3, spla::T_PAIR(5.0f, 3)); - S->set_pair(3, 1, spla::T_PAIR(10.0f, 1)); - S->set_pair(3, 2, spla::T_PAIR(5.0f, 2)); - S->set_pair(3, 4, spla::T_PAIR(15.0f, 4)); - S->set_pair(3, 5, spla::T_PAIR(12.0f, 5)); - S->set_pair(3, 6, spla::T_PAIR(8.0f, 6)); - S->set_pair(4, 0, spla::T_PAIR(4.0f, 0)); - S->set_pair(4, 1, spla::T_PAIR(9.0f, 1)); - S->set_pair(4, 3, spla::T_PAIR(15.0f, 3)); - S->set_pair(4, 5, spla::T_PAIR(6.0f, 5)); - S->set_pair(5, 3, spla::T_PAIR(12.0f, 3)); - S->set_pair(5, 4, spla::T_PAIR(6.0f, 4)); - S->set_pair(5, 6, spla::T_PAIR(13.0f, 6)); - S->set_pair(6, 3, spla::T_PAIR(8.0f, 3)); - S->set_pair(6, 5, spla::T_PAIR(13.0f, 5)); + int32_t n = 7; + auto S = spla::Matrix::make(n, n, spla::PAIR); + S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); + S->set_pair(0, 4, spla::T_PAIR(4.0f, 4)); + S->set_pair(1, 0, spla::T_PAIR(7.0f, 0)); + S->set_pair(1, 2, spla::T_PAIR(11.0f, 2)); + S->set_pair(1, 3, spla::T_PAIR(10.0f, 3)); + S->set_pair(1, 4, spla::T_PAIR(9.0f, 4)); + S->set_pair(2, 1, spla::T_PAIR(11.0f, 1)); + S->set_pair(2, 3, spla::T_PAIR(5.0f, 3)); + S->set_pair(3, 1, spla::T_PAIR(10.0f, 1)); + S->set_pair(3, 2, spla::T_PAIR(5.0f, 2)); + S->set_pair(3, 4, spla::T_PAIR(15.0f, 4)); + S->set_pair(3, 5, spla::T_PAIR(12.0f, 5)); + S->set_pair(3, 6, spla::T_PAIR(8.0f, 6)); + S->set_pair(4, 0, spla::T_PAIR(4.0f, 0)); + S->set_pair(4, 1, spla::T_PAIR(9.0f, 1)); + S->set_pair(4, 3, spla::T_PAIR(15.0f, 3)); + S->set_pair(4, 5, spla::T_PAIR(6.0f, 5)); + S->set_pair(5, 3, spla::T_PAIR(12.0f, 3)); + S->set_pair(5, 4, spla::T_PAIR(6.0f, 4)); + S->set_pair(5, 6, spla::T_PAIR(13.0f, 6)); + S->set_pair(6, 3, spla::T_PAIR(8.0f, 3)); + S->set_pair(6, 5, spla::T_PAIR(13.0f, 5)); + + auto parent = spla::Vector::make(n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + parent->set_pair(i, spla::T_PAIR(0.0f, i)); + } + auto edge = spla::Vector::make(n, spla::PAIR); - auto parent = spla::Vector::make(n, spla::PAIR); - for (int32_t i = 0; i < n; i++) { - parent->set_pair(i, spla::T_PAIR(0.0f, i)); - } - auto edge = spla::Vector::make(n, spla::PAIR); + auto mask = spla::Vector::make(n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + mask->set_pair(i, spla::T_PAIR(0.0f, 0)); + } + auto init_inf = spla::Scalar::make(spla::PAIR); + spla::T_PAIR init_val(1e9f, -1); + init_inf->set_pair(init_val); + spla::exec_mxv_masked(edge, mask, S, parent, spla::MUL_PAIR, spla::MIN_PAIR, + spla::ALWAYS_PAIR, init_inf); - - auto mask = spla::Vector::make(n, spla::PAIR); - for (int32_t i = 0; i < n; i++) { - mask->set_pair(i, spla::T_PAIR(0.0f, 0)); - } - auto init_inf = spla::Scalar::make(spla::PAIR); - spla::T_PAIR init_val(1e9f, -1); - init_inf->set_pair(init_val); - spla::exec_mxv_masked(edge, mask, S, parent, spla::MUL_PAIR, spla::MIN_PAIR, spla::ALWAYS_PAIR, init_inf); + spla::T_PAIR expected[] = {spla::T_PAIR(4.0f, 4), spla::T_PAIR(7.0f, 0), + spla::T_PAIR(5.0f, 3), spla::T_PAIR(5.0f, 2), + spla::T_PAIR(4.0f, 0), spla::T_PAIR(6.0f, 4), + spla::T_PAIR(8.0f, 3)}; - spla::T_PAIR expected[] = { - spla::T_PAIR(4.0f, 4), - spla::T_PAIR(7.0f, 0), - spla::T_PAIR(5.0f, 3), - spla::T_PAIR(5.0f, 2), - spla::T_PAIR(4.0f, 0), - spla::T_PAIR(6.0f, 4), - spla::T_PAIR(8.0f, 3) - }; - - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - edge->get_pair(i, p); - EXPECT_FLOAT_EQ(p.weight, expected[i].weight); - EXPECT_EQ(p.vertex, expected[i].vertex); - } + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + edge->get_pair(i, p); + EXPECT_FLOAT_EQ(p.weight, expected[i].weight); + EXPECT_EQ(p.vertex, expected[i].vertex); + } } TEST(pair, extract_row) { - int32_t n = 3; - auto S = spla::Matrix::make(n, n, spla::PAIR); - S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); - S->set_pair(1, 0, spla::T_PAIR(4.0f, 4)); - S->set_pair(1, 2, spla::T_PAIR(7.0f, 0)); - - auto row1 = spla::Vector::make(n, spla::PAIR); - spla::exec_m_extract_row(row1, S, 1, spla::IDENTITY_PAIR); - spla::T_PAIR expected[] = { - spla::T_PAIR(4.0f, 4), - spla::T_PAIR(), - spla::T_PAIR(7.0f, 0) - }; - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - row1->get_pair(i, p); - EXPECT_FLOAT_EQ(p.weight, expected[i].weight); - EXPECT_EQ(p.vertex, expected[i].vertex); - } + int32_t n = 3; + auto S = spla::Matrix::make(n, n, spla::PAIR); + S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); + S->set_pair(1, 0, spla::T_PAIR(4.0f, 4)); + S->set_pair(1, 2, spla::T_PAIR(7.0f, 0)); + auto row1 = spla::Vector::make(n, spla::PAIR); + spla::exec_m_extract_row(row1, S, 1, spla::IDENTITY_PAIR); + spla::T_PAIR expected[] = {spla::T_PAIR(4.0f, 4), spla::T_PAIR(), + spla::T_PAIR(7.0f, 0)}; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + row1->get_pair(i, p); + EXPECT_FLOAT_EQ(p.weight, expected[i].weight); + EXPECT_EQ(p.vertex, expected[i].vertex); + } } SPLA_GTEST_MAIN From 6952fa03da42fedc3d197b7e37723f7082bf0795 Mon Sep 17 00:00:00 2001 From: polka777 Date: Sun, 26 Apr 2026 11:56:53 +0300 Subject: [PATCH 09/14] Clang-tidy format tests/ --- tests/test_pair.cpp | 206 ++++++++++++++++++++++---------------------- 1 file changed, 103 insertions(+), 103 deletions(-) diff --git a/tests/test_pair.cpp b/tests/test_pair.cpp index 429f0cd58..3f460061e 100644 --- a/tests/test_pair.cpp +++ b/tests/test_pair.cpp @@ -4,134 +4,134 @@ #include "spla.hpp" TEST(pair, struct_creation) { - spla::T_PAIR p(2.5f, 2); - EXPECT_EQ(p.weight, 2.5f); - EXPECT_EQ(p.vertex, 2); + spla::T_PAIR p(2.5f, 2); + EXPECT_EQ(p.weight, 2.5f); + EXPECT_EQ(p.vertex, 2); } TEST(pair, basic_operations) { - spla::T_PAIR p1(2.5f, 2); - spla::T_PAIR p2(1.5f, 5); + spla::T_PAIR p1(2.5f, 2); + spla::T_PAIR p2(1.5f, 5); - EXPECT_EQ(p1.weight, 2.5f); - EXPECT_EQ(p1.vertex, 2); - EXPECT_TRUE(p2.weight < p1.weight); + EXPECT_EQ(p1.weight, 2.5f); + EXPECT_EQ(p1.vertex, 2); + EXPECT_TRUE(p2.weight < p1.weight); } TEST(pair, type_registration) { - auto type = spla::PAIR; - ASSERT_TRUE(type); - EXPECT_EQ(type->get_name(), "PAIR"); - EXPECT_EQ(type->get_code(), "P"); - EXPECT_EQ(type->get_cpp(), "struct Pair"); - EXPECT_EQ(type->get_description(), "weight-vertex pair float-int"); - EXPECT_EQ(type->get_size(), sizeof(spla::Pair)); - EXPECT_EQ(type->get_id(), 5); + auto type = spla::PAIR; + ASSERT_TRUE(type); + EXPECT_EQ(type->get_name(), "PAIR"); + EXPECT_EQ(type->get_code(), "P"); + EXPECT_EQ(type->get_cpp(), "struct Pair"); + EXPECT_EQ(type->get_description(), "weight-vertex pair float-int"); + EXPECT_EQ(type->get_size(), sizeof(spla::Pair)); + EXPECT_EQ(type->get_id(), 5); } TEST(pair, op_registration) { - spla::Library::get(); - EXPECT_EQ(spla::MIN_PAIR->get_name(), "MIN_PAIR"); + spla::Library::get(); + EXPECT_EQ(spla::MIN_PAIR->get_name(), "MIN_PAIR"); } TEST(pair, set_get_pair_matrix) { - auto S = spla::Matrix::make(2, 2, spla::PAIR); - S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); - spla::T_PAIR pair; - S->get_pair(0, 1, pair); - EXPECT_EQ(pair.vertex, 1); - EXPECT_EQ(pair.weight, 7.0f); + auto S = spla::Matrix::make(2, 2, spla::PAIR); + S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); + spla::T_PAIR pair; + S->get_pair(0, 1, pair); + EXPECT_EQ(pair.vertex, 1); + EXPECT_EQ(pair.weight, 7.0f); } TEST(pair, set_get_pair_vector) { - auto V = spla::Vector::make(2, spla::PAIR); - V->set_pair(0, spla::T_PAIR(1.0f, 1)); - V->set_pair(1, spla::T_PAIR(2.0f, 2)); - spla::T_PAIR pair; - V->get_pair(0, pair); - EXPECT_EQ(pair.vertex, 1); - EXPECT_EQ(pair.weight, 1.0f); - V->get_pair(1, pair); - EXPECT_EQ(pair.vertex, 2); - EXPECT_EQ(pair.weight, 2.0f); + auto V = spla::Vector::make(2, spla::PAIR); + V->set_pair(0, spla::T_PAIR(1.0f, 1)); + V->set_pair(1, spla::T_PAIR(2.0f, 2)); + spla::T_PAIR pair; + V->get_pair(0, pair); + EXPECT_EQ(pair.vertex, 1); + EXPECT_EQ(pair.weight, 1.0f); + V->get_pair(1, pair); + EXPECT_EQ(pair.vertex, 2); + EXPECT_EQ(pair.weight, 2.0f); } TEST(pair, set_get_pair_scalar) { - auto V = spla::Scalar::make(spla::PAIR); - spla::T_PAIR pair = spla::T_PAIR(1.0f, 1); - V->set_pair(pair); - V->get_pair(pair); - EXPECT_EQ(pair.vertex, 1); - EXPECT_EQ(pair.weight, 1.0f); + auto V = spla::Scalar::make(spla::PAIR); + spla::T_PAIR pair = spla::T_PAIR(1.0f, 1); + V->set_pair(pair); + V->get_pair(pair); + EXPECT_EQ(pair.vertex, 1); + EXPECT_EQ(pair.weight, 1.0f); } TEST(pair, mxv_pair) { - int32_t n = 7; - auto S = spla::Matrix::make(n, n, spla::PAIR); - S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); - S->set_pair(0, 4, spla::T_PAIR(4.0f, 4)); - S->set_pair(1, 0, spla::T_PAIR(7.0f, 0)); - S->set_pair(1, 2, spla::T_PAIR(11.0f, 2)); - S->set_pair(1, 3, spla::T_PAIR(10.0f, 3)); - S->set_pair(1, 4, spla::T_PAIR(9.0f, 4)); - S->set_pair(2, 1, spla::T_PAIR(11.0f, 1)); - S->set_pair(2, 3, spla::T_PAIR(5.0f, 3)); - S->set_pair(3, 1, spla::T_PAIR(10.0f, 1)); - S->set_pair(3, 2, spla::T_PAIR(5.0f, 2)); - S->set_pair(3, 4, spla::T_PAIR(15.0f, 4)); - S->set_pair(3, 5, spla::T_PAIR(12.0f, 5)); - S->set_pair(3, 6, spla::T_PAIR(8.0f, 6)); - S->set_pair(4, 0, spla::T_PAIR(4.0f, 0)); - S->set_pair(4, 1, spla::T_PAIR(9.0f, 1)); - S->set_pair(4, 3, spla::T_PAIR(15.0f, 3)); - S->set_pair(4, 5, spla::T_PAIR(6.0f, 5)); - S->set_pair(5, 3, spla::T_PAIR(12.0f, 3)); - S->set_pair(5, 4, spla::T_PAIR(6.0f, 4)); - S->set_pair(5, 6, spla::T_PAIR(13.0f, 6)); - S->set_pair(6, 3, spla::T_PAIR(8.0f, 3)); - S->set_pair(6, 5, spla::T_PAIR(13.0f, 5)); + int32_t n = 7; + auto S = spla::Matrix::make(n, n, spla::PAIR); + S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); + S->set_pair(0, 4, spla::T_PAIR(4.0f, 4)); + S->set_pair(1, 0, spla::T_PAIR(7.0f, 0)); + S->set_pair(1, 2, spla::T_PAIR(11.0f, 2)); + S->set_pair(1, 3, spla::T_PAIR(10.0f, 3)); + S->set_pair(1, 4, spla::T_PAIR(9.0f, 4)); + S->set_pair(2, 1, spla::T_PAIR(11.0f, 1)); + S->set_pair(2, 3, spla::T_PAIR(5.0f, 3)); + S->set_pair(3, 1, spla::T_PAIR(10.0f, 1)); + S->set_pair(3, 2, spla::T_PAIR(5.0f, 2)); + S->set_pair(3, 4, spla::T_PAIR(15.0f, 4)); + S->set_pair(3, 5, spla::T_PAIR(12.0f, 5)); + S->set_pair(3, 6, spla::T_PAIR(8.0f, 6)); + S->set_pair(4, 0, spla::T_PAIR(4.0f, 0)); + S->set_pair(4, 1, spla::T_PAIR(9.0f, 1)); + S->set_pair(4, 3, spla::T_PAIR(15.0f, 3)); + S->set_pair(4, 5, spla::T_PAIR(6.0f, 5)); + S->set_pair(5, 3, spla::T_PAIR(12.0f, 3)); + S->set_pair(5, 4, spla::T_PAIR(6.0f, 4)); + S->set_pair(5, 6, spla::T_PAIR(13.0f, 6)); + S->set_pair(6, 3, spla::T_PAIR(8.0f, 3)); + S->set_pair(6, 5, spla::T_PAIR(13.0f, 5)); - auto parent = spla::Vector::make(n, spla::PAIR); - for (int32_t i = 0; i < n; i++) { - parent->set_pair(i, spla::T_PAIR(0.0f, i)); - } - auto edge = spla::Vector::make(n, spla::PAIR); + auto parent = spla::Vector::make(n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + parent->set_pair(i, spla::T_PAIR(0.0f, i)); + } + auto edge = spla::Vector::make(n, spla::PAIR); - auto mask = spla::Vector::make(n, spla::PAIR); - for (int32_t i = 0; i < n; i++) { - mask->set_pair(i, spla::T_PAIR(0.0f, 0)); - } - auto init_inf = spla::Scalar::make(spla::PAIR); - spla::T_PAIR init_val(1e9f, -1); - init_inf->set_pair(init_val); - spla::exec_mxv_masked(edge, mask, S, parent, spla::MUL_PAIR, spla::MIN_PAIR, - spla::ALWAYS_PAIR, init_inf); + auto mask = spla::Vector::make(n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + mask->set_pair(i, spla::T_PAIR(0.0f, 0)); + } + auto init_inf = spla::Scalar::make(spla::PAIR); + spla::T_PAIR init_val(1e9f, -1); + init_inf->set_pair(init_val); + spla::exec_mxv_masked(edge, mask, S, parent, spla::MUL_PAIR, spla::MIN_PAIR, + spla::ALWAYS_PAIR, init_inf); - spla::T_PAIR expected[] = {spla::T_PAIR(4.0f, 4), spla::T_PAIR(7.0f, 0), - spla::T_PAIR(5.0f, 3), spla::T_PAIR(5.0f, 2), - spla::T_PAIR(4.0f, 0), spla::T_PAIR(6.0f, 4), - spla::T_PAIR(8.0f, 3)}; + spla::T_PAIR expected[] = {spla::T_PAIR(4.0f, 4), spla::T_PAIR(7.0f, 0), + spla::T_PAIR(5.0f, 3), spla::T_PAIR(5.0f, 2), + spla::T_PAIR(4.0f, 0), spla::T_PAIR(6.0f, 4), + spla::T_PAIR(8.0f, 3)}; - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - edge->get_pair(i, p); - EXPECT_FLOAT_EQ(p.weight, expected[i].weight); - EXPECT_EQ(p.vertex, expected[i].vertex); - } + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + edge->get_pair(i, p); + EXPECT_FLOAT_EQ(p.weight, expected[i].weight); + EXPECT_EQ(p.vertex, expected[i].vertex); + } } TEST(pair, extract_row) { - int32_t n = 3; - auto S = spla::Matrix::make(n, n, spla::PAIR); - S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); - S->set_pair(1, 0, spla::T_PAIR(4.0f, 4)); - S->set_pair(1, 2, spla::T_PAIR(7.0f, 0)); + int32_t n = 3; + auto S = spla::Matrix::make(n, n, spla::PAIR); + S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); + S->set_pair(1, 0, spla::T_PAIR(4.0f, 4)); + S->set_pair(1, 2, spla::T_PAIR(7.0f, 0)); - auto row1 = spla::Vector::make(n, spla::PAIR); - spla::exec_m_extract_row(row1, S, 1, spla::IDENTITY_PAIR); - spla::T_PAIR expected[] = {spla::T_PAIR(4.0f, 4), spla::T_PAIR(), - spla::T_PAIR(7.0f, 0)}; - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - row1->get_pair(i, p); - EXPECT_FLOAT_EQ(p.weight, expected[i].weight); - EXPECT_EQ(p.vertex, expected[i].vertex); - } + auto row1 = spla::Vector::make(n, spla::PAIR); + spla::exec_m_extract_row(row1, S, 1, spla::IDENTITY_PAIR); + spla::T_PAIR expected[] = {spla::T_PAIR(4.0f, 4), spla::T_PAIR(), + spla::T_PAIR(7.0f, 0)}; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + row1->get_pair(i, p); + EXPECT_FLOAT_EQ(p.weight, expected[i].weight); + EXPECT_EQ(p.vertex, expected[i].vertex); + } } SPLA_GTEST_MAIN From 155e412c5a76d82ea9bc0490fa0dc0a7b462e842 Mon Sep 17 00:00:00 2001 From: polka777 Date: Sun, 26 Apr 2026 12:09:10 +0300 Subject: [PATCH 10/14] Clang-tidy format include/ --- include/spla/algorithm.hpp | 88 +++++++++++++++++++------------------- include/spla/io.hpp | 8 ++-- include/spla/matrix.hpp | 8 ++-- include/spla/pair.hpp | 26 +++++------ include/spla/scalar.hpp | 32 +++++++------- include/spla/type.hpp | 2 +- include/spla/vector.hpp | 4 +- 7 files changed, 84 insertions(+), 84 deletions(-) diff --git a/include/spla/algorithm.hpp b/include/spla/algorithm.hpp index b373265c8..c4f1cbbb4 100644 --- a/include/spla/algorithm.hpp +++ b/include/spla/algorithm.hpp @@ -44,12 +44,12 @@ namespace spla { -/** + /** * @addtogroup spla * @{ */ -/** + /** * @brief Breadth-first search algorithm * * @param v int vector to store reached distances @@ -59,11 +59,11 @@ namespace spla { * * @return ok on success */ -SPLA_API Status -bfs(const ref_ptr &v, const ref_ptr &A, uint s, - const ref_ptr &descriptor = spla::Descriptor::make()); + SPLA_API Status + bfs(const ref_ptr& v, const ref_ptr& A, uint s, + const ref_ptr& descriptor = spla::Descriptor::make()); -/** + /** * @brief Naive breadth-first search algorithm (reference cpu implementation) * * @param v int vector to store reached distances @@ -73,11 +73,11 @@ bfs(const ref_ptr &v, const ref_ptr &A, uint s, * * @return ok on success */ -SPLA_API Status -bfs_naive(std::vector &v, std::vector> &A, uint s, - const ref_ptr &descriptor = spla::Descriptor::make()); + SPLA_API Status + bfs_naive(std::vector& v, std::vector>& A, uint s, + const ref_ptr& descriptor = spla::Descriptor::make()); -/** + /** * @brief Single-source shortest path algorithm * * @param v float vector to store reached distances @@ -88,11 +88,11 @@ bfs_naive(std::vector &v, std::vector> &A, uint s, * * @return ok on success */ -SPLA_API Status -sssp(const ref_ptr &v, const ref_ptr &A, uint s, - const ref_ptr &descriptor = ref_ptr()); + SPLA_API Status + sssp(const ref_ptr& v, const ref_ptr& A, uint s, + const ref_ptr& descriptor = ref_ptr()); -/** + /** * @brief Naive single-source shortest path algorithm (reference cpu * implementation) * @@ -105,12 +105,12 @@ sssp(const ref_ptr &v, const ref_ptr &A, uint s, * * @return ok on success */ -SPLA_API Status -sssp_naive(std::vector &v, std::vector> &Ai, - std::vector> &Ax, uint s, - const ref_ptr &descriptor = spla::Descriptor::make()); + SPLA_API Status + sssp_naive(std::vector& v, std::vector>& Ai, + std::vector>& Ax, uint s, + const ref_ptr& descriptor = spla::Descriptor::make()); -/** + /** * @brief PageRank algorithm * * @param p float vector to store result vertices weights @@ -121,12 +121,12 @@ sssp_naive(std::vector &v, std::vector> &Ai, * * @return ok on success */ -SPLA_API Status -pr(ref_ptr &p, const ref_ptr &A, float alpha = 0.85, - float eps = 1e-6, - const ref_ptr &descriptor = spla::Descriptor::make()); + SPLA_API Status + pr(ref_ptr& p, const ref_ptr& A, float alpha = 0.85, + float eps = 1e-6, + const ref_ptr& descriptor = spla::Descriptor::make()); -/** + /** * @brief Naive PageRank algorithm (reference cpu implementation) * * @param p float vector to store result vertices weights @@ -138,12 +138,12 @@ pr(ref_ptr &p, const ref_ptr &A, float alpha = 0.85, * * @return ok on success */ -SPLA_API Status pr_naive( - std::vector &p, std::vector> &Ai, - std::vector> &Ax, float alpha = 0.85, float eps = 1e-6, - const ref_ptr &descriptor = spla::Descriptor::make()); + SPLA_API Status pr_naive( + std::vector& p, std::vector>& Ai, + std::vector>& Ax, float alpha = 0.85, float eps = 1e-6, + const ref_ptr& descriptor = spla::Descriptor::make()); -/** + /** * @brief Triangles counting algorithm * * @param ntrins Number of triangles counted @@ -153,11 +153,11 @@ SPLA_API Status pr_naive( * * @return ok on success */ -SPLA_API Status -tc(int &ntrins, const ref_ptr &A, const ref_ptr &B, - const ref_ptr &descriptor = spla::Descriptor::make()); + SPLA_API Status + tc(int& ntrins, const ref_ptr& A, const ref_ptr& B, + const ref_ptr& descriptor = spla::Descriptor::make()); -/** + /** * @brief Naive triangles counting algorithm (reference cpu implementation) * * @param ntrins Number of triangles counted @@ -166,10 +166,10 @@ tc(int &ntrins, const ref_ptr &A, const ref_ptr &B, * * @return ok on success */ -SPLA_API Status -tc_naive(int &ntrins, std::vector> &Ai, - const ref_ptr &descriptor = spla::Descriptor::make()); -/** + SPLA_API Status + tc_naive(int& ntrins, std::vector>& Ai, + const ref_ptr& descriptor = spla::Descriptor::make()); + /** * @brief Boruvka's Minimum Spanning Tree algorithm * * Finds the Minimum Spanning Tree of a weighted undirected graph using @@ -184,15 +184,15 @@ tc_naive(int &ntrins, std::vector> &Ai, * * @return ok on success */ -SPLA_API Status -mst(const ref_ptr &T, ref_ptr &S, - const ref_ptr &descriptor = spla::Descriptor::make(), - ref_ptr *task_hnd = nullptr); + SPLA_API Status + mst(const ref_ptr& T, ref_ptr& S, + const ref_ptr& descriptor = spla::Descriptor::make(), + ref_ptr* task_hnd = nullptr); -/** + /** * @} */ -} // namespace spla +}// namespace spla -#endif // SPLA_ALGORITHM_HPP +#endif// SPLA_ALGORITHM_HPP diff --git a/include/spla/io.hpp b/include/spla/io.hpp index 542baf3d2..f60763015 100644 --- a/include/spla/io.hpp +++ b/include/spla/io.hpp @@ -80,9 +80,9 @@ namespace spla { [[nodiscard]] SPLA_API const std::vector& get_Ai() const; [[nodiscard]] SPLA_API const std::vector& get_Aj() const; - [[nodiscard]] SPLA_API const std::vector &get_Aw() const; - [[nodiscard]] SPLA_API uint get_n_rows() const; - [[nodiscard]] SPLA_API uint get_n_cols() const; + [[nodiscard]] SPLA_API const std::vector& get_Aw() const; + [[nodiscard]] SPLA_API uint get_n_rows() const; + [[nodiscard]] SPLA_API uint get_n_cols() const; [[nodiscard]] SPLA_API std::size_t get_n_values() const; private: @@ -90,7 +90,7 @@ namespace spla { std::filesystem::path m_file_path; std::vector m_Ai; std::vector m_Aj; - std::vector m_Aw; + std::vector m_Aw; bool m_base_is_zero = false; uint m_n_rows = 0; uint m_n_cols = 0; diff --git a/include/spla/matrix.hpp b/include/spla/matrix.hpp index c9c32a6e8..645142b0e 100644 --- a/include/spla/matrix.hpp +++ b/include/spla/matrix.hpp @@ -57,13 +57,13 @@ namespace spla { SPLA_API virtual Status set_int(uint row_id, uint col_id, std::int32_t value) = 0; SPLA_API virtual Status set_uint(uint row_id, uint col_id, std::uint32_t value) = 0; SPLA_API virtual Status set_float(uint row_id, uint col_id, float value) = 0; - SPLA_API virtual Status set_pair(uint row_id, uint col_id, - Pair value) = 0; + SPLA_API virtual Status set_pair(uint row_id, uint col_id, + Pair value) = 0; SPLA_API virtual Status get_int(uint row_id, uint col_id, std::int32_t& value) = 0; SPLA_API virtual Status get_uint(uint row_id, uint col_id, std::uint32_t& value) = 0; SPLA_API virtual Status get_float(uint row_id, uint col_id, float& value) = 0; - SPLA_API virtual Status get_pair(uint row_id, uint col_id, - Pair &value) = 0; + SPLA_API virtual Status get_pair(uint row_id, uint col_id, + Pair& value) = 0; SPLA_API virtual Status build(const ref_ptr& keys1, const ref_ptr& keys2, const ref_ptr& values) = 0; SPLA_API virtual Status read(ref_ptr& keys1, ref_ptr& keys2, ref_ptr& values) = 0; SPLA_API virtual Status clear() = 0; diff --git a/include/spla/pair.hpp b/include/spla/pair.hpp index 8799be29f..cf1c93220 100644 --- a/include/spla/pair.hpp +++ b/include/spla/pair.hpp @@ -3,21 +3,21 @@ #include namespace spla { -struct Pair { - float weight; - int vertex; + struct Pair { + float weight; + int vertex; - Pair() : weight(std::numeric_limits::infinity()), vertex(-1) {} - Pair(float w, int v) : weight(w), vertex(v) {} + Pair() : weight(std::numeric_limits::infinity()), vertex(-1) {} + Pair(float w, int v) : weight(w), vertex(v) {} - bool operator<(const Pair &other) const { return weight < other.weight; } - bool operator==(const Pair &other) const { - return weight == other.weight && vertex == other.vertex; - } + bool operator<(const Pair& other) const { return weight < other.weight; } + bool operator==(const Pair& other) const { + return weight == other.weight && vertex == other.vertex; + } - bool operator!=(const Pair &other) const { return !(*this == other); } + bool operator!=(const Pair& other) const { return !(*this == other); } - Pair &operator=(const Pair &other) = default; -}; -} // namespace spla + Pair& operator=(const Pair& other) = default; + }; +}// namespace spla #endif \ No newline at end of file diff --git a/include/spla/scalar.hpp b/include/spla/scalar.hpp index 71533ac17..b8e7f463e 100644 --- a/include/spla/scalar.hpp +++ b/include/spla/scalar.hpp @@ -44,24 +44,24 @@ namespace spla { */ class Scalar : public Object { public: - SPLA_API ~Scalar() override = default; - SPLA_API virtual ref_ptr get_type() = 0; - SPLA_API virtual Status set_int(std::int32_t value) = 0; - SPLA_API virtual Status set_uint(std::uint32_t value) = 0; - SPLA_API virtual Status set_float(float value) = 0; - SPLA_API virtual Status set_pair(Pair value) { - return Status::InvalidArgument; + SPLA_API ~Scalar() override = default; + SPLA_API virtual ref_ptr get_type() = 0; + SPLA_API virtual Status set_int(std::int32_t value) = 0; + SPLA_API virtual Status set_uint(std::uint32_t value) = 0; + SPLA_API virtual Status set_float(float value) = 0; + SPLA_API virtual Status set_pair(Pair value) { + return Status::InvalidArgument; } - SPLA_API virtual Status get_int(std::int32_t& value) = 0; - SPLA_API virtual Status get_uint(std::uint32_t& value) = 0; - SPLA_API virtual Status get_float(float& value) = 0; - SPLA_API virtual Status get_pair(Pair &value) { - return Status::InvalidArgument; + SPLA_API virtual Status get_int(std::int32_t& value) = 0; + SPLA_API virtual Status get_uint(std::uint32_t& value) = 0; + SPLA_API virtual Status get_float(float& value) = 0; + SPLA_API virtual Status get_pair(Pair& value) { + return Status::InvalidArgument; } - SPLA_API virtual T_INT as_int() = 0; - SPLA_API virtual T_UINT as_uint() = 0; - SPLA_API virtual T_FLOAT as_float() = 0; - SPLA_API virtual T_PAIR as_pair() = 0; + SPLA_API virtual T_INT as_int() = 0; + SPLA_API virtual T_UINT as_uint() = 0; + SPLA_API virtual T_FLOAT as_float() = 0; + SPLA_API virtual T_PAIR as_pair() = 0; SPLA_API static ref_ptr make(const ref_ptr& type); SPLA_API static ref_ptr make_int(std::int32_t value); diff --git a/include/spla/type.hpp b/include/spla/type.hpp index 7ed7919c0..afaf70c48 100644 --- a/include/spla/type.hpp +++ b/include/spla/type.hpp @@ -59,7 +59,7 @@ namespace spla { using T_INT = std::int32_t; using T_UINT = std::uint32_t; using T_FLOAT = float; - using T_PAIR = Pair; + using T_PAIR = Pair; SPLA_API extern ref_ptr BOOL; SPLA_API extern ref_ptr INT; diff --git a/include/spla/vector.hpp b/include/spla/vector.hpp index dd54d55df..2af5546c3 100644 --- a/include/spla/vector.hpp +++ b/include/spla/vector.hpp @@ -59,8 +59,8 @@ namespace spla { SPLA_API virtual Status get_int(uint row_id, T_INT& value) = 0; SPLA_API virtual Status get_uint(uint row_id, T_UINT& value) = 0; SPLA_API virtual Status get_float(uint row_id, float& value) = 0; - SPLA_API virtual Status get_pair(uint row_id, Pair &value) = 0; - SPLA_API virtual Status set_pair(uint row_id, Pair value) = 0; + SPLA_API virtual Status get_pair(uint row_id, Pair& value) = 0; + SPLA_API virtual Status set_pair(uint row_id, Pair value) = 0; SPLA_API virtual Status fill_noize(uint seed) = 0; SPLA_API virtual Status fill_with(const ref_ptr& value) = 0; SPLA_API virtual Status build(const ref_ptr& keys, const ref_ptr& values) = 0; From b3f52669a9492b739a0f4bcbe5f0c0b0a9320b32 Mon Sep 17 00:00:00 2001 From: polka777 Date: Sun, 26 Apr 2026 12:11:14 +0300 Subject: [PATCH 11/14] Clang-tidy format src/ --- src/algorithm.cpp | 1026 ++++++++++++++++++++++----------------------- src/io.cpp | 800 +++++++++++++++++------------------ src/matrix.cpp | 2 +- src/op.cpp | 754 ++++++++++++++++----------------- src/scalar.cpp | 62 +-- src/type.cpp | 22 +- src/vector.cpp | 52 +-- 7 files changed, 1360 insertions(+), 1358 deletions(-) diff --git a/src/algorithm.cpp b/src/algorithm.cpp index dc3cb493d..74f2c1901 100644 --- a/src/algorithm.cpp +++ b/src/algorithm.cpp @@ -53,663 +53,663 @@ namespace spla { #pragma region Bfs -Status bfs(const ref_ptr &v, const ref_ptr &A, uint s, - const ref_ptr &descriptor) { - assert(v); - assert(A); + Status bfs(const ref_ptr& v, const ref_ptr& A, uint s, + const ref_ptr& descriptor) { + assert(v); + assert(A); - const auto N = v->get_n_rows(); + const auto N = v->get_n_rows(); - ref_ptr frontier_prev = Vector::make(N, INT); - ref_ptr frontier_new = Vector::make(N, INT); - ref_ptr frontier_size = Scalar::make_int(1); - ref_ptr depth = Scalar::make_int(1); - ref_ptr zero = Scalar::make_int(0); - int current_level = 1; - int discovered = 1; - bool frontier_empty = false; + ref_ptr frontier_prev = Vector::make(N, INT); + ref_ptr frontier_new = Vector::make(N, INT); + ref_ptr frontier_size = Scalar::make_int(1); + ref_ptr depth = Scalar::make_int(1); + ref_ptr zero = Scalar::make_int(0); + int current_level = 1; + int discovered = 1; + bool frontier_empty = false; - ref_ptr desc = Descriptor::make(); - desc->set_early_exit(true); - desc->set_struct_only(true); + ref_ptr desc = Descriptor::make(); + desc->set_early_exit(true); + desc->set_struct_only(true); - frontier_prev->set_int(s, 1); + frontier_prev->set_int(s, 1); - bool push = descriptor->get_push_only(); - bool pull = descriptor->get_pull_only(); - bool push_pull = descriptor->get_push_pull(); - float front_factor = descriptor->get_front_factor(); + bool push = descriptor->get_push_only(); + bool pull = descriptor->get_pull_only(); + bool push_pull = descriptor->get_push_pull(); + float front_factor = descriptor->get_front_factor(); - if (!(push || pull || push_pull)) - push = true; + if (!(push || pull || push_pull)) + push = true; #ifndef SPLA_RELEASE - std::string mode; - if (push_pull) - mode = "(push_pull " + std::to_string(front_factor * 100.0f) + "%)"; - if (pull) - mode = "(pull)"; - if (push) - mode = "(push)"; + std::string mode; + if (push_pull) + mode = "(push_pull " + std::to_string(front_factor * 100.0f) + "%)"; + if (pull) + mode = "(pull)"; + if (push) + mode = "(push)"; - std::cout << "start bfs from " << s << " " << mode << std::endl; + std::cout << "start bfs from " << s << " " << mode << std::endl; - Timer tight; + Timer tight; #endif - while (!frontier_empty) { + while (!frontier_empty) { #ifndef SPLA_RELEASE - tight.start(); + tight.start(); #endif - depth->set_int(current_level); - exec_v_assign_masked(v, frontier_prev, depth, SECOND_INT, NQZERO_INT); - - float front_density = float(frontier_size->as_int()) / float(N); - bool is_push_better = (front_density <= front_factor); - - if (push || (push_pull && is_push_better)) { - exec_vxm_masked(frontier_new, v, frontier_prev, A, BAND_INT, BOR_INT, - EQZERO_INT, zero, desc); - } else { - exec_mxv_masked(frontier_new, v, A, frontier_prev, BAND_INT, BOR_INT, - EQZERO_INT, zero, desc); - } + depth->set_int(current_level); + exec_v_assign_masked(v, frontier_prev, depth, SECOND_INT, NQZERO_INT); + + float front_density = float(frontier_size->as_int()) / float(N); + bool is_push_better = (front_density <= front_factor); + + if (push || (push_pull && is_push_better)) { + exec_vxm_masked(frontier_new, v, frontier_prev, A, BAND_INT, BOR_INT, + EQZERO_INT, zero, desc); + } else { + exec_mxv_masked(frontier_new, v, A, frontier_prev, BAND_INT, BOR_INT, + EQZERO_INT, zero, desc); + } - exec_v_count_mf(frontier_size, frontier_new); + exec_v_count_mf(frontier_size, frontier_new); #ifndef SPLA_RELEASE - tight.stop(); - std::cout << " - iter " << current_level << " front " - << frontier_size->as_int() << " discovered " << discovered << " " - << tight.get_elapsed_ms() << " ms" << std::endl; - Library::get()->time_profile_dump(); - Library::get()->time_profile_reset(); + tight.stop(); + std::cout << " - iter " << current_level << " front " + << frontier_size->as_int() << " discovered " << discovered << " " + << tight.get_elapsed_ms() << " ms" << std::endl; + Library::get()->time_profile_dump(); + Library::get()->time_profile_reset(); #endif - frontier_empty = frontier_size->as_int() == 0; - discovered += frontier_size->as_int(); - current_level += 1; + frontier_empty = frontier_size->as_int() == 0; + discovered += frontier_size->as_int(); + current_level += 1; - std::swap(frontier_prev, frontier_new); - } + std::swap(frontier_prev, frontier_new); + } - return Status::Ok; -} + return Status::Ok; + } -Status bfs_naive(std::vector &v, std::vector> &A, - uint s, const ref_ptr &descriptor) { + Status bfs_naive(std::vector& v, std::vector>& A, + uint s, const ref_ptr& descriptor) { - const auto N = v.size(); + const auto N = v.size(); - std::queue front; - std::vector visited(N, false); + std::queue front; + std::vector visited(N, false); - std::fill(v.begin(), v.end(), 0); + std::fill(v.begin(), v.end(), 0); - front.push(s); - visited[s] = true; - v[s] = 1; + front.push(s); + visited[s] = true; + v[s] = 1; - while (!front.empty()) { - auto i = front.front(); - front.pop(); + while (!front.empty()) { + auto i = front.front(); + front.pop(); - for (auto j : A[i]) { - if (!visited[j]) { - visited[j] = true; - v[j] = v[i] + 1; - front.push(j); - } - } - } + for (auto j : A[i]) { + if (!visited[j]) { + visited[j] = true; + v[j] = v[i] + 1; + front.push(j); + } + } + } - return Status::Ok; -} + return Status::Ok; + } #pragma endregion Bfs #pragma region Sssp -Status sssp(const ref_ptr &v, const ref_ptr &A, uint s, - const ref_ptr &descriptor) { - assert(v); - assert(A); + Status sssp(const ref_ptr& v, const ref_ptr& A, uint s, + const ref_ptr& descriptor) { + assert(v); + assert(A); - const auto N = v->get_n_rows(); - const auto inf = std::numeric_limits::max(); + const auto N = v->get_n_rows(); + const auto inf = std::numeric_limits::max(); - ref_ptr dummy_mask = Vector::make(N, FLOAT); - ref_ptr frontier = Vector::make(N, FLOAT); - ref_ptr feedback = Vector::make(N, FLOAT); - ref_ptr feedback_size = Scalar::make_int(0); - ref_ptr inf_init = Scalar::make_float(inf); - int current_level = 1; - bool feedback_empty = false; + ref_ptr dummy_mask = Vector::make(N, FLOAT); + ref_ptr frontier = Vector::make(N, FLOAT); + ref_ptr feedback = Vector::make(N, FLOAT); + ref_ptr feedback_size = Scalar::make_int(0); + ref_ptr inf_init = Scalar::make_float(inf); + int current_level = 1; + bool feedback_empty = false; - v->set_fill_value(inf_init); - feedback->set_fill_value(inf_init); - frontier->set_fill_value(inf_init); + v->set_fill_value(inf_init); + feedback->set_fill_value(inf_init); + frontier->set_fill_value(inf_init); - v->set_float(s, 0.0f); - feedback->set_float(s, 0.0f); + v->set_float(s, 0.0f); + feedback->set_float(s, 0.0f); - bool push = descriptor->get_push_only(); - bool pull = descriptor->get_pull_only(); - bool push_pull = descriptor->get_push_pull(); - float front_factor = descriptor->get_front_factor(); + bool push = descriptor->get_push_only(); + bool pull = descriptor->get_pull_only(); + bool push_pull = descriptor->get_push_pull(); + float front_factor = descriptor->get_front_factor(); - if (!(push || pull || push_pull)) - push = true; + if (!(push || pull || push_pull)) + push = true; #ifndef SPLA_RELEASE - std::string mode; - if (push_pull) - mode = "(push_pull " + std::to_string(front_factor * 100.0f) + "%)"; - if (pull) - mode = "(pull)"; - if (push) - mode = "(push)"; + std::string mode; + if (push_pull) + mode = "(push_pull " + std::to_string(front_factor * 100.0f) + "%)"; + if (pull) + mode = "(pull)"; + if (push) + mode = "(push)"; - std::cout << "start sssp from " << s << " " << mode << std::endl; + std::cout << "start sssp from " << s << " " << mode << std::endl; - Timer tight; + Timer tight; #endif - while (!feedback_empty) { + while (!feedback_empty) { #ifndef SPLA_RELEASE - tight.start(); + tight.start(); #endif - float front_density = float(feedback_size->as_int()) / float(N); - bool is_push_better = (front_density <= front_factor); - - if (push || (push_pull && is_push_better)) { - exec_vxm_masked(frontier, dummy_mask, feedback, A, PLUS_FLOAT, MIN_FLOAT, - ALWAYS_FLOAT, inf_init); - } else { - exec_mxv_masked(frontier, dummy_mask, A, feedback, PLUS_FLOAT, MIN_FLOAT, - ALWAYS_FLOAT, inf_init); - } + float front_density = float(feedback_size->as_int()) / float(N); + bool is_push_better = (front_density <= front_factor); + + if (push || (push_pull && is_push_better)) { + exec_vxm_masked(frontier, dummy_mask, feedback, A, PLUS_FLOAT, MIN_FLOAT, + ALWAYS_FLOAT, inf_init); + } else { + exec_mxv_masked(frontier, dummy_mask, A, feedback, PLUS_FLOAT, MIN_FLOAT, + ALWAYS_FLOAT, inf_init); + } - exec_v_eadd_fdb(v, frontier, feedback, MIN_FLOAT); - exec_v_count_mf(feedback_size, feedback); + exec_v_eadd_fdb(v, frontier, feedback, MIN_FLOAT); + exec_v_count_mf(feedback_size, feedback); #ifndef SPLA_RELEASE - tight.stop(); - std::cout << " - iter " << current_level << " feed " - << feedback_size->as_int() << " " << tight.get_elapsed_ms() - << " ms" << std::endl; - Library::get()->time_profile_dump(); - Library::get()->time_profile_reset(); + tight.stop(); + std::cout << " - iter " << current_level << " feed " + << feedback_size->as_int() << " " << tight.get_elapsed_ms() + << " ms" << std::endl; + Library::get()->time_profile_dump(); + Library::get()->time_profile_reset(); #endif - feedback_empty = feedback_size->as_int() == 0; - current_level += 1; - } - - return Status::Ok; -} - -Status sssp_naive(std::vector &v, std::vector> &Ai, - std::vector> &Ax, uint s, - const ref_ptr &descriptor) { - - const auto N = v.size(); - const auto inf = std::numeric_limits::max(); - - std::queue front; - std::vector in_queue(N, false); - std::fill(v.begin(), v.end(), inf); - - front.push(s); - in_queue[s] = true; - v[s] = 0.0f; - - while (!front.empty()) { - auto i = front.front(); - front.pop(); - in_queue[i] = false; - - const auto &col_ids = Ai[i]; - const auto &col_vals = Ax[i]; - const auto n_vals = col_ids.size(); - - for (std::size_t k = 0; k < n_vals; k += 1) { - const uint j = col_ids[k]; - const float w = col_vals[k]; - - if (v[j] == inf || v[i] + w < v[j]) { - v[j] = v[i] + w; - if (!in_queue[j]) { - in_queue[j] = true; - front.push(j); + feedback_empty = feedback_size->as_int() == 0; + current_level += 1; } - } + + return Status::Ok; } - } - return Status::Ok; -} + Status sssp_naive(std::vector& v, std::vector>& Ai, + std::vector>& Ax, uint s, + const ref_ptr& descriptor) { + + const auto N = v.size(); + const auto inf = std::numeric_limits::max(); + + std::queue front; + std::vector in_queue(N, false); + std::fill(v.begin(), v.end(), inf); + + front.push(s); + in_queue[s] = true; + v[s] = 0.0f; + + while (!front.empty()) { + auto i = front.front(); + front.pop(); + in_queue[i] = false; + + const auto& col_ids = Ai[i]; + const auto& col_vals = Ax[i]; + const auto n_vals = col_ids.size(); + + for (std::size_t k = 0; k < n_vals; k += 1) { + const uint j = col_ids[k]; + const float w = col_vals[k]; + + if (v[j] == inf || v[i] + w < v[j]) { + v[j] = v[i] + w; + if (!in_queue[j]) { + in_queue[j] = true; + front.push(j); + } + } + } + } + + return Status::Ok; + } #pragma endregion Sssp #pragma region Pr -Status pr(ref_ptr &p, const ref_ptr &A, float alpha, float eps, - const ref_ptr &descriptor) { - assert(p); - assert(A); + Status pr(ref_ptr& p, const ref_ptr& A, float alpha, float eps, + const ref_ptr& descriptor) { + assert(p); + assert(A); - const auto N = p->get_n_rows(); + const auto N = p->get_n_rows(); - ref_ptr dummy_mask = Vector::make(N, FLOAT); - ref_ptr p_prev = Vector::make(N, FLOAT); - ref_ptr p_tmp = Vector::make(N, FLOAT); - ref_ptr addition = Vector::make(N, FLOAT); - ref_ptr errors = Vector::make(N, FLOAT); - ref_ptr error2 = Scalar::make(FLOAT); - ref_ptr zero = Scalar::make_float(0.0f); + ref_ptr dummy_mask = Vector::make(N, FLOAT); + ref_ptr p_prev = Vector::make(N, FLOAT); + ref_ptr p_tmp = Vector::make(N, FLOAT); + ref_ptr addition = Vector::make(N, FLOAT); + ref_ptr errors = Vector::make(N, FLOAT); + ref_ptr error2 = Scalar::make(FLOAT); + ref_ptr zero = Scalar::make_float(0.0f); - addition->fill_with(Scalar::make_float((1.0f - alpha) / float(N))); - p_prev->fill_with(Scalar::make_float(1.0f / float(N))); + addition->fill_with(Scalar::make_float((1.0f - alpha) / float(N))); + p_prev->fill_with(Scalar::make_float(1.0f / float(N))); - float error = eps + 0.1f; + float error = eps + 0.1f; #ifndef SPLA_RELEASE - int iter = 0; + int iter = 0; - std::cout << "start pr alpha=" << alpha << " eps " << eps << std::endl; + std::cout << "start pr alpha=" << alpha << " eps " << eps << std::endl; - Timer tight; + Timer tight; #endif - while (error > eps) { + while (error > eps) { #ifndef SPLA_RELEASE - tight.start(); + tight.start(); #endif - // p = A*p + (1-alpha)/N - exec_mxv_masked(p_tmp, dummy_mask, A, p_prev, MULT_FLOAT, PLUS_FLOAT, - ALWAYS_FLOAT, zero); - exec_v_eadd(p, p_tmp, addition, PLUS_FLOAT); + // p = A*p + (1-alpha)/N + exec_mxv_masked(p_tmp, dummy_mask, A, p_prev, MULT_FLOAT, PLUS_FLOAT, + ALWAYS_FLOAT, zero); + exec_v_eadd(p, p_tmp, addition, PLUS_FLOAT); - // error = sqrt((p[01]-prev[0])^2 + ... + p[N-1]-prev[N-1])^2) - exec_v_eadd(errors, p, p_prev, MINUS_POW2_FLOAT); - exec_v_reduce(error2, zero, errors, PLUS_FLOAT); + // error = sqrt((p[01]-prev[0])^2 + ... + p[N-1]-prev[N-1])^2) + exec_v_eadd(errors, p, p_prev, MINUS_POW2_FLOAT); + exec_v_reduce(error2, zero, errors, PLUS_FLOAT); - error = std::sqrt(error2->as_float()); + error = std::sqrt(error2->as_float()); - std::swap(p, p_prev); + std::swap(p, p_prev); #ifndef SPLA_RELEASE - tight.stop(); - std::cout << " - iter " << iter++ << " error " << error << " " - << tight.get_elapsed_ms() << " ms" << std::endl; - Library::get()->time_profile_dump(); - Library::get()->time_profile_reset(); + tight.stop(); + std::cout << " - iter " << iter++ << " error " << error << " " + << tight.get_elapsed_ms() << " ms" << std::endl; + Library::get()->time_profile_dump(); + Library::get()->time_profile_reset(); #endif - } + } - std::swap(p, p_prev); - return Status::Ok; -} + std::swap(p, p_prev); + return Status::Ok; + } -Status pr_naive(std::vector &p, std::vector> &Ai, - std::vector> &Ax, float alpha, float eps, - const ref_ptr &descriptor) { + Status pr_naive(std::vector& p, std::vector>& Ai, + std::vector>& Ax, float alpha, float eps, + const ref_ptr& descriptor) { - const auto N = p.size(); + const auto N = p.size(); - std::vector p_prev(N, 1.0f / float(N)); + std::vector p_prev(N, 1.0f / float(N)); - float error = eps + 0.1f; + float error = eps + 0.1f; - while (error > eps) { - for (std::size_t i = 0; i < N; i++) { - p[i] = 0; + while (error > eps) { + for (std::size_t i = 0; i < N; i++) { + p[i] = 0; - for (std::size_t k = 0; k < Ai[i].size(); k++) { - p[i] += Ax[i][k] * p_prev[Ai[i][k]]; - } + for (std::size_t k = 0; k < Ai[i].size(); k++) { + p[i] += Ax[i][k] * p_prev[Ai[i][k]]; + } - p[i] += (1.0f - alpha) / float(N); - } + p[i] += (1.0f - alpha) / float(N); + } - error = 0.0f; + error = 0.0f; - for (std::size_t i = 0; i < N; i++) { - error += (p[i] - p_prev[i]) * (p[i] - p_prev[i]); - } + for (std::size_t i = 0; i < N; i++) { + error += (p[i] - p_prev[i]) * (p[i] - p_prev[i]); + } - error = std::sqrt(error); + error = std::sqrt(error); - std::swap(p, p_prev); - } + std::swap(p, p_prev); + } - std::swap(p, p_prev); - return Status::Ok; -} + std::swap(p, p_prev); + return Status::Ok; + } #pragma endregion Pr #pragma region Tc -Status tc(int &ntrins, const ref_ptr &A, const ref_ptr &B, - const ref_ptr &descriptor) { - assert(A); - assert(B); + Status tc(int& ntrins, const ref_ptr& A, const ref_ptr& B, + const ref_ptr& descriptor) { + assert(A); + assert(B); - ref_ptr zero = Scalar::make_int(0); - ref_ptr result = Scalar::make(INT); + ref_ptr zero = Scalar::make_int(0); + ref_ptr result = Scalar::make(INT); #ifndef SPLA_RELEASE - std::cout << "start tc" << std::endl; + std::cout << "start tc" << std::endl; - Timer tight; - tight.start(); + Timer tight; + tight.start(); #endif - spla::exec_mxmT_masked(B, A, A, A, MULT_INT, PLUS_INT, GTZERO_INT, zero); - spla::exec_m_reduce(result, zero, B, PLUS_INT); + spla::exec_mxmT_masked(B, A, A, A, MULT_INT, PLUS_INT, GTZERO_INT, zero); + spla::exec_m_reduce(result, zero, B, PLUS_INT); - ntrins = result->as_int(); + ntrins = result->as_int(); #ifndef SPLA_RELEASE - tight.stop(); + tight.stop(); - std::cout << " - ntrins " << ntrins << " " << tight.get_elapsed_ms() << " ms" - << std::endl; + std::cout << " - ntrins " << ntrins << " " << tight.get_elapsed_ms() << " ms" + << std::endl; - Library::get()->time_profile_dump(); - Library::get()->time_profile_reset(); + Library::get()->time_profile_dump(); + Library::get()->time_profile_reset(); #endif - return Status::Ok; -} + return Status::Ok; + } -Status tc_naive(int &ntrins, std::vector> &Ai, - const ref_ptr &descriptor) { + Status tc_naive(int& ntrins, std::vector>& Ai, + const ref_ptr& descriptor) { - ntrins = 0; + ntrins = 0; - for (const auto &row_Ai : Ai) { - for (const auto neighbor : row_Ai) { - const auto &row_neighbor = Ai[neighbor]; + for (const auto& row_Ai : Ai) { + for (const auto neighbor : row_Ai) { + const auto& row_neighbor = Ai[neighbor]; - auto it1 = row_Ai.begin(); - auto it2 = row_neighbor.begin(); + auto it1 = row_Ai.begin(); + auto it2 = row_neighbor.begin(); - auto end1 = row_Ai.end(); - auto end2 = row_neighbor.end(); + auto end1 = row_Ai.end(); + auto end2 = row_neighbor.end(); - while (it1 != end1 && it2 != end2) { - if (*it1 == *it2) { - ++ntrins; - ++it1; - ++it2; - } else if (*it1 < *it2) { - ++it1; - } else { - ++it2; + while (it1 != end1 && it2 != end2) { + if (*it1 == *it2) { + ++ntrins; + ++it1; + ++it2; + } else if (*it1 < *it2) { + ++it1; + } else { + ++it2; + } + } + } } - } - } - } - return Status::Ok; -} + return Status::Ok; + } #pragma endregion Pr #pragma region Mst -Status mst(const ref_ptr &T, ref_ptr &S, - const ref_ptr &descriptor, - ref_ptr *task_hnd) { - - assert(S); - assert(T); - - const auto n = S->get_n_rows(); - int comp = n; - - auto parent = Vector::make(n, PAIR); - for (uint i = 0; i < n; i++) { - parent->set_pair(i, T_PAIR(0.0f, i)); - } - auto edge = Vector::make(n, PAIR); - auto cedge = Vector::make(n, PAIR); - auto t_vec = Vector::make(n, PAIR); - auto mask = Vector::make(n, PAIR); - for (uint i = 0; i < n; i++) { - mask->set_pair(i, T_PAIR(1.0f, 0)); - } - auto init_inf = Scalar::make(PAIR); - T_PAIR init_val; - init_inf->set_pair(init_val); - int iteration = 0; - auto new_S = S; + Status mst(const ref_ptr& T, ref_ptr& S, + const ref_ptr& descriptor, + ref_ptr* task_hnd) { + + assert(S); + assert(T); + + const auto n = S->get_n_rows(); + int comp = n; + + auto parent = Vector::make(n, PAIR); + for (uint i = 0; i < n; i++) { + parent->set_pair(i, T_PAIR(0.0f, i)); + } + auto edge = Vector::make(n, PAIR); + auto cedge = Vector::make(n, PAIR); + auto t_vec = Vector::make(n, PAIR); + auto mask = Vector::make(n, PAIR); + for (uint i = 0; i < n; i++) { + mask->set_pair(i, T_PAIR(1.0f, 0)); + } + auto init_inf = Scalar::make(PAIR); + T_PAIR init_val; + init_inf->set_pair(init_val); + int iteration = 0; + auto new_S = S; #ifdef SPLA_RELEASE - std::cout << "start Boruvka MST, vertices = " << n << "\n"; - Timer tight; + std::cout << "start Boruvka MST, vertices = " << n << "\n"; + Timer tight; #endif - while (comp > 1) { + while (comp > 1) { #ifdef SPLA_RELEASE - tight.start(); + tight.start(); #endif - iteration++; - int edges_added_this_iteration = 0; - spla::exec_mxv_masked(edge, mask, S, parent, spla::MUL_PAIR, spla::MIN_PAIR, - spla::ALWAYS_PAIR, init_inf); + iteration++; + int edges_added_this_iteration = 0; + spla::exec_mxv_masked(edge, mask, S, parent, spla::MUL_PAIR, spla::MIN_PAIR, + spla::ALWAYS_PAIR, init_inf); #ifdef SPLA_DEBUG - std::cout << "edge = ["; - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - edge->get_pair(i, p); - std::cout << "(" << p.weight << ", " << p.vertex << "), "; - } - std::cout << "]\n"; + std::cout << "edge = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + edge->get_pair(i, p); + std::cout << "(" << p.weight << ", " << p.vertex << "), "; + } + std::cout << "]\n"; #endif - for (int32_t i = 0; i < n; i++) { - cedge->set_pair(i, init_val); - } - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - spla::T_PAIR p1; - spla::T_PAIR p2; - parent->get_pair(i, p); - auto p_i = p.vertex; - cedge->get_pair(p_i, p1); - edge->get_pair(i, p2); - auto min_for_comp = p1.weight <= p2.weight ? p1 : p2; - cedge->set_pair(p_i, min_for_comp); - } + for (int32_t i = 0; i < n; i++) { + cedge->set_pair(i, init_val); + } + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + spla::T_PAIR p1; + spla::T_PAIR p2; + parent->get_pair(i, p); + auto p_i = p.vertex; + cedge->get_pair(p_i, p1); + edge->get_pair(i, p2); + auto min_for_comp = p1.weight <= p2.weight ? p1 : p2; + cedge->set_pair(p_i, min_for_comp); + } #ifdef SPLA_DEBUG - std::cout << "cedge = ["; - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - cedge->get_pair(i, p); - std::cout << "(" << p.weight << ", " << p.vertex << "), "; - } - std::cout << "]\n"; + std::cout << "cedge = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + cedge->get_pair(i, p); + std::cout << "(" << p.weight << ", " << p.vertex << "), "; + } + std::cout << "]\n"; #endif - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR parent_v; - spla::T_PAIR cedge_v; - parent->get_pair(i, parent_v); - cedge->get_pair(parent_v.vertex, cedge_v); - t_vec->set_pair(i, cedge_v); - } + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR parent_v; + spla::T_PAIR cedge_v; + parent->get_pair(i, parent_v); + cedge->get_pair(parent_v.vertex, cedge_v); + t_vec->set_pair(i, cedge_v); + } - auto index = spla::Vector::make(n, spla::INT); + auto index = spla::Vector::make(n, spla::INT); - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR edge_v, t_v; - edge->get_pair(i, edge_v); - t_vec->get_pair(i, t_v); - if (edge_v == t_v) - index->set_int(i, i); - else - index->set_int(i, n); - } - auto temp = spla::Vector::make(n, spla::INT); - for (int32_t i = 0; i < n; i++) - temp->set_int(i, n); - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR parent_v; - parent->get_pair(i, parent_v); - auto p_i = parent_v.vertex; - spla::T_INT temp_v, ind_v; - temp->get_int(p_i, temp_v); - index->get_int(i, ind_v); - spla::T_INT min_v = temp_v < ind_v ? temp_v : ind_v; - temp->set_int(p_i, min_v); - } + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR edge_v, t_v; + edge->get_pair(i, edge_v); + t_vec->get_pair(i, t_v); + if (edge_v == t_v) + index->set_int(i, i); + else + index->set_int(i, n); + } + auto temp = spla::Vector::make(n, spla::INT); + for (int32_t i = 0; i < n; i++) + temp->set_int(i, n); + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR parent_v; + parent->get_pair(i, parent_v); + auto p_i = parent_v.vertex; + spla::T_INT temp_v, ind_v; + temp->get_int(p_i, temp_v); + index->get_int(i, ind_v); + spla::T_INT min_v = temp_v < ind_v ? temp_v : ind_v; + temp->set_int(p_i, min_v); + } #ifdef SPLA_DEBUG - std::cout << "t = ["; - for (int32_t i = 0; i < n; i++) { - spla::T_INT p; - temp->get_int(i, p); - std::cout << p << ", "; - } - std::cout << "]\n"; + std::cout << "t = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_INT p; + temp->get_int(i, p); + std::cout << p << ", "; + } + std::cout << "]\n"; #endif - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR parent_v; - parent->get_pair(i, parent_v); - auto p_i = parent_v.vertex; - spla::T_INT temp_v; - temp->get_int(p_i, temp_v); - index->set_int(i, temp_v); - } + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR parent_v; + parent->get_pair(i, parent_v); + auto p_i = parent_v.vertex; + spla::T_INT temp_v; + temp->get_int(p_i, temp_v); + index->set_int(i, temp_v); + } #ifdef SPLA_DEBUG - std::cout << "index = ["; - for (int32_t i = 0; i < n; i++) { - spla::T_INT p; - index->get_int(i, p); - std::cout << p << ", "; - } - std::cout << "]\n"; + std::cout << "index = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_INT p; + index->get_int(i, p); + std::cout << p << ", "; + } + std::cout << "]\n"; #endif - auto new_parent = spla::Vector::make(n, spla::PAIR); - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - parent->get_pair(i, p); - new_parent->set_pair(i, p); - } - for (int32_t i = 0; i < n; i++) { - spla::T_INT ind_v; - index->get_int(i, ind_v); - if (i == ind_v) { - auto row = spla::Vector::make(n, spla::PAIR); - spla::exec_m_extract_row(row, S, i, spla::IDENTITY_PAIR); - int min_vertex = -1; - float min_weight = INF; - - for (int32_t j = 0; j < n; j++) { - spla::T_PAIR pair_row; - row->get_pair(j, pair_row); - auto pair_row_weight = pair_row.weight; - auto pair_row_vertex = pair_row.vertex; - if (pair_row_weight < INF) { - spla::T_PAIR p1, p2; - parent->get_pair(i, p1); - parent->get_pair(pair_row_vertex, p2); - if (p1.vertex != p2.vertex) { - if (pair_row_weight < min_weight) { - min_weight = pair_row_weight; - min_vertex = j; - } + auto new_parent = spla::Vector::make(n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + parent->get_pair(i, p); + new_parent->set_pair(i, p); + } + for (int32_t i = 0; i < n; i++) { + spla::T_INT ind_v; + index->get_int(i, ind_v); + if (i == ind_v) { + auto row = spla::Vector::make(n, spla::PAIR); + spla::exec_m_extract_row(row, S, i, spla::IDENTITY_PAIR); + int min_vertex = -1; + float min_weight = INF; + + for (int32_t j = 0; j < n; j++) { + spla::T_PAIR pair_row; + row->get_pair(j, pair_row); + auto pair_row_weight = pair_row.weight; + auto pair_row_vertex = pair_row.vertex; + if (pair_row_weight < INF) { + spla::T_PAIR p1, p2; + parent->get_pair(i, p1); + parent->get_pair(pair_row_vertex, p2); + if (p1.vertex != p2.vertex) { + if (pair_row_weight < min_weight) { + min_weight = pair_row_weight; + min_vertex = j; + } + } + } + } + if (min_vertex == -1) + continue; + T->set_float(i, min_vertex, min_weight); + T->set_float(min_vertex, i, min_weight); + edges_added_this_iteration++; + if (i < min_vertex) { + spla::T_PAIR p; + spla::T_PAIR old_p; + new_parent->get_pair(i, p); + new_parent->get_pair(min_vertex, old_p); + new_parent->set_pair(min_vertex, spla::T_PAIR(0.0f, p.vertex)); + for (int k = 0; k < n; k++) { + spla::T_PAIR p1; + new_parent->get_pair(k, p1); + if (p1.vertex == old_p.vertex) + new_parent->set_pair(k, spla::T_PAIR(0.0f, p.vertex)); + } + } else { + spla::T_PAIR p; + spla::T_PAIR old_p; + new_parent->get_pair(min_vertex, p); + new_parent->get_pair(i, old_p); + new_parent->set_pair(i, spla::T_PAIR(0.0f, p.vertex)); + for (int k = 0; k < n; k++) { + spla::T_PAIR p1; + new_parent->get_pair(k, p1); + if (p1.vertex == old_p.vertex) + new_parent->set_pair(k, spla::T_PAIR(0.0f, p.vertex)); + } + } + } + } + parent = new_parent; + std::vector seen(n, false); + for (uint i = 0; i < n; i++) { + T_PAIR p; + parent->get_pair(i, p); + seen[p.vertex] = true; } - } - } - if (min_vertex == -1) - continue; - T->set_float(i, min_vertex, min_weight); - T->set_float(min_vertex, i, min_weight); - edges_added_this_iteration++; - if (i < min_vertex) { - spla::T_PAIR p; - spla::T_PAIR old_p; - new_parent->get_pair(i, p); - new_parent->get_pair(min_vertex, old_p); - new_parent->set_pair(min_vertex, spla::T_PAIR(0.0f, p.vertex)); - for (int k = 0; k < n; k++) { - spla::T_PAIR p1; - new_parent->get_pair(k, p1); - if (p1.vertex == old_p.vertex) - new_parent->set_pair(k, spla::T_PAIR(0.0f, p.vertex)); - } - } else { - spla::T_PAIR p; - spla::T_PAIR old_p; - new_parent->get_pair(min_vertex, p); - new_parent->get_pair(i, old_p); - new_parent->set_pair(i, spla::T_PAIR(0.0f, p.vertex)); - for (int k = 0; k < n; k++) { - spla::T_PAIR p1; - new_parent->get_pair(k, p1); - if (p1.vertex == old_p.vertex) - new_parent->set_pair(k, spla::T_PAIR(0.0f, p.vertex)); - } - } - } - } - parent = new_parent; - std::vector seen(n, false); - for (uint i = 0; i < n; i++) { - T_PAIR p; - parent->get_pair(i, p); - seen[p.vertex] = true; - } - comp = 0; - for (uint i = 0; i < n; i++) { - if (seen[i]) - comp++; - } + comp = 0; + for (uint i = 0; i < n; i++) { + if (seen[i]) + comp++; + } #ifdef SPLA_DEBUG - std::cout << "parent = ["; - for (int32_t i = 0; i < n; i++) { - spla::T_PAIR p; - parent->get_pair(i, p); - std::cout << p.vertex << ", "; - } - std::cout << "]\n"; + std::cout << "parent = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + parent->get_pair(i, p); + std::cout << p.vertex << ", "; + } + std::cout << "]\n"; #endif #ifdef SPLA_RELEASE - tight.stop(); - std::cout << " - iteration " << iteration << " components " << comp << " " - << tight.get_elapsed_ms() << " ms" << std::endl; - Library::get()->time_profile_dump(); - Library::get()->time_profile_reset(); + tight.stop(); + std::cout << " - iteration " << iteration << " components " << comp << " " + << tight.get_elapsed_ms() << " ms" << std::endl; + Library::get()->time_profile_dump(); + Library::get()->time_profile_reset(); #endif - if (comp == 1) { - std::cout << "MST complete after " << iteration << " iterations" - << std::endl; - return Status::Ok; - } - auto filtered_S = spla::Matrix::make(n, n, spla::PAIR); - for (int32_t i = 0; i < n; i++) { - for (int32_t j = 0; j < n; j++) { - spla::T_PAIR val; - S->get_pair(i, j, val); - if (val.weight != std::numeric_limits::infinity()) { - spla::T_PAIR parent_i, parent_j; - parent->get_pair(i, parent_i); - parent->get_pair(j, parent_j); - if ((parent_i.vertex != parent_j.vertex) && - (val.weight != std::numeric_limits::infinity())) { - filtered_S->set_pair(i, j, val); - } + if (comp == 1) { + std::cout << "MST complete after " << iteration << " iterations" + << std::endl; + return Status::Ok; + } + auto filtered_S = spla::Matrix::make(n, n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + for (int32_t j = 0; j < n; j++) { + spla::T_PAIR val; + S->get_pair(i, j, val); + if (val.weight != std::numeric_limits::infinity()) { + spla::T_PAIR parent_i, parent_j; + parent->get_pair(i, parent_i); + parent->get_pair(j, parent_j); + if ((parent_i.vertex != parent_j.vertex) && + (val.weight != std::numeric_limits::infinity())) { + filtered_S->set_pair(i, j, val); + } + } + } + } + S = filtered_S; + if (edges_added_this_iteration == 0) { + return Status::Ok; + } } - } - } - S = filtered_S; - if (edges_added_this_iteration == 0) { - return Status::Ok; + return Status::Ok; } - } - return Status::Ok; -} #pragma endregion Mst -} // namespace spla +}// namespace spla diff --git a/src/io.cpp b/src/io.cpp index 578a20136..65c753441 100644 --- a/src/io.cpp +++ b/src/io.cpp @@ -51,421 +51,421 @@ namespace spla { -MtxLoader::MtxLoader(std::string name) : m_name(std::move(name)) {} - -bool MtxLoader::load(std::filesystem::path file_path, bool offset_indices, - bool make_undirected, bool remove_loops) { - m_file_path = std::move(file_path); - m_base_is_zero = offset_indices; - - std::fstream file(m_file_path, std::ios::in); - if (!file.is_open()) { - LOG_MSG(Status::Error, "failed to open file " << m_file_path); - return false; - } - - Timer t; - t.start(); - - std::size_t n_lines = 0; - std::size_t n_sort = 0; - - std::string line; - while (std::getline(file, line)) { - if (line[0] != '%') - break; - n_lines++; - } - - std::size_t nnz; - std::stringstream header(line); - header >> m_n_rows >> m_n_cols >> nnz; - - bool file_has_values = false; - if (line.find("pattern") == - std::string::npos) { // есть подстрока pattern => граф невзвешенный - file_has_values = true; - } - - std::cout << "Loading matrix-market coordinate format data... " << std::endl; - std::cout << " Reading from " << m_file_path << std::endl; - std::cout << " Matrix size " << m_n_rows << " rows, " << m_n_cols << " cols" - << std::endl; - std::cout << " Data: " << nnz << " directed edges" << std::endl; - if (remove_loops) - std::cout << " Opt: remove self-loops" << std::endl; - if (offset_indices) - std::cout << " Opt: offset indices by -1" << std::endl; - if (make_undirected) - std::cout << " Opt: double edges" << std::endl; - std::cout << " Reading data: "; - - // optimized reading by sliding window - const std::size_t BUFFER_CAPACITY = 1024 * 8; - std::size_t buffer_size = 0; - std::size_t buffer_offset = 0; - char buffer[BUFFER_CAPACITY + 1]; - - // read data - std::size_t to_count = 0; - std::size_t to_read = nnz; - std::size_t to_preallocate = to_read * (make_undirected ? 2 : 1); - std::vector Ai; - std::vector Aj; - std::vector Av; - - // preallocate to avoid copy - Ai.reserve(to_preallocate); - Aj.reserve(to_preallocate); - if (file_has_values) - Av.reserve(to_preallocate); - - float job_done = 0.0f; - float job_total = 35.0f; - - while (to_read > 0) { - to_count++; - to_read--; - n_lines++; - - // display current progress of reading the file - while (float(to_count) / float(nnz) > job_done / job_total) { - job_done += 1.0f; - std::cout << "|"; - } + MtxLoader::MtxLoader(std::string name) : m_name(std::move(name)) {} + + bool MtxLoader::load(std::filesystem::path file_path, bool offset_indices, + bool make_undirected, bool remove_loops) { + m_file_path = std::move(file_path); + m_base_is_zero = offset_indices; - // try to find where next line is ends up - bool line_found = false; - std::size_t line_end; - while (!line_found) { - line_end = buffer_offset; - - // travers buffer to find ending - while (line_end < buffer_size && buffer[line_end] != '\n') { - line_end += 1; - } - - // buffer not ended of file is ended - line_found = line_end < buffer_size || file.eof(); - - // not found in buffer, need to fetch more data - if (!line_found) { - assert(!file.eof()); - assert(buffer_offset <= BUFFER_CAPACITY); - - if (buffer_offset > 0) { - if (buffer_offset < BUFFER_CAPACITY) { - std::memcpy(buffer, buffer + buffer_offset, - BUFFER_CAPACITY - buffer_offset); - } - buffer_offset = BUFFER_CAPACITY - buffer_offset; + std::fstream file(m_file_path, std::ios::in); + if (!file.is_open()) { + LOG_MSG(Status::Error, "failed to open file " << m_file_path); + return false; } - auto bytes_to_read = BUFFER_CAPACITY - buffer_offset; - file.read(buffer + buffer_offset, std::streamsize(bytes_to_read)); - auto bytes_actually_read = file.gcount(); - buffer_size = buffer_offset + bytes_actually_read; - buffer_offset = 0; - buffer[buffer_size] = '\0'; - assert(buffer_size <= BUFFER_CAPACITY + 1); - } - } + Timer t; + t.start(); - char *end = nullptr; - auto i = uint(std::strtoll(buffer + buffer_offset, &end, 10)); - auto j = uint(std::strtoll(end, &end, 10)); - float val = 1.0f; // default value - - if (file_has_values) { - char *next = end; - while (*next == ' ' || *next == '\t') - next++; - if (*next != '\n' && *next != '\0') { - val = static_cast(std::strtod(next, &end)); - } - } - buffer_offset = line_end + 1; + std::size_t n_lines = 0; + std::size_t n_sort = 0; - assert(i > 0 && j > 0); + std::string line; + while (std::getline(file, line)) { + if (line[0] != '%') + break; + n_lines++; + } - if (remove_loops) { - if (i == j) - continue; - } - if (offset_indices) { - i -= 1; - j -= 1; - } - if (make_undirected) { - Ai.push_back(j); - Aj.push_back(i); - if (file_has_values) - Av.push_back(val); - } + std::size_t nnz; + std::stringstream header(line); + header >> m_n_rows >> m_n_cols >> nnz; - Ai.push_back(i); - Aj.push_back(j); - if (file_has_values) - Av.push_back(val); - } - t.lap_end(); // parsing - - if (file_has_values) { - struct Edge { - uint i, j; - float w; - bool operator<(const Edge &other) const { - if (i != other.i) - return i < other.i; - return j < other.j; - } - }; - - std::vector edges; - edges.reserve(Ai.size()); - for (std::size_t k = 0; k < Ai.size(); k++) { - edges.push_back({Ai[k], Aj[k], Av[k]}); - } + bool file_has_values = false; + if (line.find("pattern") == + std::string::npos) {// есть подстрока pattern => граф невзвешенный + file_has_values = true; + } - std::sort(edges.begin(), edges.end()); - - std::vector reduced_Ai; - std::vector reduced_Aj; - std::vector reduced_Av; - reduced_Ai.reserve(edges.size()); - reduced_Aj.reserve(edges.size()); - reduced_Av.reserve(edges.size()); - - for (std::size_t k = 0; k < edges.size(); k++) { - if (k == 0 || edges[k].i != edges[k - 1].i || - edges[k].j != edges[k - 1].j) { - reduced_Ai.push_back(edges[k].i); - reduced_Aj.push_back(edges[k].j); - reduced_Av.push_back(edges[k].w); - } - } + std::cout << "Loading matrix-market coordinate format data... " << std::endl; + std::cout << " Reading from " << m_file_path << std::endl; + std::cout << " Matrix size " << m_n_rows << " rows, " << m_n_cols << " cols" + << std::endl; + std::cout << " Data: " << nnz << " directed edges" << std::endl; + if (remove_loops) + std::cout << " Opt: remove self-loops" << std::endl; + if (offset_indices) + std::cout << " Opt: offset indices by -1" << std::endl; + if (make_undirected) + std::cout << " Opt: double edges" << std::endl; + std::cout << " Reading data: "; + + // optimized reading by sliding window + const std::size_t BUFFER_CAPACITY = 1024 * 8; + std::size_t buffer_size = 0; + std::size_t buffer_offset = 0; + char buffer[BUFFER_CAPACITY + 1]; + + // read data + std::size_t to_count = 0; + std::size_t to_read = nnz; + std::size_t to_preallocate = to_read * (make_undirected ? 2 : 1); + std::vector Ai; + std::vector Aj; + std::vector Av; + + // preallocate to avoid copy + Ai.reserve(to_preallocate); + Aj.reserve(to_preallocate); + if (file_has_values) + Av.reserve(to_preallocate); + + float job_done = 0.0f; + float job_total = 35.0f; + + while (to_read > 0) { + to_count++; + to_read--; + n_lines++; + + // display current progress of reading the file + while (float(to_count) / float(nnz) > job_done / job_total) { + job_done += 1.0f; + std::cout << "|"; + } + + // try to find where next line is ends up + bool line_found = false; + std::size_t line_end; + while (!line_found) { + line_end = buffer_offset; + + // travers buffer to find ending + while (line_end < buffer_size && buffer[line_end] != '\n') { + line_end += 1; + } + + // buffer not ended of file is ended + line_found = line_end < buffer_size || file.eof(); + + // not found in buffer, need to fetch more data + if (!line_found) { + assert(!file.eof()); + assert(buffer_offset <= BUFFER_CAPACITY); + + if (buffer_offset > 0) { + if (buffer_offset < BUFFER_CAPACITY) { + std::memcpy(buffer, buffer + buffer_offset, + BUFFER_CAPACITY - buffer_offset); + } + buffer_offset = BUFFER_CAPACITY - buffer_offset; + } + + auto bytes_to_read = BUFFER_CAPACITY - buffer_offset; + file.read(buffer + buffer_offset, std::streamsize(bytes_to_read)); + auto bytes_actually_read = file.gcount(); + buffer_size = buffer_offset + bytes_actually_read; + buffer_offset = 0; + buffer[buffer_size] = '\0'; + assert(buffer_size <= BUFFER_CAPACITY + 1); + } + } + + char* end = nullptr; + auto i = uint(std::strtoll(buffer + buffer_offset, &end, 10)); + auto j = uint(std::strtoll(end, &end, 10)); + float val = 1.0f;// default value + + if (file_has_values) { + char* next = end; + while (*next == ' ' || *next == '\t') + next++; + if (*next != '\n' && *next != '\0') { + val = static_cast(std::strtod(next, &end)); + } + } + buffer_offset = line_end + 1; + + assert(i > 0 && j > 0); + + if (remove_loops) { + if (i == j) + continue; + } + if (offset_indices) { + i -= 1; + j -= 1; + } + if (make_undirected) { + Ai.push_back(j); + Aj.push_back(i); + if (file_has_values) + Av.push_back(val); + } + + Ai.push_back(i); + Aj.push_back(j); + if (file_has_values) + Av.push_back(val); + } + t.lap_end();// parsing + + if (file_has_values) { + struct Edge { + uint i, j; + float w; + bool operator<(const Edge& other) const { + if (i != other.i) + return i < other.i; + return j < other.j; + } + }; + + std::vector edges; + edges.reserve(Ai.size()); + for (std::size_t k = 0; k < Ai.size(); k++) { + edges.push_back({Ai[k], Aj[k], Av[k]}); + } + + std::sort(edges.begin(), edges.end()); + + std::vector reduced_Ai; + std::vector reduced_Aj; + std::vector reduced_Av; + reduced_Ai.reserve(edges.size()); + reduced_Aj.reserve(edges.size()); + reduced_Av.reserve(edges.size()); + + for (std::size_t k = 0; k < edges.size(); k++) { + if (k == 0 || edges[k].i != edges[k - 1].i || + edges[k].j != edges[k - 1].j) { + reduced_Ai.push_back(edges[k].i); + reduced_Aj.push_back(edges[k].j); + reduced_Av.push_back(edges[k].w); + } + } + + m_n_values = reduced_Ai.size(); + m_Ai = std::move(reduced_Ai); + m_Aj = std::move(reduced_Aj); + m_Aw = std::move(reduced_Av); + + } else { + std::vector sorted; + { + sorted.reserve(Ai.size()); + n_sort = Ai.size(); + + for (std::size_t k = 0; k < Ai.size(); k++) { + std::uint64_t entry = 0; + entry |= std::uint64_t(Ai[k]) << 32u; + entry |= std::uint64_t(Aj[k]) << 0u; + sorted.push_back(entry); + } + Ai.clear(); + Aj.clear(); + + std::sort(sorted.begin(), sorted.end()); + } + t.lap_end();// sorting + + std::vector reduced_Ai; + std::vector reduced_Aj; + { + reduced_Ai.reserve(sorted.size()); + reduced_Aj.reserve(sorted.size()); + + std::uint64_t entry_prev = 0xffffffffffffffff; + for (std::uint64_t entry : sorted) { + if (entry_prev != entry) { + uint i = uint((entry >> 32u) & 0xffffffff); + uint j = uint((entry >> 0u) & 0xffffffff); + reduced_Ai.push_back(i); + reduced_Aj.push_back(j); + } + entry_prev = entry; + } + + m_n_values = reduced_Ai.size(); + m_Ai = std::move(reduced_Ai); + m_Aj = std::move(reduced_Aj); + } + } + + calc_stats(); + t.lap_end();// stats + + t.stop(); - m_n_values = reduced_Ai.size(); - m_Ai = std::move(reduced_Ai); - m_Aj = std::move(reduced_Aj); - m_Aw = std::move(reduced_Av); - - } else { - std::vector sorted; - { - sorted.reserve(Ai.size()); - n_sort = Ai.size(); - - for (std::size_t k = 0; k < Ai.size(); k++) { - std::uint64_t entry = 0; - entry |= std::uint64_t(Ai[k]) << 32u; - entry |= std::uint64_t(Aj[k]) << 0u; - sorted.push_back(entry); - } - Ai.clear(); - Aj.clear(); - - std::sort(sorted.begin(), sorted.end()); + std::cout << " 100%" << std::endl; + std::cout << " Parsed in " << t.get_laps_ms()[0] * 1e-3 << " sec " << n_lines + << " lines" + << " speed " << float(n_lines) / (t.get_laps_ms()[0] * 1e-3) + << " lines/sec" << std::endl; + std::cout << " Sorted in " << t.get_laps_ms()[1] * 1e-3 << " sec " << n_sort + << " lines" << std::endl; + std::cout << " Reduced in " << t.get_laps_ms()[2] * 1e-3 << " sec " + << m_n_values << " lines" << std::endl; + std::cout << " Calc stats in " << t.get_laps_ms()[3] * 1e-3 << " sec" + << std::endl; + std::cout << " Loaded in " << t.get_elapsed_ms() * 1e-3 << " sec, " + << m_n_values << " edges total" << std::endl; + + output_stats(); + + return true; } - t.lap_end(); // sorting - - std::vector reduced_Ai; - std::vector reduced_Aj; - { - reduced_Ai.reserve(sorted.size()); - reduced_Aj.reserve(sorted.size()); - - std::uint64_t entry_prev = 0xffffffffffffffff; - for (std::uint64_t entry : sorted) { - if (entry_prev != entry) { - uint i = uint((entry >> 32u) & 0xffffffff); - uint j = uint((entry >> 0u) & 0xffffffff); - reduced_Ai.push_back(i); - reduced_Aj.push_back(j); + + bool MtxLoader::save(const std::filesystem::path& file_path, bool stats_only) { + std::fstream file(file_path, std::ios::out); + + if (!file.is_open()) { + LOG_MSG(Status::Error, "failed to open file " << file_path); + return false; + } + + file << "%%MatrixMarket matrix coordinate pattern general\n"; + file << "%-------------------------------------------------------------------" + "------------\n"; + file << "%-------------------------------------------------------------------" + "------------\n"; + + file << "% meta-info:\n"; + file << "% name: " << m_name << "\n"; + file << "% source-file: " << m_file_path << "\n"; + file << "% deg-avg: " << m_deg_avg << "\n"; + file << "% deg-sd: " << m_deg_sd << "\n"; + file << "% deg-min: " << m_deg_min << "\n"; + file << "% deg-max: " << m_deg_max << "\n"; + file << "% deg-distribution: \n"; + + for (std::size_t i = 0; i < m_deg_distribution.size(); i++) { + file << "% " << m_deg_ranges[i] << " " << m_deg_ranges[i + 1] << " " + << m_deg_distribution[i] << "\n"; + } + + file << "%-------------------------------------------------------------------" + "------------\n"; + file << m_n_rows << " " << m_n_cols << " " << m_n_values << "\n"; + + if (!stats_only) { + const uint offset = m_base_is_zero ? 1 : 0; + for (std::size_t k = 0; k < m_n_values; k++) { + file << m_Ai[k] + offset << " " << m_Aj[k] + offset << "\n"; + } } - entry_prev = entry; - } - m_n_values = reduced_Ai.size(); - m_Ai = std::move(reduced_Ai); - m_Aj = std::move(reduced_Aj); + return true; } - } - - calc_stats(); - t.lap_end(); // stats - - t.stop(); - - std::cout << " 100%" << std::endl; - std::cout << " Parsed in " << t.get_laps_ms()[0] * 1e-3 << " sec " << n_lines - << " lines" - << " speed " << float(n_lines) / (t.get_laps_ms()[0] * 1e-3) - << " lines/sec" << std::endl; - std::cout << " Sorted in " << t.get_laps_ms()[1] * 1e-3 << " sec " << n_sort - << " lines" << std::endl; - std::cout << " Reduced in " << t.get_laps_ms()[2] * 1e-3 << " sec " - << m_n_values << " lines" << std::endl; - std::cout << " Calc stats in " << t.get_laps_ms()[3] * 1e-3 << " sec" - << std::endl; - std::cout << " Loaded in " << t.get_elapsed_ms() * 1e-3 << " sec, " - << m_n_values << " edges total" << std::endl; - - output_stats(); - - return true; -} - -bool MtxLoader::save(const std::filesystem::path &file_path, bool stats_only) { - std::fstream file(file_path, std::ios::out); - - if (!file.is_open()) { - LOG_MSG(Status::Error, "failed to open file " << file_path); - return false; - } - - file << "%%MatrixMarket matrix coordinate pattern general\n"; - file << "%-------------------------------------------------------------------" - "------------\n"; - file << "%-------------------------------------------------------------------" - "------------\n"; - - file << "% meta-info:\n"; - file << "% name: " << m_name << "\n"; - file << "% source-file: " << m_file_path << "\n"; - file << "% deg-avg: " << m_deg_avg << "\n"; - file << "% deg-sd: " << m_deg_sd << "\n"; - file << "% deg-min: " << m_deg_min << "\n"; - file << "% deg-max: " << m_deg_max << "\n"; - file << "% deg-distribution: \n"; - - for (std::size_t i = 0; i < m_deg_distribution.size(); i++) { - file << "% " << m_deg_ranges[i] << " " << m_deg_ranges[i + 1] << " " - << m_deg_distribution[i] << "\n"; - } - - file << "%-------------------------------------------------------------------" - "------------\n"; - file << m_n_rows << " " << m_n_cols << " " << m_n_values << "\n"; - - if (!stats_only) { - const uint offset = m_base_is_zero ? 1 : 0; - for (std::size_t k = 0; k < m_n_values; k++) { - file << m_Ai[k] + offset << " " << m_Aj[k] + offset << "\n"; + + void MtxLoader::calc_stats() { + std::vector deg_pre_vertex(m_n_rows, 0.0f); + + for (auto i : m_Ai) { + deg_pre_vertex[m_base_is_zero ? i : i - 1] += 1; + } + + m_deg_sd = 0.0; + m_deg_avg = 0.0; + m_deg_max = -1.0; + m_deg_min = 1.0 + static_cast(m_n_values); + + for (auto deg : deg_pre_vertex) { + m_deg_min = std::min(m_deg_min, static_cast(deg)); + m_deg_max = std::max(m_deg_max, static_cast(deg)); + m_deg_avg += deg; + m_deg_sd += deg * deg; + } + + auto n = static_cast(m_n_rows); + + m_deg_avg = m_deg_avg / n; + m_deg_sd = std::sqrt(n * (m_deg_sd / n - m_deg_avg * m_deg_avg) / + (n > 1.0 ? n - 1.0 : 1.0)); + + const uint GROUPS_COUNT_MAX = + std::max(uint(10), uint(std::log2(double(m_n_rows) * 0.77))); + + std::vector count_per_deg(static_cast(m_deg_max) + 2, 0); + std::vector count_per_deg_offsets(static_cast(m_deg_max) + 2, 0); + + for (uint i = 0; i < m_n_rows; i++) { + count_per_deg[std::min(deg_pre_vertex[i], uint(m_deg_max))] += 1; + } + + std::exclusive_scan(count_per_deg.begin(), count_per_deg.end(), + count_per_deg_offsets.begin(), 0); + count_per_deg_offsets.back() += 1; + + std::vector distributions; + std::vector ranges; + + auto range = m_deg_max - m_deg_min; + auto groups_count = + std::max(std::min(GROUPS_COUNT_MAX, static_cast(range)), 1u); + auto g = static_cast(groups_count); + + auto total = static_cast(count_per_deg_offsets.back()); + auto from = count_per_deg_offsets.begin(); + + ranges.push_back(static_cast(m_deg_min)); + for (uint i = 0; i < groups_count; ++i) { + auto next = (from + 1 == count_per_deg_offsets.end()) ? from : from + 1; + auto to = std::lower_bound( + next, count_per_deg_offsets.end(), + static_cast(total / g * static_cast(i + 1))); + auto to_offset = std::distance(count_per_deg_offsets.begin(), to); + + assert(to != count_per_deg_offsets.end()); + + distributions.push_back(static_cast(*to - *from) / total); + ranges.push_back(static_cast(to_offset)); + from = to; + } + + m_deg_distribution = std::move(distributions); + m_deg_ranges = std::move(ranges); } - } - - return true; -} - -void MtxLoader::calc_stats() { - std::vector deg_pre_vertex(m_n_rows, 0.0f); - - for (auto i : m_Ai) { - deg_pre_vertex[m_base_is_zero ? i : i - 1] += 1; - } - - m_deg_sd = 0.0; - m_deg_avg = 0.0; - m_deg_max = -1.0; - m_deg_min = 1.0 + static_cast(m_n_values); - - for (auto deg : deg_pre_vertex) { - m_deg_min = std::min(m_deg_min, static_cast(deg)); - m_deg_max = std::max(m_deg_max, static_cast(deg)); - m_deg_avg += deg; - m_deg_sd += deg * deg; - } - - auto n = static_cast(m_n_rows); - - m_deg_avg = m_deg_avg / n; - m_deg_sd = std::sqrt(n * (m_deg_sd / n - m_deg_avg * m_deg_avg) / - (n > 1.0 ? n - 1.0 : 1.0)); - - const uint GROUPS_COUNT_MAX = - std::max(uint(10), uint(std::log2(double(m_n_rows) * 0.77))); - - std::vector count_per_deg(static_cast(m_deg_max) + 2, 0); - std::vector count_per_deg_offsets(static_cast(m_deg_max) + 2, 0); - - for (uint i = 0; i < m_n_rows; i++) { - count_per_deg[std::min(deg_pre_vertex[i], uint(m_deg_max))] += 1; - } - - std::exclusive_scan(count_per_deg.begin(), count_per_deg.end(), - count_per_deg_offsets.begin(), 0); - count_per_deg_offsets.back() += 1; - - std::vector distributions; - std::vector ranges; - - auto range = m_deg_max - m_deg_min; - auto groups_count = - std::max(std::min(GROUPS_COUNT_MAX, static_cast(range)), 1u); - auto g = static_cast(groups_count); - - auto total = static_cast(count_per_deg_offsets.back()); - auto from = count_per_deg_offsets.begin(); - - ranges.push_back(static_cast(m_deg_min)); - for (uint i = 0; i < groups_count; ++i) { - auto next = (from + 1 == count_per_deg_offsets.end()) ? from : from + 1; - auto to = std::lower_bound( - next, count_per_deg_offsets.end(), - static_cast(total / g * static_cast(i + 1))); - auto to_offset = std::distance(count_per_deg_offsets.begin(), to); - - assert(to != count_per_deg_offsets.end()); - - distributions.push_back(static_cast(*to - *from) / total); - ranges.push_back(static_cast(to_offset)); - from = to; - } - - m_deg_distribution = std::move(distributions); - m_deg_ranges = std::move(ranges); -} -void MtxLoader::output_stats() { - std::cout << " " - << "deg: " - << "min " << m_deg_min << ", " - << "max " << m_deg_max << ", " - << "avg " << m_deg_avg << ", " - << "sd " << m_deg_sd << std::endl; - - std::cout << " distribution:" << std::endl; - - const auto n = static_cast(m_n_rows); - const auto default_precision{std::cout.precision()}; - const auto n_digits = static_cast(std::log10(n) + 1.0); - - const double DISPLAY_DENSITY = - std::max(double(100), double(m_deg_distribution.size())); - - for (std::size_t i = 0; i < m_deg_distribution.size(); i++) { - auto deg = m_deg_distribution[i] >= 0.01 ? m_deg_distribution[i] : 0.0; - auto k_start = m_deg_ranges[i]; - auto k_end = m_deg_ranges[i + 1]; - auto k_count = std::round(static_cast(deg * DISPLAY_DENSITY)); - - std::cout << " [" << std::setw(n_digits) << k_start << " - " - << std::setw(n_digits) << k_end << "): "; - std::cout << std::setw(6) << std::setprecision(2) << deg * 100.0 - << std::setprecision(default_precision) << "% "; - for (uint s = 0; s < k_count; ++s) { - std::cout << "*"; + void MtxLoader::output_stats() { + std::cout << " " + << "deg: " + << "min " << m_deg_min << ", " + << "max " << m_deg_max << ", " + << "avg " << m_deg_avg << ", " + << "sd " << m_deg_sd << std::endl; + + std::cout << " distribution:" << std::endl; + + const auto n = static_cast(m_n_rows); + const auto default_precision{std::cout.precision()}; + const auto n_digits = static_cast(std::log10(n) + 1.0); + + const double DISPLAY_DENSITY = + std::max(double(100), double(m_deg_distribution.size())); + + for (std::size_t i = 0; i < m_deg_distribution.size(); i++) { + auto deg = m_deg_distribution[i] >= 0.01 ? m_deg_distribution[i] : 0.0; + auto k_start = m_deg_ranges[i]; + auto k_end = m_deg_ranges[i + 1]; + auto k_count = std::round(static_cast(deg * DISPLAY_DENSITY)); + + std::cout << " [" << std::setw(n_digits) << k_start << " - " + << std::setw(n_digits) << k_end << "): "; + std::cout << std::setw(6) << std::setprecision(2) << deg * 100.0 + << std::setprecision(default_precision) << "% "; + for (uint s = 0; s < k_count; ++s) { + std::cout << "*"; + } + std::cout << std::endl; + } } - std::cout << std::endl; - } -} -const std::vector &MtxLoader::get_Ai() const { return m_Ai; } -const std::vector &MtxLoader::get_Aj() const { return m_Aj; } -const std::vector &MtxLoader::get_Aw() const { return m_Aw; } + const std::vector& MtxLoader::get_Ai() const { return m_Ai; } + const std::vector& MtxLoader::get_Aj() const { return m_Aj; } + const std::vector& MtxLoader::get_Aw() const { return m_Aw; } -uint MtxLoader::get_n_rows() const { return m_n_rows; } -uint MtxLoader::get_n_cols() const { return m_n_cols; } -std::size_t MtxLoader::get_n_values() const { return m_n_values; } + uint MtxLoader::get_n_rows() const { return m_n_rows; } + uint MtxLoader::get_n_cols() const { return m_n_cols; } + std::size_t MtxLoader::get_n_values() const { return m_n_values; } -} // namespace spla \ No newline at end of file +}// namespace spla \ No newline at end of file diff --git a/src/matrix.cpp b/src/matrix.cpp index ea9839250..aa319ca3b 100644 --- a/src/matrix.cpp +++ b/src/matrix.cpp @@ -52,7 +52,7 @@ namespace spla { return ref_ptr(new TMatrix(n_rows, n_cols)); } if (type == PAIR) { - return ref_ptr(new TMatrix(n_rows, n_cols)); + return ref_ptr(new TMatrix(n_rows, n_cols)); } LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); diff --git a/src/op.cpp b/src/op.cpp index 9ac1049ab..8509dc750 100644 --- a/src/op.cpp +++ b/src/op.cpp @@ -43,379 +43,381 @@ namespace spla { -ref_ptr IDENTITY_INT; -ref_ptr IDENTITY_UINT; -ref_ptr IDENTITY_FLOAT; -ref_ptr AINV_INT; -ref_ptr AINV_UINT; -ref_ptr AINV_FLOAT; -ref_ptr MINV_INT; -ref_ptr MINV_UINT; -ref_ptr MINV_FLOAT; -ref_ptr LNOT_INT; -ref_ptr LNOT_UINT; -ref_ptr LNOT_FLOAT; -ref_ptr UONE_INT; -ref_ptr UONE_UINT; -ref_ptr UONE_FLOAT; -ref_ptr ABS_INT; -ref_ptr ABS_UINT; -ref_ptr ABS_FLOAT; - -ref_ptr BNOT_INT; -ref_ptr BNOT_UINT; - -ref_ptr SQRT_FLOAT; -ref_ptr LOG_FLOAT; -ref_ptr EXP_FLOAT; -ref_ptr SIN_FLOAT; -ref_ptr COS_FLOAT; -ref_ptr TAN_FLOAT; -ref_ptr ASIN_FLOAT; -ref_ptr ACOS_FLOAT; -ref_ptr ATAN_FLOAT; -ref_ptr CEIL_FLOAT; -ref_ptr FLOOR_FLOAT; -ref_ptr ROUND_FLOAT; -ref_ptr TRUNC_FLOAT; -ref_ptr IDENTITY_PAIR; - -////////////////////////////////////////////////////////////////////////////// - -ref_ptr PLUS_INT; -ref_ptr PLUS_UINT; -ref_ptr PLUS_FLOAT; -ref_ptr MINUS_INT; -ref_ptr MINUS_UINT; -ref_ptr MINUS_FLOAT; -ref_ptr MULT_INT; -ref_ptr MULT_UINT; -ref_ptr MULT_FLOAT; -ref_ptr DIV_INT; -ref_ptr DIV_UINT; -ref_ptr DIV_FLOAT; - -ref_ptr MINUS_POW2_INT; -ref_ptr MINUS_POW2_UINT; -ref_ptr MINUS_POW2_FLOAT; - -ref_ptr FIRST_INT; -ref_ptr FIRST_UINT; -ref_ptr FIRST_FLOAT; -ref_ptr SECOND_INT; -ref_ptr SECOND_UINT; -ref_ptr SECOND_FLOAT; - -ref_ptr BONE_INT; -ref_ptr BONE_UINT; -ref_ptr BONE_FLOAT; - -ref_ptr MIN_INT; -ref_ptr MIN_UINT; -ref_ptr MIN_FLOAT; -ref_ptr MAX_INT; -ref_ptr MAX_UINT; -ref_ptr MAX_FLOAT; - -ref_ptr LOR_INT; -ref_ptr LOR_UINT; -ref_ptr LOR_FLOAT; -ref_ptr LAND_INT; -ref_ptr LAND_UINT; -ref_ptr LAND_FLOAT; - -ref_ptr BOR_INT; -ref_ptr BOR_UINT; -ref_ptr BAND_INT; -ref_ptr BAND_UINT; -ref_ptr BXOR_INT; -ref_ptr BXOR_UINT; - -ref_ptr MIN_PAIR; -ref_ptr MUL_PAIR; - -////////////////////////////////////////////////////////////////////////////// - -ref_ptr EQZERO_INT; -ref_ptr EQZERO_UINT; -ref_ptr EQZERO_FLOAT; -ref_ptr NQZERO_INT; -ref_ptr NQZERO_UINT; -ref_ptr NQZERO_FLOAT; -ref_ptr GTZERO_INT; -ref_ptr GTZERO_UINT; -ref_ptr GTZERO_FLOAT; -ref_ptr GEZERO_INT; -ref_ptr GEZERO_UINT; -ref_ptr GEZERO_FLOAT; -ref_ptr LTZERO_INT; -ref_ptr LTZERO_UINT; -ref_ptr LTZERO_FLOAT; -ref_ptr LEZERO_INT; -ref_ptr LEZERO_UINT; -ref_ptr LEZERO_FLOAT; -ref_ptr ALWAYS_INT; -ref_ptr ALWAYS_UINT; -ref_ptr ALWAYS_FLOAT; -ref_ptr ALWAYS_PAIR; -ref_ptr NEVER_INT; -ref_ptr NEVER_UINT; -ref_ptr NEVER_FLOAT; - -template inline T min(T a, T b) { return std::min(a, b); } - -template inline T max(T a, T b) { return std::max(a, b); } - -void register_ops() { - DECL_OP_UNA_S(IDENTITY_INT, IDENSTITY, T_INT, { return a; }); - DECL_OP_UNA_S(IDENTITY_UINT, IDENSTITY, T_UINT, { return a; }); - DECL_OP_UNA_S(IDENTITY_FLOAT, IDENSTITY, T_FLOAT, { return a; }); - DECL_OP_UNA_S(AINV_INT, AINV, T_INT, { return -a; }); - DECL_OP_UNA_S(AINV_UINT, AINV, T_UINT, { return -a; }); - DECL_OP_UNA_S(AINV_FLOAT, AINV, T_FLOAT, { return -a; }); - DECL_OP_UNA_S(MINV_INT, MINV, T_INT, { return 1 / a; }); - DECL_OP_UNA_S(MINV_UINT, MINV, T_UINT, { return 1 / a; }); - DECL_OP_UNA_S(MINV_FLOAT, MINV, T_FLOAT, { return 1.0f / a; }); - DECL_OP_UNA_S(LNOT_INT, LNOT, T_INT, { return !(a != 0); }); - DECL_OP_UNA_S(LNOT_UINT, LNOT, T_UINT, { return !(a != 0); }); - DECL_OP_UNA_S(LNOT_FLOAT, LNOT, T_FLOAT, { return !(a != 0); }); - DECL_OP_UNA_S(UONE_INT, UONE, T_INT, { return 1; }); - DECL_OP_UNA_S(UONE_UINT, UONE, T_UINT, { return 1; }); - DECL_OP_UNA_S(UONE_FLOAT, UONE, T_FLOAT, { return 1; }); - DECL_OP_UNA_S(ABS_INT, ABS, T_INT, { return abs(a); }); - DECL_OP_UNA_S(ABS_UINT, ABS, T_UINT, { return a; }); - DECL_OP_UNA_S(ABS_FLOAT, ABS, T_FLOAT, { return fabs(a); }); - - DECL_OP_UNA_S(BNOT_INT, BNOT, T_INT, { return ~a; }); - DECL_OP_UNA_S(BNOT_UINT, BNOT, T_UINT, { return ~a; }); - - DECL_OP_UNA_S(SQRT_FLOAT, SQRT, T_FLOAT, { return sqrt(a); }); - DECL_OP_UNA_S(LOG_FLOAT, LOG, T_FLOAT, { return log(a); }); - DECL_OP_UNA_S(EXP_FLOAT, EXP, T_FLOAT, { return exp(a); }); - DECL_OP_UNA_S(SIN_FLOAT, SIN, T_FLOAT, { return sin(a); }); - DECL_OP_UNA_S(COS_FLOAT, COS, T_FLOAT, { return cos(a); }); - DECL_OP_UNA_S(TAN_FLOAT, TAN, T_FLOAT, { return tan(a); }); - DECL_OP_UNA_S(ASIN_FLOAT, ASIN, T_FLOAT, { return asin(a); }); - DECL_OP_UNA_S(ACOS_FLOAT, ACOS, T_FLOAT, { return acos(a); }); - DECL_OP_UNA_S(ATAN_FLOAT, ATAN, T_FLOAT, { return atan(a); }); - DECL_OP_UNA_S(CEIL_FLOAT, CEIL, T_FLOAT, { return ceil(a); }); - DECL_OP_UNA_S(FLOOR_FLOAT, FLOOR, T_FLOAT, { return floor(a); }); - DECL_OP_UNA_S(ROUND_FLOAT, ROUND, T_FLOAT, { return round(a); }); - DECL_OP_UNA_S(TRUNC_FLOAT, TRUNC, T_FLOAT, { return trunc(a); }); - IDENTITY_PAIR = spla::OpUnary::make_pair( - "IDENTITY_PAIR", "(a) identity_pair(a)", [](Pair a) { return a; }); - - DECL_OP_BIN_S(PLUS_INT, PLUS, T_INT, { return a + b; }); - DECL_OP_BIN_S(PLUS_UINT, PLUS, T_UINT, { return a + b; }); - DECL_OP_BIN_S(PLUS_FLOAT, PLUS, T_FLOAT, { return a + b; }); - DECL_OP_BIN_S(MINUS_INT, MINUS, T_INT, { return a - b; }); - DECL_OP_BIN_S(MINUS_UINT, MINUS, T_UINT, { return a - b; }); - DECL_OP_BIN_S(MINUS_FLOAT, MINUS, T_FLOAT, { return a - b; }); - DECL_OP_BIN_S(MULT_INT, MULT, T_INT, { return a * b; }); - DECL_OP_BIN_S(MULT_UINT, MULT, T_UINT, { return a * b; }); - DECL_OP_BIN_S(MULT_FLOAT, MULT, T_FLOAT, { return a * b; }); - DECL_OP_BIN_S(DIV_INT, DIV, T_INT, { return a / b; }); - DECL_OP_BIN_S(DIV_UINT, DIV, T_UINT, { return a / b; }); - DECL_OP_BIN_S(DIV_FLOAT, DIV, T_FLOAT, { return a / b; }); - - DECL_OP_BIN_S(MINUS_POW2_INT, MINUS_POW2, T_INT, - { return (a - b) * (a - b); }); - DECL_OP_BIN_S(MINUS_POW2_UINT, MINUS_POW2, T_UINT, - { return (a - b) * (a - b); }); - DECL_OP_BIN_S(MINUS_POW2_FLOAT, MINUS_POW2, T_FLOAT, - { return (a - b) * (a - b); }); - - DECL_OP_BIN_S(FIRST_INT, FIRST, T_INT, { return a; }); - DECL_OP_BIN_S(FIRST_UINT, FIRST, T_UINT, { return a; }); - DECL_OP_BIN_S(FIRST_FLOAT, FIRST, T_FLOAT, { return a; }); - DECL_OP_BIN_S(SECOND_INT, SECOND, T_INT, { return b; }); - DECL_OP_BIN_S(SECOND_UINT, SECOND, T_UINT, { return b; }); - DECL_OP_BIN_S(SECOND_FLOAT, SECOND, T_FLOAT, { return b; }); - - DECL_OP_BIN_S(BONE_INT, BONE, T_INT, { return 1; }); - DECL_OP_BIN_S(BONE_UINT, BONE, T_UINT, { return 1; }); - DECL_OP_BIN_S(BONE_FLOAT, BONE, T_FLOAT, { return 1; }); - - DECL_OP_BIN_S(MIN_INT, MIN, T_INT, { return min(a, b); }); - DECL_OP_BIN_S(MIN_UINT, MIN, T_UINT, { return min(a, b); }); - DECL_OP_BIN_S(MIN_FLOAT, MIN, T_FLOAT, { return min(a, b); }); - DECL_OP_BIN_S(MAX_INT, MAX, T_INT, { return max(a, b); }); - DECL_OP_BIN_S(MAX_UINT, MAX, T_UINT, { return max(a, b); }); - DECL_OP_BIN_S(MAX_FLOAT, MAX, T_FLOAT, { return max(a, b); }); - - DECL_OP_BIN_S(LOR_INT, LOR, T_INT, { return a || b; }); - DECL_OP_BIN_S(LOR_UINT, LOR, T_UINT, { return a || b; }); - DECL_OP_BIN_S(LOR_FLOAT, LOR, T_FLOAT, { return a || b; }); - DECL_OP_BIN_S(LAND_INT, LAND, T_INT, { return a && b; }); - DECL_OP_BIN_S(LAND_UINT, LAND, T_UINT, { return a && b; }); - DECL_OP_BIN_S(LAND_FLOAT, LAND, T_FLOAT, { return a && b; }); - - DECL_OP_BIN_S(BOR_INT, BOR, T_INT, { return a | b; }); - DECL_OP_BIN_S(BOR_UINT, BOR, T_UINT, { return a | b; }); - DECL_OP_BIN_S(BAND_INT, BAND, T_INT, { return a & b; }); - DECL_OP_BIN_S(BAND_UINT, BAND, T_UINT, { return a & b; }); - DECL_OP_BIN_S(BXOR_INT, BXOR, T_INT, { return a ^ b; }); - DECL_OP_BIN_S(BXOR_UINT, BXOR, T_UINT, { return a ^ b; }); - - MUL_PAIR = OpBinary::make_pair( - "MUL_PAIR", "(a, b) make_pair(a.weight, b.vertex)", - [](Pair a, Pair b) { return Pair(a.weight, b.vertex); }); - MIN_PAIR = OpBinary::make_pair("MIN_PAIR", "(a, b) min_pair(a, b)", - [](Pair a, Pair b) { - if (a.weight == b.weight) - return a.vertex < b.vertex ? a : b; - return a.weight < b.weight ? a : b; - }); - - DECL_OP_SELECT(EQZERO_INT, EQZERO, T_INT, { return a == 0; }); - DECL_OP_SELECT(EQZERO_UINT, EQZERO, T_UINT, { return a == 0; }); - DECL_OP_SELECT(EQZERO_FLOAT, EQZERO, T_FLOAT, { return a == 0; }); - DECL_OP_SELECT(NQZERO_INT, NQZERO, T_INT, { return a != 0; }); - DECL_OP_SELECT(NQZERO_UINT, NQZERO, T_UINT, { return a != 0; }); - DECL_OP_SELECT(NQZERO_FLOAT, NQZERO, T_FLOAT, { return a != 0; }); - DECL_OP_SELECT(GTZERO_INT, GTZERO, T_INT, { return a > 0; }); - DECL_OP_SELECT(GTZERO_UINT, GTZERO, T_UINT, { return a > 0; }); - DECL_OP_SELECT(GTZERO_FLOAT, GTZERO, T_FLOAT, { return a > 0; }); - DECL_OP_SELECT(GEZERO_INT, GEZERO, T_INT, { return a >= 0; }); - DECL_OP_SELECT(GEZERO_UINT, GEZERO, T_UINT, { return a >= 0; }); - DECL_OP_SELECT(GEZERO_FLOAT, GEZERO, T_FLOAT, { return a >= 0; }); - DECL_OP_SELECT(LTZERO_INT, LTZERO, T_INT, { return a < 0; }); - DECL_OP_SELECT(LTZERO_UINT, LTZERO, T_UINT, { return a < 0; }); - DECL_OP_SELECT(LTZERO_FLOAT, LTZERO, T_FLOAT, { return a < 0; }); - DECL_OP_SELECT(LEZERO_INT, LEZERO, T_INT, { return a <= 0; }); - DECL_OP_SELECT(LEZERO_UINT, LEZERO, T_UINT, { return a <= 0; }); - DECL_OP_SELECT(LEZERO_FLOAT, LEZERO, T_FLOAT, { return a <= 0; }); - DECL_OP_SELECT(ALWAYS_INT, ALWAYS, T_INT, { return 1; }); - DECL_OP_SELECT(ALWAYS_UINT, ALWAYS, T_UINT, { return 1; }); - DECL_OP_SELECT(ALWAYS_FLOAT, ALWAYS, T_FLOAT, { return 1; }); - ALWAYS_PAIR = OpSelect::make_pair("ALWAYS_PAIR", "(a) pair_always(a)", - [](Pair a) { return 1; }); - DECL_OP_SELECT(NEVER_INT, NEVER, T_INT, { return 0; }); - DECL_OP_SELECT(NEVER_UINT, NEVER, T_UINT, { return 0; }); - DECL_OP_SELECT(NEVER_FLOAT, NEVER, T_FLOAT, { return 0; }); -} - -ref_ptr OpUnary::make_int(std::string name, std::string code, - std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + - op->get_type_res()->get_code(); - return op.as(); -} -ref_ptr OpUnary::make_uint(std::string name, std::string code, - std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + - op->get_type_res()->get_code(); - return op.as(); -} -ref_ptr OpUnary::make_float(std::string name, std::string code, - std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + - op->get_type_res()->get_code(); - return op.as(); -} -ref_ptr OpUnary::make_pair(std::string name, std::string code, - std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + - op->get_type_res()->get_code(); - return op.as(); -} - -ref_ptr -OpBinary::make_int(std::string name, std::string code, - std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + - op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); - return op.as(); -} -ref_ptr -OpBinary::make_uint(std::string name, std::string code, - std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + - op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); - return op.as(); -} -ref_ptr -OpBinary::make_float(std::string name, std::string code, - std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + - op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); - return op.as(); -} -ref_ptr -OpBinary::make_pair(std::string name, std::string code, - std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + - op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); - return op.as(); -} - -ref_ptr OpSelect::make_int(std::string name, std::string code, - std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code(); - return op.as(); -} -ref_ptr OpSelect::make_uint(std::string name, std::string code, - std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code(); - return op.as(); -} -ref_ptr OpSelect::make_float(std::string name, std::string code, - std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code(); - return op.as(); -} -ref_ptr OpSelect::make_pair(std::string name, std::string code, - std::function function) { - auto op = make_ref>(); - op->name = std::move(name); - op->function = std::move(function); - op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code(); - return op.as(); -} - -} // namespace spla \ No newline at end of file + ref_ptr IDENTITY_INT; + ref_ptr IDENTITY_UINT; + ref_ptr IDENTITY_FLOAT; + ref_ptr AINV_INT; + ref_ptr AINV_UINT; + ref_ptr AINV_FLOAT; + ref_ptr MINV_INT; + ref_ptr MINV_UINT; + ref_ptr MINV_FLOAT; + ref_ptr LNOT_INT; + ref_ptr LNOT_UINT; + ref_ptr LNOT_FLOAT; + ref_ptr UONE_INT; + ref_ptr UONE_UINT; + ref_ptr UONE_FLOAT; + ref_ptr ABS_INT; + ref_ptr ABS_UINT; + ref_ptr ABS_FLOAT; + + ref_ptr BNOT_INT; + ref_ptr BNOT_UINT; + + ref_ptr SQRT_FLOAT; + ref_ptr LOG_FLOAT; + ref_ptr EXP_FLOAT; + ref_ptr SIN_FLOAT; + ref_ptr COS_FLOAT; + ref_ptr TAN_FLOAT; + ref_ptr ASIN_FLOAT; + ref_ptr ACOS_FLOAT; + ref_ptr ATAN_FLOAT; + ref_ptr CEIL_FLOAT; + ref_ptr FLOOR_FLOAT; + ref_ptr ROUND_FLOAT; + ref_ptr TRUNC_FLOAT; + ref_ptr IDENTITY_PAIR; + + ////////////////////////////////////////////////////////////////////////////// + + ref_ptr PLUS_INT; + ref_ptr PLUS_UINT; + ref_ptr PLUS_FLOAT; + ref_ptr MINUS_INT; + ref_ptr MINUS_UINT; + ref_ptr MINUS_FLOAT; + ref_ptr MULT_INT; + ref_ptr MULT_UINT; + ref_ptr MULT_FLOAT; + ref_ptr DIV_INT; + ref_ptr DIV_UINT; + ref_ptr DIV_FLOAT; + + ref_ptr MINUS_POW2_INT; + ref_ptr MINUS_POW2_UINT; + ref_ptr MINUS_POW2_FLOAT; + + ref_ptr FIRST_INT; + ref_ptr FIRST_UINT; + ref_ptr FIRST_FLOAT; + ref_ptr SECOND_INT; + ref_ptr SECOND_UINT; + ref_ptr SECOND_FLOAT; + + ref_ptr BONE_INT; + ref_ptr BONE_UINT; + ref_ptr BONE_FLOAT; + + ref_ptr MIN_INT; + ref_ptr MIN_UINT; + ref_ptr MIN_FLOAT; + ref_ptr MAX_INT; + ref_ptr MAX_UINT; + ref_ptr MAX_FLOAT; + + ref_ptr LOR_INT; + ref_ptr LOR_UINT; + ref_ptr LOR_FLOAT; + ref_ptr LAND_INT; + ref_ptr LAND_UINT; + ref_ptr LAND_FLOAT; + + ref_ptr BOR_INT; + ref_ptr BOR_UINT; + ref_ptr BAND_INT; + ref_ptr BAND_UINT; + ref_ptr BXOR_INT; + ref_ptr BXOR_UINT; + + ref_ptr MIN_PAIR; + ref_ptr MUL_PAIR; + + ////////////////////////////////////////////////////////////////////////////// + + ref_ptr EQZERO_INT; + ref_ptr EQZERO_UINT; + ref_ptr EQZERO_FLOAT; + ref_ptr NQZERO_INT; + ref_ptr NQZERO_UINT; + ref_ptr NQZERO_FLOAT; + ref_ptr GTZERO_INT; + ref_ptr GTZERO_UINT; + ref_ptr GTZERO_FLOAT; + ref_ptr GEZERO_INT; + ref_ptr GEZERO_UINT; + ref_ptr GEZERO_FLOAT; + ref_ptr LTZERO_INT; + ref_ptr LTZERO_UINT; + ref_ptr LTZERO_FLOAT; + ref_ptr LEZERO_INT; + ref_ptr LEZERO_UINT; + ref_ptr LEZERO_FLOAT; + ref_ptr ALWAYS_INT; + ref_ptr ALWAYS_UINT; + ref_ptr ALWAYS_FLOAT; + ref_ptr ALWAYS_PAIR; + ref_ptr NEVER_INT; + ref_ptr NEVER_UINT; + ref_ptr NEVER_FLOAT; + + template + inline T min(T a, T b) { return std::min(a, b); } + + template + inline T max(T a, T b) { return std::max(a, b); } + + void register_ops() { + DECL_OP_UNA_S(IDENTITY_INT, IDENSTITY, T_INT, { return a; }); + DECL_OP_UNA_S(IDENTITY_UINT, IDENSTITY, T_UINT, { return a; }); + DECL_OP_UNA_S(IDENTITY_FLOAT, IDENSTITY, T_FLOAT, { return a; }); + DECL_OP_UNA_S(AINV_INT, AINV, T_INT, { return -a; }); + DECL_OP_UNA_S(AINV_UINT, AINV, T_UINT, { return -a; }); + DECL_OP_UNA_S(AINV_FLOAT, AINV, T_FLOAT, { return -a; }); + DECL_OP_UNA_S(MINV_INT, MINV, T_INT, { return 1 / a; }); + DECL_OP_UNA_S(MINV_UINT, MINV, T_UINT, { return 1 / a; }); + DECL_OP_UNA_S(MINV_FLOAT, MINV, T_FLOAT, { return 1.0f / a; }); + DECL_OP_UNA_S(LNOT_INT, LNOT, T_INT, { return !(a != 0); }); + DECL_OP_UNA_S(LNOT_UINT, LNOT, T_UINT, { return !(a != 0); }); + DECL_OP_UNA_S(LNOT_FLOAT, LNOT, T_FLOAT, { return !(a != 0); }); + DECL_OP_UNA_S(UONE_INT, UONE, T_INT, { return 1; }); + DECL_OP_UNA_S(UONE_UINT, UONE, T_UINT, { return 1; }); + DECL_OP_UNA_S(UONE_FLOAT, UONE, T_FLOAT, { return 1; }); + DECL_OP_UNA_S(ABS_INT, ABS, T_INT, { return abs(a); }); + DECL_OP_UNA_S(ABS_UINT, ABS, T_UINT, { return a; }); + DECL_OP_UNA_S(ABS_FLOAT, ABS, T_FLOAT, { return fabs(a); }); + + DECL_OP_UNA_S(BNOT_INT, BNOT, T_INT, { return ~a; }); + DECL_OP_UNA_S(BNOT_UINT, BNOT, T_UINT, { return ~a; }); + + DECL_OP_UNA_S(SQRT_FLOAT, SQRT, T_FLOAT, { return sqrt(a); }); + DECL_OP_UNA_S(LOG_FLOAT, LOG, T_FLOAT, { return log(a); }); + DECL_OP_UNA_S(EXP_FLOAT, EXP, T_FLOAT, { return exp(a); }); + DECL_OP_UNA_S(SIN_FLOAT, SIN, T_FLOAT, { return sin(a); }); + DECL_OP_UNA_S(COS_FLOAT, COS, T_FLOAT, { return cos(a); }); + DECL_OP_UNA_S(TAN_FLOAT, TAN, T_FLOAT, { return tan(a); }); + DECL_OP_UNA_S(ASIN_FLOAT, ASIN, T_FLOAT, { return asin(a); }); + DECL_OP_UNA_S(ACOS_FLOAT, ACOS, T_FLOAT, { return acos(a); }); + DECL_OP_UNA_S(ATAN_FLOAT, ATAN, T_FLOAT, { return atan(a); }); + DECL_OP_UNA_S(CEIL_FLOAT, CEIL, T_FLOAT, { return ceil(a); }); + DECL_OP_UNA_S(FLOOR_FLOAT, FLOOR, T_FLOAT, { return floor(a); }); + DECL_OP_UNA_S(ROUND_FLOAT, ROUND, T_FLOAT, { return round(a); }); + DECL_OP_UNA_S(TRUNC_FLOAT, TRUNC, T_FLOAT, { return trunc(a); }); + IDENTITY_PAIR = spla::OpUnary::make_pair( + "IDENTITY_PAIR", "(a) identity_pair(a)", [](Pair a) { return a; }); + + DECL_OP_BIN_S(PLUS_INT, PLUS, T_INT, { return a + b; }); + DECL_OP_BIN_S(PLUS_UINT, PLUS, T_UINT, { return a + b; }); + DECL_OP_BIN_S(PLUS_FLOAT, PLUS, T_FLOAT, { return a + b; }); + DECL_OP_BIN_S(MINUS_INT, MINUS, T_INT, { return a - b; }); + DECL_OP_BIN_S(MINUS_UINT, MINUS, T_UINT, { return a - b; }); + DECL_OP_BIN_S(MINUS_FLOAT, MINUS, T_FLOAT, { return a - b; }); + DECL_OP_BIN_S(MULT_INT, MULT, T_INT, { return a * b; }); + DECL_OP_BIN_S(MULT_UINT, MULT, T_UINT, { return a * b; }); + DECL_OP_BIN_S(MULT_FLOAT, MULT, T_FLOAT, { return a * b; }); + DECL_OP_BIN_S(DIV_INT, DIV, T_INT, { return a / b; }); + DECL_OP_BIN_S(DIV_UINT, DIV, T_UINT, { return a / b; }); + DECL_OP_BIN_S(DIV_FLOAT, DIV, T_FLOAT, { return a / b; }); + + DECL_OP_BIN_S(MINUS_POW2_INT, MINUS_POW2, T_INT, + { return (a - b) * (a - b); }); + DECL_OP_BIN_S(MINUS_POW2_UINT, MINUS_POW2, T_UINT, + { return (a - b) * (a - b); }); + DECL_OP_BIN_S(MINUS_POW2_FLOAT, MINUS_POW2, T_FLOAT, + { return (a - b) * (a - b); }); + + DECL_OP_BIN_S(FIRST_INT, FIRST, T_INT, { return a; }); + DECL_OP_BIN_S(FIRST_UINT, FIRST, T_UINT, { return a; }); + DECL_OP_BIN_S(FIRST_FLOAT, FIRST, T_FLOAT, { return a; }); + DECL_OP_BIN_S(SECOND_INT, SECOND, T_INT, { return b; }); + DECL_OP_BIN_S(SECOND_UINT, SECOND, T_UINT, { return b; }); + DECL_OP_BIN_S(SECOND_FLOAT, SECOND, T_FLOAT, { return b; }); + + DECL_OP_BIN_S(BONE_INT, BONE, T_INT, { return 1; }); + DECL_OP_BIN_S(BONE_UINT, BONE, T_UINT, { return 1; }); + DECL_OP_BIN_S(BONE_FLOAT, BONE, T_FLOAT, { return 1; }); + + DECL_OP_BIN_S(MIN_INT, MIN, T_INT, { return min(a, b); }); + DECL_OP_BIN_S(MIN_UINT, MIN, T_UINT, { return min(a, b); }); + DECL_OP_BIN_S(MIN_FLOAT, MIN, T_FLOAT, { return min(a, b); }); + DECL_OP_BIN_S(MAX_INT, MAX, T_INT, { return max(a, b); }); + DECL_OP_BIN_S(MAX_UINT, MAX, T_UINT, { return max(a, b); }); + DECL_OP_BIN_S(MAX_FLOAT, MAX, T_FLOAT, { return max(a, b); }); + + DECL_OP_BIN_S(LOR_INT, LOR, T_INT, { return a || b; }); + DECL_OP_BIN_S(LOR_UINT, LOR, T_UINT, { return a || b; }); + DECL_OP_BIN_S(LOR_FLOAT, LOR, T_FLOAT, { return a || b; }); + DECL_OP_BIN_S(LAND_INT, LAND, T_INT, { return a && b; }); + DECL_OP_BIN_S(LAND_UINT, LAND, T_UINT, { return a && b; }); + DECL_OP_BIN_S(LAND_FLOAT, LAND, T_FLOAT, { return a && b; }); + + DECL_OP_BIN_S(BOR_INT, BOR, T_INT, { return a | b; }); + DECL_OP_BIN_S(BOR_UINT, BOR, T_UINT, { return a | b; }); + DECL_OP_BIN_S(BAND_INT, BAND, T_INT, { return a & b; }); + DECL_OP_BIN_S(BAND_UINT, BAND, T_UINT, { return a & b; }); + DECL_OP_BIN_S(BXOR_INT, BXOR, T_INT, { return a ^ b; }); + DECL_OP_BIN_S(BXOR_UINT, BXOR, T_UINT, { return a ^ b; }); + + MUL_PAIR = OpBinary::make_pair( + "MUL_PAIR", "(a, b) make_pair(a.weight, b.vertex)", + [](Pair a, Pair b) { return Pair(a.weight, b.vertex); }); + MIN_PAIR = OpBinary::make_pair("MIN_PAIR", "(a, b) min_pair(a, b)", + [](Pair a, Pair b) { + if (a.weight == b.weight) + return a.vertex < b.vertex ? a : b; + return a.weight < b.weight ? a : b; + }); + + DECL_OP_SELECT(EQZERO_INT, EQZERO, T_INT, { return a == 0; }); + DECL_OP_SELECT(EQZERO_UINT, EQZERO, T_UINT, { return a == 0; }); + DECL_OP_SELECT(EQZERO_FLOAT, EQZERO, T_FLOAT, { return a == 0; }); + DECL_OP_SELECT(NQZERO_INT, NQZERO, T_INT, { return a != 0; }); + DECL_OP_SELECT(NQZERO_UINT, NQZERO, T_UINT, { return a != 0; }); + DECL_OP_SELECT(NQZERO_FLOAT, NQZERO, T_FLOAT, { return a != 0; }); + DECL_OP_SELECT(GTZERO_INT, GTZERO, T_INT, { return a > 0; }); + DECL_OP_SELECT(GTZERO_UINT, GTZERO, T_UINT, { return a > 0; }); + DECL_OP_SELECT(GTZERO_FLOAT, GTZERO, T_FLOAT, { return a > 0; }); + DECL_OP_SELECT(GEZERO_INT, GEZERO, T_INT, { return a >= 0; }); + DECL_OP_SELECT(GEZERO_UINT, GEZERO, T_UINT, { return a >= 0; }); + DECL_OP_SELECT(GEZERO_FLOAT, GEZERO, T_FLOAT, { return a >= 0; }); + DECL_OP_SELECT(LTZERO_INT, LTZERO, T_INT, { return a < 0; }); + DECL_OP_SELECT(LTZERO_UINT, LTZERO, T_UINT, { return a < 0; }); + DECL_OP_SELECT(LTZERO_FLOAT, LTZERO, T_FLOAT, { return a < 0; }); + DECL_OP_SELECT(LEZERO_INT, LEZERO, T_INT, { return a <= 0; }); + DECL_OP_SELECT(LEZERO_UINT, LEZERO, T_UINT, { return a <= 0; }); + DECL_OP_SELECT(LEZERO_FLOAT, LEZERO, T_FLOAT, { return a <= 0; }); + DECL_OP_SELECT(ALWAYS_INT, ALWAYS, T_INT, { return 1; }); + DECL_OP_SELECT(ALWAYS_UINT, ALWAYS, T_UINT, { return 1; }); + DECL_OP_SELECT(ALWAYS_FLOAT, ALWAYS, T_FLOAT, { return 1; }); + ALWAYS_PAIR = OpSelect::make_pair("ALWAYS_PAIR", "(a) pair_always(a)", + [](Pair a) { return 1; }); + DECL_OP_SELECT(NEVER_INT, NEVER, T_INT, { return 0; }); + DECL_OP_SELECT(NEVER_UINT, NEVER, T_UINT, { return 0; }); + DECL_OP_SELECT(NEVER_FLOAT, NEVER, T_FLOAT, { return 0; }); + } + + ref_ptr OpUnary::make_int(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_res()->get_code(); + return op.as(); + } + ref_ptr OpUnary::make_uint(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_res()->get_code(); + return op.as(); + } + ref_ptr OpUnary::make_float(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_res()->get_code(); + return op.as(); + } + ref_ptr OpUnary::make_pair(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_res()->get_code(); + return op.as(); + } + + ref_ptr + OpBinary::make_int(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + return op.as(); + } + ref_ptr + OpBinary::make_uint(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + return op.as(); + } + ref_ptr + OpBinary::make_float(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + return op.as(); + } + ref_ptr + OpBinary::make_pair(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + return op.as(); + } + + ref_ptr OpSelect::make_int(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code(); + return op.as(); + } + ref_ptr OpSelect::make_uint(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code(); + return op.as(); + } + ref_ptr OpSelect::make_float(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code(); + return op.as(); + } + ref_ptr OpSelect::make_pair(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code(); + return op.as(); + } + +}// namespace spla \ No newline at end of file diff --git a/src/scalar.cpp b/src/scalar.cpp index a6235c9c9..551e34765 100644 --- a/src/scalar.cpp +++ b/src/scalar.cpp @@ -37,39 +37,39 @@ namespace spla { -ref_ptr Scalar::make(const ref_ptr &type) { - if (!type) { - LOG_MSG(Status::InvalidArgument, "passed null type"); - return ref_ptr{}; - } + ref_ptr Scalar::make(const ref_ptr& type) { + if (!type) { + LOG_MSG(Status::InvalidArgument, "passed null type"); + return ref_ptr{}; + } - Library::get(); + Library::get(); - if (type == INT) { - return ref_ptr(new TScalar()); - } - if (type == UINT) { - return ref_ptr(new TScalar()); - } - if (type == FLOAT) { - return ref_ptr(new TScalar()); - } - if (type == PAIR) { - return ref_ptr(new TScalar()); - } + if (type == INT) { + return ref_ptr(new TScalar()); + } + if (type == UINT) { + return ref_ptr(new TScalar()); + } + if (type == FLOAT) { + return ref_ptr(new TScalar()); + } + if (type == PAIR) { + return ref_ptr(new TScalar()); + } - LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); - return ref_ptr(); -} + LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); + return ref_ptr(); + } -ref_ptr Scalar::make_int(std::int32_t value) { - return ref_ptr(new TScalar(value)); -} -ref_ptr Scalar::make_uint(std::uint32_t value) { - return ref_ptr(new TScalar(value)); -} -ref_ptr Scalar::Scalar::make_float(float value) { - return ref_ptr(new TScalar(value)); -} + ref_ptr Scalar::make_int(std::int32_t value) { + return ref_ptr(new TScalar(value)); + } + ref_ptr Scalar::make_uint(std::uint32_t value) { + return ref_ptr(new TScalar(value)); + } + ref_ptr Scalar::Scalar::make_float(float value) { + return ref_ptr(new TScalar(value)); + } -} // namespace spla \ No newline at end of file +}// namespace spla \ No newline at end of file diff --git a/src/type.cpp b/src/type.cpp index ce3f4c64a..3e32829f2 100644 --- a/src/type.cpp +++ b/src/type.cpp @@ -36,15 +36,15 @@ namespace spla { -ref_ptr BOOL = - TType::make_type("BOOL", "B", "bool", "4 byte logical type", 1); -ref_ptr INT = TType::make_type("INT", "I", "int", - "signed 4 byte integral type", 2); -ref_ptr UINT = TType::make_type( - "UINT", "U", "uint", "unsigned 4 byte integral type", 3); -ref_ptr FLOAT = TType::make_type( - "FLOAT", "F", "float", "4 byte floating point type", 4); -ref_ptr PAIR = TType::make_type( - "PAIR", "P", "struct Pair", "weight-vertex pair float-int", 5); + ref_ptr BOOL = + TType::make_type("BOOL", "B", "bool", "4 byte logical type", 1); + ref_ptr INT = TType::make_type("INT", "I", "int", + "signed 4 byte integral type", 2); + ref_ptr UINT = TType::make_type( + "UINT", "U", "uint", "unsigned 4 byte integral type", 3); + ref_ptr FLOAT = TType::make_type( + "FLOAT", "F", "float", "4 byte floating point type", 4); + ref_ptr PAIR = TType::make_type( + "PAIR", "P", "struct Pair", "weight-vertex pair float-int", 5); -} // namespace spla \ No newline at end of file +}// namespace spla \ No newline at end of file diff --git a/src/vector.cpp b/src/vector.cpp index e152cc80d..a25fc6d0c 100644 --- a/src/vector.cpp +++ b/src/vector.cpp @@ -37,33 +37,33 @@ namespace spla { -ref_ptr Vector::make(uint n_rows, const ref_ptr &type) { - if (n_rows <= 0) { - LOG_MSG(Status::InvalidArgument, "passed 0 dim"); - return ref_ptr{}; - } - if (!type) { - LOG_MSG(Status::InvalidArgument, "passed null type"); - return ref_ptr{}; - } + ref_ptr Vector::make(uint n_rows, const ref_ptr& type) { + if (n_rows <= 0) { + LOG_MSG(Status::InvalidArgument, "passed 0 dim"); + return ref_ptr{}; + } + if (!type) { + LOG_MSG(Status::InvalidArgument, "passed null type"); + return ref_ptr{}; + } - Library::get(); + Library::get(); - if (type == INT) { - return ref_ptr(new TVector(n_rows)); - } - if (type == UINT) { - return ref_ptr(new TVector(n_rows)); - } - if (type == FLOAT) { - return ref_ptr(new TVector(n_rows)); - } - if (type == spla::PAIR) { - return ref_ptr(new TVector(n_rows)); - } + if (type == INT) { + return ref_ptr(new TVector(n_rows)); + } + if (type == UINT) { + return ref_ptr(new TVector(n_rows)); + } + if (type == FLOAT) { + return ref_ptr(new TVector(n_rows)); + } + if (type == spla::PAIR) { + return ref_ptr(new TVector(n_rows)); + } - LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); - return ref_ptr{}; -} + LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); + return ref_ptr{}; + } -} // namespace spla \ No newline at end of file +}// namespace spla \ No newline at end of file From b1aaa66f04bd4727547a8bbb93450968103ddcc4 Mon Sep 17 00:00:00 2001 From: polka777 Date: Sun, 26 Apr 2026 12:11:47 +0300 Subject: [PATCH 12/14] Clang-tidy format examples/ --- examples/mst.cpp | 136 +++++++++++++++++++++++------------------------ 1 file changed, 68 insertions(+), 68 deletions(-) diff --git a/examples/mst.cpp b/examples/mst.cpp index 23b31d051..60bcb45f0 100644 --- a/examples/mst.cpp +++ b/examples/mst.cpp @@ -38,95 +38,95 @@ #include #include -int main(int argc, const char *const *argv) { - auto options = make_options( - "mst", "Boruvka's Minimum Spanning Tree algorithm with spla library"); +int main(int argc, const char* const* argv) { + auto options = make_options( + "mst", "Boruvka's Minimum Spanning Tree algorithm with spla library"); - cxxopts::ParseResult args; - int ret; + cxxopts::ParseResult args; + int ret; - if (parse_options(argc, argv, options, args, ret)) { - std::cerr << "failed to parse options" << std::endl; - return ret; - } + if (parse_options(argc, argv, options, args, ret)) { + std::cerr << "failed to parse options" << std::endl; + return ret; + } - spla::Timer timer_total; - spla::Timer timer_gpu; - spla::Timer timer_ref; - spla::MtxLoader loader; + spla::Timer timer_total; + spla::Timer timer_gpu; + spla::Timer timer_ref; + spla::MtxLoader loader; - timer_total.start(); + timer_total.start(); - if (!loader.load(args["mtxpath"].as())) { - std::cerr << "failed to load graph"; - return 1; - } + if (!loader.load(args["mtxpath"].as())) { + std::cerr << "failed to load graph"; + return 1; + } - std::string acc_info; - spla::Library *library = spla::Library::get(); + std::string acc_info; + spla::Library* library = spla::Library::get(); - library->set_platform(args["platform"].as()); - library->set_device(args["device"].as()); - library->set_queues_count(1); - library->get_accelerator_info(acc_info); - std::cout << "env: " << acc_info << std::endl; + library->set_platform(args["platform"].as()); + library->set_device(args["device"].as()); + library->set_queues_count(1); + library->get_accelerator_info(acc_info); + std::cout << "env: " << acc_info << std::endl; - const spla::uint N = loader.get_n_rows(); - auto S = spla::Matrix::make(N, N, spla::PAIR); + const spla::uint N = loader.get_n_rows(); + auto S = spla::Matrix::make(N, N, spla::PAIR); - const auto &Ai = loader.get_Ai(); - const auto &Aj = loader.get_Aj(); - const auto &Aw = loader.get_Aw(); + const auto& Ai = loader.get_Ai(); + const auto& Aj = loader.get_Aj(); + const auto& Aw = loader.get_Aw(); - for (std::size_t k = 0; k < loader.get_n_values(); ++k) { - S->set_pair(Ai[k], Aj[k], spla::T_PAIR(Aw[k], Aj[k])); - } + for (std::size_t k = 0; k < loader.get_n_values(); ++k) { + S->set_pair(Ai[k], Aj[k], spla::T_PAIR(Aw[k], Aj[k])); + } - auto T_gpu = spla::Matrix::make(N, N, spla::FLOAT); + auto T_gpu = spla::Matrix::make(N, N, spla::FLOAT); - auto desc = spla::Descriptor::make(); + auto desc = spla::Descriptor::make(); - const int n_iters = args["niters"].as(); + const int n_iters = args["niters"].as(); - double total_weight_gpu = 0.0; + double total_weight_gpu = 0.0; - if (args["run-gpu"].as()) { - library->set_force_no_acceleration(false); + if (args["run-gpu"].as()) { + library->set_force_no_acceleration(false); - for (int i = 0; i < n_iters; ++i) { - T_gpu->clear(); - S = spla::Matrix::make(N, N, spla::PAIR); - for (std::size_t k = 0; k < loader.get_n_values(); ++k) { - S->set_pair(Ai[k], Aj[k], spla::T_PAIR(Aw[k], Aj[k])); - } - timer_gpu.lap_begin(); - spla::mst(T_gpu, S, desc, nullptr); - timer_gpu.lap_end(); - } + for (int i = 0; i < n_iters; ++i) { + T_gpu->clear(); + S = spla::Matrix::make(N, N, spla::PAIR); + for (std::size_t k = 0; k < loader.get_n_values(); ++k) { + S->set_pair(Ai[k], Aj[k], spla::T_PAIR(Aw[k], Aj[k])); + } + timer_gpu.lap_begin(); + spla::mst(T_gpu, S, desc, nullptr); + timer_gpu.lap_end(); + } - total_weight_gpu = 0; - for (spla::uint i = 0; i < N; ++i) { - for (spla::uint j = i + 1; j < N; ++j) { - float w; - T_gpu->get_float(i, j, w); - if (w != 0.0) { - total_weight_gpu += w; + total_weight_gpu = 0; + for (spla::uint i = 0; i < N; ++i) { + for (spla::uint j = i + 1; j < N; ++j) { + float w; + T_gpu->get_float(i, j, w); + if (w != 0.0) { + total_weight_gpu += w; + } + } } - } - } - std::cout << "GPU MST total weight: " << total_weight_gpu << std::endl; - } + std::cout << "GPU MST total weight: " << total_weight_gpu << std::endl; + } - spla::Library::get()->finalize(); + spla::Library::get()->finalize(); - timer_total.stop(); + timer_total.stop(); - std::cout << "\n=== Timing Results ===" << std::endl; - std::cout << "total(ms):" << timer_total.get_elapsed_ms() << std::endl; - std::cout << "gpu(ms): "; - timer_gpu.print(); - std::cout << std::endl; + std::cout << "\n=== Timing Results ===" << std::endl; + std::cout << "total(ms):" << timer_total.get_elapsed_ms() << std::endl; + std::cout << "gpu(ms): "; + timer_gpu.print(); + std::cout << std::endl; - return 0; + return 0; } \ No newline at end of file From 3706036e11c4e64148bf7c7b4b3e69f0459a816d Mon Sep 17 00:00:00 2001 From: polka777 Date: Sun, 26 Apr 2026 12:16:00 +0300 Subject: [PATCH 13/14] Clang-tidy format src/ --- src/binding/c_op.cpp | 206 +++++++++++++++++++++---------------------- 1 file changed, 103 insertions(+), 103 deletions(-) diff --git a/src/binding/c_op.cpp b/src/binding/c_op.cpp index d959bf1cb..3feb1ac1c 100644 --- a/src/binding/c_op.cpp +++ b/src/binding/c_op.cpp @@ -35,313 +35,313 @@ #include "c_config.hpp" spla_OpUnary spla_OpUnary_IDENTITY_INT() { - return as_ptr(spla::IDENTITY_INT.ref_and_get()); + return as_ptr(spla::IDENTITY_INT.ref_and_get()); } spla_OpUnary spla_OpUnary_IDENTITY_UINT() { - return as_ptr(spla::IDENTITY_UINT.ref_and_get()); + return as_ptr(spla::IDENTITY_UINT.ref_and_get()); } spla_OpUnary spla_OpUnary_IDENTITY_FLOAT() { - return as_ptr(spla::IDENTITY_FLOAT.ref_and_get()); + return as_ptr(spla::IDENTITY_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_IDENTITY_PAIR() { - return as_ptr(spla::IDENTITY_PAIR.ref_and_get()); + return as_ptr(spla::IDENTITY_PAIR.ref_and_get()); } spla_OpUnary spla_OpUnary_AINV_INT() { - return as_ptr(spla::AINV_INT.ref_and_get()); + return as_ptr(spla::AINV_INT.ref_and_get()); } spla_OpUnary spla_OpUnary_AINV_UINT() { - return as_ptr(spla::AINV_UINT.ref_and_get()); + return as_ptr(spla::AINV_UINT.ref_and_get()); } spla_OpUnary spla_OpUnary_AINV_FLOAT() { - return as_ptr(spla::AINV_FLOAT.ref_and_get()); + return as_ptr(spla::AINV_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_MINV_INT() { - return as_ptr(spla::MINV_INT.ref_and_get()); + return as_ptr(spla::MINV_INT.ref_and_get()); } spla_OpUnary spla_OpUnary_MINV_UINT() { - return as_ptr(spla::MINV_UINT.ref_and_get()); + return as_ptr(spla::MINV_UINT.ref_and_get()); } spla_OpUnary spla_OpUnary_MINV_FLOAT() { - return as_ptr(spla::MINV_FLOAT.ref_and_get()); + return as_ptr(spla::MINV_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_LNOT_INT() { - return as_ptr(spla::LNOT_INT.ref_and_get()); + return as_ptr(spla::LNOT_INT.ref_and_get()); } spla_OpUnary spla_OpUnary_LNOT_UINT() { - return as_ptr(spla::LNOT_UINT.ref_and_get()); + return as_ptr(spla::LNOT_UINT.ref_and_get()); } spla_OpUnary spla_OpUnary_LNOT_FLOAT() { - return as_ptr(spla::LNOT_FLOAT.ref_and_get()); + return as_ptr(spla::LNOT_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_UONE_INT() { - return as_ptr(spla::UONE_INT.ref_and_get()); + return as_ptr(spla::UONE_INT.ref_and_get()); } spla_OpUnary spla_OpUnary_UONE_UINT() { - return as_ptr(spla::UONE_UINT.ref_and_get()); + return as_ptr(spla::UONE_UINT.ref_and_get()); } spla_OpUnary spla_OpUnary_UONE_FLOAT() { - return as_ptr(spla::UONE_FLOAT.ref_and_get()); + return as_ptr(spla::UONE_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_ABS_INT() { - return as_ptr(spla::ABS_INT.ref_and_get()); + return as_ptr(spla::ABS_INT.ref_and_get()); } spla_OpUnary spla_OpUnary_ABS_UINT() { - return as_ptr(spla::ABS_UINT.ref_and_get()); + return as_ptr(spla::ABS_UINT.ref_and_get()); } spla_OpUnary spla_OpUnary_ABS_FLOAT() { - return as_ptr(spla::ABS_FLOAT.ref_and_get()); + return as_ptr(spla::ABS_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_BNOT_INT() { - return as_ptr(spla::BNOT_INT.ref_and_get()); + return as_ptr(spla::BNOT_INT.ref_and_get()); } spla_OpUnary spla_OpUnary_BNOT_UINT() { - return as_ptr(spla::BNOT_UINT.ref_and_get()); + return as_ptr(spla::BNOT_UINT.ref_and_get()); } spla_OpUnary spla_OpUnary_SQRT_FLOAT() { - return as_ptr(spla::SQRT_FLOAT.ref_and_get()); + return as_ptr(spla::SQRT_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_LOG_FLOAT() { - return as_ptr(spla::LOG_FLOAT.ref_and_get()); + return as_ptr(spla::LOG_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_EXP_FLOAT() { - return as_ptr(spla::EXP_FLOAT.ref_and_get()); + return as_ptr(spla::EXP_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_SIN_FLOAT() { - return as_ptr(spla::SIN_FLOAT.ref_and_get()); + return as_ptr(spla::SIN_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_COS_FLOAT() { - return as_ptr(spla::COS_FLOAT.ref_and_get()); + return as_ptr(spla::COS_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_TAN_FLOAT() { - return as_ptr(spla::TAN_FLOAT.ref_and_get()); + return as_ptr(spla::TAN_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_ASIN_FLOAT() { - return as_ptr(spla::ASIN_FLOAT.ref_and_get()); + return as_ptr(spla::ASIN_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_ACOS_FLOAT() { - return as_ptr(spla::ACOS_FLOAT.ref_and_get()); + return as_ptr(spla::ACOS_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_ATAN_FLOAT() { - return as_ptr(spla::ATAN_FLOAT.ref_and_get()); + return as_ptr(spla::ATAN_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_CEIL_FLOAT() { - return as_ptr(spla::CEIL_FLOAT.ref_and_get()); + return as_ptr(spla::CEIL_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_FLOOR_FLOAT() { - return as_ptr(spla::FLOOR_FLOAT.ref_and_get()); + return as_ptr(spla::FLOOR_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_ROUND_FLOAT() { - return as_ptr(spla::ROUND_FLOAT.ref_and_get()); + return as_ptr(spla::ROUND_FLOAT.ref_and_get()); } spla_OpUnary spla_OpUnary_TRUNC_FLOAT() { - return as_ptr(spla::TRUNC_FLOAT.ref_and_get()); + return as_ptr(spla::TRUNC_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_PLUS_INT() { - return as_ptr(spla::PLUS_INT.ref_and_get()); + return as_ptr(spla::PLUS_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_PLUS_UINT() { - return as_ptr(spla::PLUS_UINT.ref_and_get()); + return as_ptr(spla::PLUS_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_PLUS_FLOAT() { - return as_ptr(spla::PLUS_FLOAT.ref_and_get()); + return as_ptr(spla::PLUS_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_MINUS_INT() { - return as_ptr(spla::MINUS_INT.ref_and_get()); + return as_ptr(spla::MINUS_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_MINUS_UINT() { - return as_ptr(spla::MINUS_UINT.ref_and_get()); + return as_ptr(spla::MINUS_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_MINUS_FLOAT() { - return as_ptr(spla::MINUS_FLOAT.ref_and_get()); + return as_ptr(spla::MINUS_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_MULT_INT() { - return as_ptr(spla::MULT_INT.ref_and_get()); + return as_ptr(spla::MULT_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_MULT_UINT() { - return as_ptr(spla::MULT_UINT.ref_and_get()); + return as_ptr(spla::MULT_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_MULT_FLOAT() { - return as_ptr(spla::MULT_FLOAT.ref_and_get()); + return as_ptr(spla::MULT_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_DIV_INT() { - return as_ptr(spla::DIV_INT.ref_and_get()); + return as_ptr(spla::DIV_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_DIV_UINT() { - return as_ptr(spla::DIV_UINT.ref_and_get()); + return as_ptr(spla::DIV_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_DIV_FLOAT() { - return as_ptr(spla::DIV_FLOAT.ref_and_get()); + return as_ptr(spla::DIV_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_MINUS_POW2_INT() { - return as_ptr(spla::MINUS_POW2_INT.ref_and_get()); + return as_ptr(spla::MINUS_POW2_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_MINUS_POW2_UINT() { - return as_ptr(spla::MINUS_POW2_UINT.ref_and_get()); + return as_ptr(spla::MINUS_POW2_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_MINUS_POW2_FLOAT() { - return as_ptr(spla::MINUS_POW2_FLOAT.ref_and_get()); + return as_ptr(spla::MINUS_POW2_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_FIRST_INT() { - return as_ptr(spla::FIRST_INT.ref_and_get()); + return as_ptr(spla::FIRST_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_FIRST_UINT() { - return as_ptr(spla::FIRST_UINT.ref_and_get()); + return as_ptr(spla::FIRST_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_FIRST_FLOAT() { - return as_ptr(spla::FIRST_FLOAT.ref_and_get()); + return as_ptr(spla::FIRST_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_SECOND_INT() { - return as_ptr(spla::SECOND_INT.ref_and_get()); + return as_ptr(spla::SECOND_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_SECOND_UINT() { - return as_ptr(spla::SECOND_UINT.ref_and_get()); + return as_ptr(spla::SECOND_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_SECOND_FLOAT() { - return as_ptr(spla::SECOND_FLOAT.ref_and_get()); + return as_ptr(spla::SECOND_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_BONE_INT() { - return as_ptr(spla::BONE_INT.ref_and_get()); + return as_ptr(spla::BONE_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_BONE_UINT() { - return as_ptr(spla::BONE_UINT.ref_and_get()); + return as_ptr(spla::BONE_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_BONE_FLOAT() { - return as_ptr(spla::BONE_FLOAT.ref_and_get()); + return as_ptr(spla::BONE_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_MIN_INT() { - return as_ptr(spla::MIN_INT.ref_and_get()); + return as_ptr(spla::MIN_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_MIN_UINT() { - return as_ptr(spla::MIN_UINT.ref_and_get()); + return as_ptr(spla::MIN_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_MIN_FLOAT() { - return as_ptr(spla::MIN_FLOAT.ref_and_get()); + return as_ptr(spla::MIN_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_MAX_INT() { - return as_ptr(spla::MAX_INT.ref_and_get()); + return as_ptr(spla::MAX_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_MAX_UINT() { - return as_ptr(spla::MAX_UINT.ref_and_get()); + return as_ptr(spla::MAX_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_MAX_FLOAT() { - return as_ptr(spla::MAX_FLOAT.ref_and_get()); + return as_ptr(spla::MAX_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_LOR_INT() { - return as_ptr(spla::LOR_INT.ref_and_get()); + return as_ptr(spla::LOR_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_LOR_UINT() { - return as_ptr(spla::LOR_UINT.ref_and_get()); + return as_ptr(spla::LOR_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_LOR_FLOAT() { - return as_ptr(spla::LOR_FLOAT.ref_and_get()); + return as_ptr(spla::LOR_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_LAND_INT() { - return as_ptr(spla::LAND_INT.ref_and_get()); + return as_ptr(spla::LAND_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_LAND_UINT() { - return as_ptr(spla::LAND_UINT.ref_and_get()); + return as_ptr(spla::LAND_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_LAND_FLOAT() { - return as_ptr(spla::LAND_FLOAT.ref_and_get()); + return as_ptr(spla::LAND_FLOAT.ref_and_get()); } spla_OpBinary spla_OpBinary_BOR_INT() { - return as_ptr(spla::BOR_INT.ref_and_get()); + return as_ptr(spla::BOR_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_BOR_UINT() { - return as_ptr(spla::BOR_UINT.ref_and_get()); + return as_ptr(spla::BOR_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_BAND_INT() { - return as_ptr(spla::BAND_INT.ref_and_get()); + return as_ptr(spla::BAND_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_BAND_UINT() { - return as_ptr(spla::BAND_UINT.ref_and_get()); + return as_ptr(spla::BAND_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_BXOR_INT() { - return as_ptr(spla::BXOR_INT.ref_and_get()); + return as_ptr(spla::BXOR_INT.ref_and_get()); } spla_OpBinary spla_OpBinary_BXOR_UINT() { - return as_ptr(spla::BXOR_UINT.ref_and_get()); + return as_ptr(spla::BXOR_UINT.ref_and_get()); } spla_OpBinary spla_OpBinary_MIN_PAIR() { - return as_ptr(spla::MIN_PAIR.ref_and_get()); + return as_ptr(spla::MIN_PAIR.ref_and_get()); } spla_OpBinary spla_OpBinary_MUL_PAIR() { - return as_ptr(spla::MUL_PAIR.ref_and_get()); + return as_ptr(spla::MUL_PAIR.ref_and_get()); } spla_OpSelect spla_OpSelect_EQZERO_INT() { - return as_ptr(spla::EQZERO_INT.ref_and_get()); + return as_ptr(spla::EQZERO_INT.ref_and_get()); } spla_OpSelect spla_OpSelect_EQZERO_UINT() { - return as_ptr(spla::EQZERO_UINT.ref_and_get()); + return as_ptr(spla::EQZERO_UINT.ref_and_get()); } spla_OpSelect spla_OpSelect_EQZERO_FLOAT() { - return as_ptr(spla::EQZERO_FLOAT.ref_and_get()); + return as_ptr(spla::EQZERO_FLOAT.ref_and_get()); } spla_OpSelect spla_OpSelect_NQZERO_INT() { - return as_ptr(spla::NQZERO_INT.ref_and_get()); + return as_ptr(spla::NQZERO_INT.ref_and_get()); } spla_OpSelect spla_OpSelect_NQZERO_UINT() { - return as_ptr(spla::NQZERO_UINT.ref_and_get()); + return as_ptr(spla::NQZERO_UINT.ref_and_get()); } spla_OpSelect spla_OpSelect_NQZERO_FLOAT() { - return as_ptr(spla::NQZERO_FLOAT.ref_and_get()); + return as_ptr(spla::NQZERO_FLOAT.ref_and_get()); } spla_OpSelect spla_OpSelect_GTZERO_INT() { - return as_ptr(spla::GTZERO_INT.ref_and_get()); + return as_ptr(spla::GTZERO_INT.ref_and_get()); } spla_OpSelect spla_OpSelect_GTZERO_UINT() { - return as_ptr(spla::GTZERO_UINT.ref_and_get()); + return as_ptr(spla::GTZERO_UINT.ref_and_get()); } spla_OpSelect spla_OpSelect_GTZERO_FLOAT() { - return as_ptr(spla::GTZERO_FLOAT.ref_and_get()); + return as_ptr(spla::GTZERO_FLOAT.ref_and_get()); } spla_OpSelect spla_OpSelect_GEZERO_INT() { - return as_ptr(spla::GEZERO_INT.ref_and_get()); + return as_ptr(spla::GEZERO_INT.ref_and_get()); } spla_OpSelect spla_OpSelect_GEZERO_UINT() { - return as_ptr(spla::GEZERO_UINT.ref_and_get()); + return as_ptr(spla::GEZERO_UINT.ref_and_get()); } spla_OpSelect spla_OpSelect_GEZERO_FLOAT() { - return as_ptr(spla::GEZERO_FLOAT.ref_and_get()); + return as_ptr(spla::GEZERO_FLOAT.ref_and_get()); } spla_OpSelect spla_OpSelect_LTZERO_INT() { - return as_ptr(spla::LTZERO_INT.ref_and_get()); + return as_ptr(spla::LTZERO_INT.ref_and_get()); } spla_OpSelect spla_OpSelect_LTZERO_UINT() { - return as_ptr(spla::LTZERO_UINT.ref_and_get()); + return as_ptr(spla::LTZERO_UINT.ref_and_get()); } spla_OpSelect spla_OpSelect_LTZERO_FLOAT() { - return as_ptr(spla::LTZERO_FLOAT.ref_and_get()); + return as_ptr(spla::LTZERO_FLOAT.ref_and_get()); } spla_OpSelect spla_OpSelect_LEZERO_INT() { - return as_ptr(spla::LEZERO_INT.ref_and_get()); + return as_ptr(spla::LEZERO_INT.ref_and_get()); } spla_OpSelect spla_OpSelect_LEZERO_UINT() { - return as_ptr(spla::LEZERO_UINT.ref_and_get()); + return as_ptr(spla::LEZERO_UINT.ref_and_get()); } spla_OpSelect spla_OpSelect_LEZERO_FLOAT() { - return as_ptr(spla::LEZERO_FLOAT.ref_and_get()); + return as_ptr(spla::LEZERO_FLOAT.ref_and_get()); } spla_OpSelect spla_OpSelect_ALWAYS_INT() { - return as_ptr(spla::ALWAYS_INT.ref_and_get()); + return as_ptr(spla::ALWAYS_INT.ref_and_get()); } spla_OpSelect spla_OpSelect_ALWAYS_UINT() { - return as_ptr(spla::ALWAYS_UINT.ref_and_get()); + return as_ptr(spla::ALWAYS_UINT.ref_and_get()); } spla_OpSelect spla_OpSelect_ALWAYS_FLOAT() { - return as_ptr(spla::ALWAYS_FLOAT.ref_and_get()); + return as_ptr(spla::ALWAYS_FLOAT.ref_and_get()); } spla_OpSelect spla_OpSelect_ALWAYS_PAIR() { - return as_ptr(spla::ALWAYS_PAIR.ref_and_get()); + return as_ptr(spla::ALWAYS_PAIR.ref_and_get()); } spla_OpSelect spla_OpSelect_NEVER_INT() { - return as_ptr(spla::NEVER_INT.ref_and_get()); + return as_ptr(spla::NEVER_INT.ref_and_get()); } spla_OpSelect spla_OpSelect_NEVER_UINT() { - return as_ptr(spla::NEVER_UINT.ref_and_get()); + return as_ptr(spla::NEVER_UINT.ref_and_get()); } spla_OpSelect spla_OpSelect_NEVER_FLOAT() { - return as_ptr(spla::NEVER_FLOAT.ref_and_get()); + return as_ptr(spla::NEVER_FLOAT.ref_and_get()); } \ No newline at end of file From bdf4709bfed12b70ddda3e67a4b33b59427ad714 Mon Sep 17 00:00:00 2001 From: polka777 Date: Sun, 26 Apr 2026 12:18:08 +0300 Subject: [PATCH 14/14] Clang-tidy format --- src/core/tmatrix.hpp | 665 +++++++++++++------------- src/core/top.hpp | 6 +- src/core/tscalar.hpp | 299 ++++++------ src/core/ttype.hpp | 5 +- src/core/tvector.hpp | 751 +++++++++++++++--------------- src/opencl/cl_mxv.hpp | 513 ++++++++++---------- src/opencl/cl_program_builder.cpp | 12 +- 7 files changed, 1154 insertions(+), 1097 deletions(-) diff --git a/src/core/tmatrix.hpp b/src/core/tmatrix.hpp index d8d173867..20c8344b8 100644 --- a/src/core/tmatrix.hpp +++ b/src/core/tmatrix.hpp @@ -49,341 +49,356 @@ namespace spla { -/** + /** * @addtogroup internal * @{ */ -/** + /** * @class TMatrix * @brief Matrix interface implementation with type information bound * * @tparam T Type of stored elements */ -template class TMatrix final : public Matrix { -public: - TMatrix(uint n_rows, uint n_cols); - ~TMatrix() override = default; - - uint get_n_rows() override; - uint get_n_cols() override; - ref_ptr get_type() override; - void set_label(std::string label) override; - const std::string &get_label() const override; - Status set_format(FormatMatrix format) override; - Status set_fill_value(const ref_ptr &value) override; - Status set_reduce(ref_ptr resolve_duplicates) override; - Status set_int(uint row_id, uint col_id, std::int32_t value) override; - Status set_uint(uint row_id, uint col_id, std::uint32_t value) override; - Status set_float(uint row_id, uint col_id, float value) override; - Status set_pair(uint row_id, uint col_id, Pair value) override { - return Status::InvalidArgument; - } - Status get_int(uint row_id, uint col_id, int32_t &value) override; - Status get_uint(uint row_id, uint col_id, uint32_t &value) override; - Status get_float(uint row_id, uint col_id, float &value) override; - Status get_pair(uint row_id, uint col_id, Pair &value) override { - return Status::InvalidArgument; - } - Status build(const ref_ptr &keys1, const ref_ptr &keys2, - const ref_ptr &values) override; - Status read(ref_ptr &keys1, ref_ptr &keys2, - ref_ptr &values) override; - Status clear() override; - - template Decorator *get() { - return m_storage.template get(); - } - - void validate_rw(FormatMatrix format); - void validate_rwd(FormatMatrix format); - void validate_wd(FormatMatrix format); - void validate_ctor(FormatMatrix format); - bool is_valid(FormatMatrix format) const; - T get_fill_value() const { return m_storage.get_fill_value(); } - - static StorageManagerMatrix *get_storage_manager(); - -private: - typename StorageManagerMatrix::Storage m_storage; - std::string m_label; -}; - -template TMatrix::TMatrix(uint n_rows, uint n_cols) { - m_storage.set_dims(n_rows, n_cols); -} - -template uint TMatrix::get_n_rows() { - return m_storage.get_n_rows(); -} -template uint TMatrix::get_n_cols() { - return m_storage.get_n_cols(); -} -template ref_ptr TMatrix::get_type() { - return get_ttype().template as(); -} - -template void TMatrix::set_label(std::string label) { - m_label = std::move(label); - LOG_MSG(Status::Ok, "set label '" << m_label << "' to " << (void *)this); -} -template const std::string &TMatrix::get_label() const { - return m_label; -} - -template Status TMatrix::set_format(FormatMatrix format) { - validate_rw(format); - return Status::Ok; -} -template -Status TMatrix::set_fill_value(const ref_ptr &value) { - if (value) { - m_storage.invalidate(); - - if constexpr (std::is_same::value) - m_storage.set_fill_value(value->as_int()); - if constexpr (std::is_same::value) - m_storage.set_fill_value(value->as_uint()); - if constexpr (std::is_same::value) - m_storage.set_fill_value(value->as_float()); - if constexpr (std::is_same::value) - m_storage.set_fill_value(value->as_pair()); - - return Status::Ok; - } - - return Status::InvalidArgument; -} -template -Status TMatrix::set_reduce(ref_ptr resolve_duplicates) { - auto reduce = resolve_duplicates.template cast_safe>(); - - if (reduce) { - validate_ctor(FormatMatrix::CpuLil); - get>()->reduce = reduce->function; - validate_ctor(FormatMatrix::CpuDok); - get>()->reduce = reduce->function; - } - - return Status::InvalidArgument; -} - -template -Status TMatrix::set_int(uint row_id, uint col_id, std::int32_t value) { - validate_rwd(FormatMatrix::CpuLil); - cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); - return Status::Ok; -} -template <> -inline Status TMatrix::set_int(uint row_id, uint col_id, - std::int32_t value) { - return Status::InvalidArgument; -} -template -Status TMatrix::set_uint(uint row_id, uint col_id, std::uint32_t value) { - validate_rwd(FormatMatrix::CpuLil); - cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); - return Status::Ok; -} -template <> -inline Status TMatrix::set_uint(uint row_id, uint col_id, - std::uint32_t value) { - return Status::InvalidArgument; -} -template -Status TMatrix::set_float(uint row_id, uint col_id, float value) { - validate_rwd(FormatMatrix::CpuLil); - cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); - return Status::Ok; -} -template <> -inline Status TMatrix::set_float(uint row_id, uint col_id, float value) { - return Status::InvalidArgument; -} - -template -Status TMatrix::get_int(uint row_id, uint col_id, int32_t &value) { - validate_rw(FormatMatrix::CpuDok); - - auto &Ax = get>()->Ax; - auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; -} -template <> -inline Status TMatrix::get_int(uint row_id, uint col_id, - std::int32_t &value) { - return Status::InvalidArgument; -} -template -Status TMatrix::get_uint(uint row_id, uint col_id, uint32_t &value) { - validate_rw(FormatMatrix::CpuDok); - - auto &Ax = get>()->Ax; - auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; -} -template <> -inline Status TMatrix::get_uint(uint row_id, uint col_id, - std::uint32_t &value) { - return Status::InvalidArgument; -} -template -Status TMatrix::get_float(uint row_id, uint col_id, float &value) { - validate_rw(FormatMatrix::CpuDok); - - auto &Ax = get>()->Ax; - auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; -} -template <> -inline Status TMatrix::get_float(uint row_id, uint col_id, float &value) { - return Status::InvalidArgument; -} - -template -Status TMatrix::build(const ref_ptr &keys1, - const ref_ptr &keys2, - const ref_ptr &values) { - assert(keys1); - assert(keys2); - assert(values); - - const auto key_size = sizeof(uint); - const auto value_size = sizeof(T); - const auto elements_count = keys1->get_size() / key_size; - - if (elements_count != values->get_size() / value_size) { - return Status::InvalidArgument; - } - if (elements_count * key_size != keys1->get_size()) { - return Status::InvalidArgument; - } - if (elements_count * key_size != keys2->get_size()) { - return Status::InvalidArgument; - } - - validate_rwd(FormatMatrix::CpuCoo); - CpuCoo &coo = *get>(); - - coo.Ai.resize(elements_count); - coo.Aj.resize(elements_count); - coo.Ax.resize(elements_count); - coo.values = uint(elements_count); - - keys1->read(0, key_size * elements_count, coo.Ai.data()); - keys2->read(0, key_size * elements_count, coo.Aj.data()); - values->read(0, value_size * elements_count, coo.Ax.data()); - - return Status::Ok; -} -template -Status TMatrix::read(ref_ptr &keys1, ref_ptr &keys2, - ref_ptr &values) { - const auto key_size = sizeof(uint); - const auto value_size = sizeof(T); - - validate_rw(FormatMatrix::CpuCoo); - CpuCoo &coo = *get>(); - - const auto elements_count = coo.Ai.size(); - - keys1 = MemView::make(coo.Ai.data(), key_size * elements_count, false); - keys2 = MemView::make(coo.Aj.data(), key_size * elements_count, false); - values = MemView::make(coo.Ax.data(), value_size * elements_count, false); - - return Status::Ok; -} - -template Status TMatrix::clear() { - m_storage.invalidate(); - return Status::Ok; -} - -template void TMatrix::validate_rw(FormatMatrix format) { - StorageManagerMatrix *manager = get_storage_manager(); - manager->validate_rw(format, m_storage); -} - -template void TMatrix::validate_rwd(FormatMatrix format) { - StorageManagerMatrix *manager = get_storage_manager(); - manager->validate_rwd(format, m_storage); -} - -template void TMatrix::validate_wd(FormatMatrix format) { - StorageManagerMatrix *manager = get_storage_manager(); - manager->validate_wd(format, m_storage); -} - -template void TMatrix::validate_ctor(FormatMatrix format) { - StorageManagerMatrix *manager = get_storage_manager(); - manager->validate_ctor(format, m_storage); -} - -template bool TMatrix::is_valid(FormatMatrix format) const { - return m_storage.is_valid(format); -} - -template -StorageManagerMatrix *TMatrix::get_storage_manager() { - static std::unique_ptr> storage_manager; - - if (!storage_manager) { - storage_manager = std::make_unique>(); - register_formats_matrix(*storage_manager); - } - - return storage_manager.get(); -} -template <> -inline Status TMatrix::set_pair(uint row_id, uint col_id, Pair value) { - if (get_type() != PAIR) { - return Status::InvalidArgument; - } - - validate_rwd(FormatMatrix::CpuLil); - cpu_lil_add_element(row_id, col_id, value, *get>()); - return Status::Ok; -} -template <> -inline Status TMatrix::get_pair(uint row_id, uint col_id, Pair &value) { - if (get_type() != PAIR) { - return Status::InvalidArgument; - } - validate_rw(FormatMatrix::CpuDok); - - auto &Ax = get>()->Ax; - auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; -} - -/** + template + class TMatrix final : public Matrix { + public: + TMatrix(uint n_rows, uint n_cols); + ~TMatrix() override = default; + + uint get_n_rows() override; + uint get_n_cols() override; + ref_ptr get_type() override; + void set_label(std::string label) override; + const std::string& get_label() const override; + Status set_format(FormatMatrix format) override; + Status set_fill_value(const ref_ptr& value) override; + Status set_reduce(ref_ptr resolve_duplicates) override; + Status set_int(uint row_id, uint col_id, std::int32_t value) override; + Status set_uint(uint row_id, uint col_id, std::uint32_t value) override; + Status set_float(uint row_id, uint col_id, float value) override; + Status set_pair(uint row_id, uint col_id, Pair value) override { + return Status::InvalidArgument; + } + Status get_int(uint row_id, uint col_id, int32_t& value) override; + Status get_uint(uint row_id, uint col_id, uint32_t& value) override; + Status get_float(uint row_id, uint col_id, float& value) override; + Status get_pair(uint row_id, uint col_id, Pair& value) override { + return Status::InvalidArgument; + } + Status build(const ref_ptr& keys1, const ref_ptr& keys2, + const ref_ptr& values) override; + Status read(ref_ptr& keys1, ref_ptr& keys2, + ref_ptr& values) override; + Status clear() override; + + template + Decorator* get() { + return m_storage.template get(); + } + + void validate_rw(FormatMatrix format); + void validate_rwd(FormatMatrix format); + void validate_wd(FormatMatrix format); + void validate_ctor(FormatMatrix format); + bool is_valid(FormatMatrix format) const; + T get_fill_value() const { return m_storage.get_fill_value(); } + + static StorageManagerMatrix* get_storage_manager(); + + private: + typename StorageManagerMatrix::Storage m_storage; + std::string m_label; + }; + + template + TMatrix::TMatrix(uint n_rows, uint n_cols) { + m_storage.set_dims(n_rows, n_cols); + } + + template + uint TMatrix::get_n_rows() { + return m_storage.get_n_rows(); + } + template + uint TMatrix::get_n_cols() { + return m_storage.get_n_cols(); + } + template + ref_ptr TMatrix::get_type() { + return get_ttype().template as(); + } + + template + void TMatrix::set_label(std::string label) { + m_label = std::move(label); + LOG_MSG(Status::Ok, "set label '" << m_label << "' to " << (void*) this); + } + template + const std::string& TMatrix::get_label() const { + return m_label; + } + + template + Status TMatrix::set_format(FormatMatrix format) { + validate_rw(format); + return Status::Ok; + } + template + Status TMatrix::set_fill_value(const ref_ptr& value) { + if (value) { + m_storage.invalidate(); + + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_int()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_uint()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_float()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_pair()); + + return Status::Ok; + } + + return Status::InvalidArgument; + } + template + Status TMatrix::set_reduce(ref_ptr resolve_duplicates) { + auto reduce = resolve_duplicates.template cast_safe>(); + + if (reduce) { + validate_ctor(FormatMatrix::CpuLil); + get>()->reduce = reduce->function; + validate_ctor(FormatMatrix::CpuDok); + get>()->reduce = reduce->function; + } + + return Status::InvalidArgument; + } + + template + Status TMatrix::set_int(uint row_id, uint col_id, std::int32_t value) { + validate_rwd(FormatMatrix::CpuLil); + cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); + return Status::Ok; + } + template<> + inline Status TMatrix::set_int(uint row_id, uint col_id, + std::int32_t value) { + return Status::InvalidArgument; + } + template + Status TMatrix::set_uint(uint row_id, uint col_id, std::uint32_t value) { + validate_rwd(FormatMatrix::CpuLil); + cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); + return Status::Ok; + } + template<> + inline Status TMatrix::set_uint(uint row_id, uint col_id, + std::uint32_t value) { + return Status::InvalidArgument; + } + template + Status TMatrix::set_float(uint row_id, uint col_id, float value) { + validate_rwd(FormatMatrix::CpuLil); + cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); + return Status::Ok; + } + template<> + inline Status TMatrix::set_float(uint row_id, uint col_id, float value) { + return Status::InvalidArgument; + } + + template + Status TMatrix::get_int(uint row_id, uint col_id, int32_t& value) { + validate_rw(FormatMatrix::CpuDok); + + auto& Ax = get>()->Ax; + auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; + } + template<> + inline Status TMatrix::get_int(uint row_id, uint col_id, + std::int32_t& value) { + return Status::InvalidArgument; + } + template + Status TMatrix::get_uint(uint row_id, uint col_id, uint32_t& value) { + validate_rw(FormatMatrix::CpuDok); + + auto& Ax = get>()->Ax; + auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; + } + template<> + inline Status TMatrix::get_uint(uint row_id, uint col_id, + std::uint32_t& value) { + return Status::InvalidArgument; + } + template + Status TMatrix::get_float(uint row_id, uint col_id, float& value) { + validate_rw(FormatMatrix::CpuDok); + + auto& Ax = get>()->Ax; + auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; + } + template<> + inline Status TMatrix::get_float(uint row_id, uint col_id, float& value) { + return Status::InvalidArgument; + } + + template + Status TMatrix::build(const ref_ptr& keys1, + const ref_ptr& keys2, + const ref_ptr& values) { + assert(keys1); + assert(keys2); + assert(values); + + const auto key_size = sizeof(uint); + const auto value_size = sizeof(T); + const auto elements_count = keys1->get_size() / key_size; + + if (elements_count != values->get_size() / value_size) { + return Status::InvalidArgument; + } + if (elements_count * key_size != keys1->get_size()) { + return Status::InvalidArgument; + } + if (elements_count * key_size != keys2->get_size()) { + return Status::InvalidArgument; + } + + validate_rwd(FormatMatrix::CpuCoo); + CpuCoo& coo = *get>(); + + coo.Ai.resize(elements_count); + coo.Aj.resize(elements_count); + coo.Ax.resize(elements_count); + coo.values = uint(elements_count); + + keys1->read(0, key_size * elements_count, coo.Ai.data()); + keys2->read(0, key_size * elements_count, coo.Aj.data()); + values->read(0, value_size * elements_count, coo.Ax.data()); + + return Status::Ok; + } + template + Status TMatrix::read(ref_ptr& keys1, ref_ptr& keys2, + ref_ptr& values) { + const auto key_size = sizeof(uint); + const auto value_size = sizeof(T); + + validate_rw(FormatMatrix::CpuCoo); + CpuCoo& coo = *get>(); + + const auto elements_count = coo.Ai.size(); + + keys1 = MemView::make(coo.Ai.data(), key_size * elements_count, false); + keys2 = MemView::make(coo.Aj.data(), key_size * elements_count, false); + values = MemView::make(coo.Ax.data(), value_size * elements_count, false); + + return Status::Ok; + } + + template + Status TMatrix::clear() { + m_storage.invalidate(); + return Status::Ok; + } + + template + void TMatrix::validate_rw(FormatMatrix format) { + StorageManagerMatrix* manager = get_storage_manager(); + manager->validate_rw(format, m_storage); + } + + template + void TMatrix::validate_rwd(FormatMatrix format) { + StorageManagerMatrix* manager = get_storage_manager(); + manager->validate_rwd(format, m_storage); + } + + template + void TMatrix::validate_wd(FormatMatrix format) { + StorageManagerMatrix* manager = get_storage_manager(); + manager->validate_wd(format, m_storage); + } + + template + void TMatrix::validate_ctor(FormatMatrix format) { + StorageManagerMatrix* manager = get_storage_manager(); + manager->validate_ctor(format, m_storage); + } + + template + bool TMatrix::is_valid(FormatMatrix format) const { + return m_storage.is_valid(format); + } + + template + StorageManagerMatrix* TMatrix::get_storage_manager() { + static std::unique_ptr> storage_manager; + + if (!storage_manager) { + storage_manager = std::make_unique>(); + register_formats_matrix(*storage_manager); + } + + return storage_manager.get(); + } + template<> + inline Status TMatrix::set_pair(uint row_id, uint col_id, Pair value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + + validate_rwd(FormatMatrix::CpuLil); + cpu_lil_add_element(row_id, col_id, value, *get>()); + return Status::Ok; + } + template<> + inline Status TMatrix::get_pair(uint row_id, uint col_id, Pair& value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + validate_rw(FormatMatrix::CpuDok); + + auto& Ax = get>()->Ax; + auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; + } + + /** * @} */ -} // namespace spla +}// namespace spla -#endif // SPLA_TMATRIX_HPP +#endif// SPLA_TMATRIX_HPP diff --git a/src/core/top.hpp b/src/core/top.hpp index 654e928eb..a313564a0 100644 --- a/src/core/top.hpp +++ b/src/core/top.hpp @@ -43,7 +43,7 @@ namespace spla { { \ auto func = make_ref>(); \ \ - func->function = [](A0 a) -> R __VA_ARGS__; \ + func->function = [](A0 a)->R __VA_ARGS__; \ func->name = #fname; \ \ std::stringstream source_builder; \ @@ -69,7 +69,7 @@ namespace spla { { \ auto func = make_ref>(); \ \ - func->function = [](A0 a, A1 b) -> R __VA_ARGS__; \ + func->function = [](A0 a, A1 b)->R __VA_ARGS__; \ func->name = #fname; \ \ std::stringstream source_builder; \ @@ -98,7 +98,7 @@ namespace spla { { \ auto func = make_ref>(); \ \ - func->function = [](A0 a) -> bool __VA_ARGS__; \ + func->function = [](A0 a)->bool __VA_ARGS__; \ func->name = #fname; \ \ std::stringstream source_builder; \ diff --git a/src/core/tscalar.hpp b/src/core/tscalar.hpp index 75614139c..4a07060f1 100644 --- a/src/core/tscalar.hpp +++ b/src/core/tscalar.hpp @@ -41,151 +41,168 @@ namespace spla { -/** + /** * @addtogroup internal * @{ */ -/** + /** * * @tparam T */ -template class TScalar final : public Scalar { -public: - TScalar() = default; - explicit TScalar(T value); - ~TScalar() override = default; - - ref_ptr get_type() override; - Status set_int(std::int32_t value) override; - Status set_uint(std::uint32_t value) override; - Status set_float(float value) override; - Status get_int(std::int32_t &value) override; - Status get_uint(std::uint32_t &value) override; - Status get_float(float &value) override; - T_INT as_int() override { return static_cast(m_value); } - T_UINT as_uint() override { return static_cast(m_value); } - T_FLOAT as_float() override { return static_cast(m_value); } - T_PAIR as_pair() override { return static_cast(m_value); } - - void set_label(std::string label) override; - const std::string &get_label() const override; - - T &get_value(); - T get_value() const; - -private: - std::string m_label; - T m_value = T(); -}; - -template TScalar::TScalar(T value) : m_value(value) {} - -template ref_ptr TScalar::get_type() { - return get_ttype().template as(); -} - -template Status TScalar::set_int(std::int32_t value) { - m_value = static_cast(value); - return Status::Ok; -} -template Status TScalar::set_uint(std::uint32_t value) { - m_value = static_cast(value); - return Status::Ok; -} -template Status TScalar::set_float(float value) { - m_value = static_cast(value); - return Status::Ok; -} - -template Status TScalar::get_int(std::int32_t &value) { - value = static_cast(m_value); - return Status::Ok; -} -template Status TScalar::get_uint(std::uint32_t &value) { - value = static_cast(m_value); - return Status::Ok; -} -template Status TScalar::get_float(float &value) { - value = static_cast(m_value); - return Status::Ok; -} - -template void TScalar::set_label(std::string label) { - m_label = std::move(label); -} - -template const std::string &TScalar::get_label() const { - return m_label; -} - -template T &TScalar::get_value() { return m_value; } -template T TScalar::get_value() const { return m_value; } -template <> inline T_PAIR TScalar::as_pair() { return Pair(); } -template <> inline T_PAIR TScalar::as_pair() { return Pair(); } -template <> inline T_PAIR TScalar::as_pair() { return Pair(); } - -template <> class TScalar final : public Scalar { -public: - TScalar() = default; - explicit TScalar(Pair value) : m_value(value) {} - ~TScalar() override = default; - - Status set_pair(Pair value) { - m_value = value; - return Status::Ok; - } - - Status get_pair(Pair &value) const { - value = m_value; - return Status::Ok; - } - - ref_ptr get_type() override { return PAIR; } - - Status set_int(std::int32_t) override { return Status::InvalidArgument; } - - Status set_uint(std::uint32_t) override { return Status::InvalidArgument; } - - Status set_float(float) override { return Status::InvalidArgument; } - - Status get_int(std::int32_t &) override { return Status::InvalidArgument; } - - Status get_uint(std::uint32_t &) override { return Status::InvalidArgument; } - - Status get_float(float &) override { return Status::InvalidArgument; } - - T_INT as_int() override { - LOG_MSG(Status::InvalidArgument, "cannot convert Pair to int"); - return 0; - } - - T_UINT as_uint() override { - LOG_MSG(Status::InvalidArgument, "cannot convert Pair to uint"); - return 0; - } - - T_FLOAT as_float() override { - LOG_MSG(Status::InvalidArgument, "cannot convert Pair to float"); - return 0.0f; - } - T_PAIR as_pair() override { - LOG_MSG(Status::InvalidArgument, "cannot convert Pair to pair"); - return Pair(); - } - - void set_label(std::string label) override { m_label = std::move(label); } - - const std::string &get_label() const override { return m_label; } - - Pair &get_value() { return m_value; } - Pair get_value() const { return m_value; } - -private: - std::string m_label; - Pair m_value = Pair(); -}; - -} // namespace spla - -#endif // SPLA_TSCALAR_HPP + template + class TScalar final : public Scalar { + public: + TScalar() = default; + explicit TScalar(T value); + ~TScalar() override = default; + + ref_ptr get_type() override; + Status set_int(std::int32_t value) override; + Status set_uint(std::uint32_t value) override; + Status set_float(float value) override; + Status get_int(std::int32_t& value) override; + Status get_uint(std::uint32_t& value) override; + Status get_float(float& value) override; + T_INT as_int() override { return static_cast(m_value); } + T_UINT as_uint() override { return static_cast(m_value); } + T_FLOAT as_float() override { return static_cast(m_value); } + T_PAIR as_pair() override { return static_cast(m_value); } + + void set_label(std::string label) override; + const std::string& get_label() const override; + + T& get_value(); + T get_value() const; + + private: + std::string m_label; + T m_value = T(); + }; + + template + TScalar::TScalar(T value) : m_value(value) {} + + template + ref_ptr TScalar::get_type() { + return get_ttype().template as(); + } + + template + Status TScalar::set_int(std::int32_t value) { + m_value = static_cast(value); + return Status::Ok; + } + template + Status TScalar::set_uint(std::uint32_t value) { + m_value = static_cast(value); + return Status::Ok; + } + template + Status TScalar::set_float(float value) { + m_value = static_cast(value); + return Status::Ok; + } + + template + Status TScalar::get_int(std::int32_t& value) { + value = static_cast(m_value); + return Status::Ok; + } + template + Status TScalar::get_uint(std::uint32_t& value) { + value = static_cast(m_value); + return Status::Ok; + } + template + Status TScalar::get_float(float& value) { + value = static_cast(m_value); + return Status::Ok; + } + + template + void TScalar::set_label(std::string label) { + m_label = std::move(label); + } + + template + const std::string& TScalar::get_label() const { + return m_label; + } + + template + T& TScalar::get_value() { return m_value; } + template + T TScalar::get_value() const { return m_value; } + template<> + inline T_PAIR TScalar::as_pair() { return Pair(); } + template<> + inline T_PAIR TScalar::as_pair() { return Pair(); } + template<> + inline T_PAIR TScalar::as_pair() { return Pair(); } + + template<> + class TScalar final : public Scalar { + public: + TScalar() = default; + explicit TScalar(Pair value) : m_value(value) {} + ~TScalar() override = default; + + Status set_pair(Pair value) { + m_value = value; + return Status::Ok; + } + + Status get_pair(Pair& value) const { + value = m_value; + return Status::Ok; + } + + ref_ptr get_type() override { return PAIR; } + + Status set_int(std::int32_t) override { return Status::InvalidArgument; } + + Status set_uint(std::uint32_t) override { return Status::InvalidArgument; } + + Status set_float(float) override { return Status::InvalidArgument; } + + Status get_int(std::int32_t&) override { return Status::InvalidArgument; } + + Status get_uint(std::uint32_t&) override { return Status::InvalidArgument; } + + Status get_float(float&) override { return Status::InvalidArgument; } + + T_INT as_int() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to int"); + return 0; + } + + T_UINT as_uint() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to uint"); + return 0; + } + + T_FLOAT as_float() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to float"); + return 0.0f; + } + T_PAIR as_pair() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to pair"); + return Pair(); + } + + void set_label(std::string label) override { m_label = std::move(label); } + + const std::string& get_label() const override { return m_label; } + + Pair& get_value() { return m_value; } + Pair get_value() const { return m_value; } + + private: + std::string m_label; + Pair m_value = Pair(); + }; + +}// namespace spla + +#endif// SPLA_TSCALAR_HPP diff --git a/src/core/ttype.hpp b/src/core/ttype.hpp index 2ca57ebfe..136401a7f 100644 --- a/src/core/ttype.hpp +++ b/src/core/ttype.hpp @@ -129,8 +129,9 @@ namespace spla { ref_ptr> get_ttype() { return FLOAT.cast_safe>(); } - template <> ref_ptr> get_ttype() { - return PAIR.cast_safe>(); + template<> + ref_ptr> get_ttype() { + return PAIR.cast_safe>(); } /** diff --git a/src/core/tvector.hpp b/src/core/tvector.hpp index e565f9c28..ed48a3a44 100644 --- a/src/core/tvector.hpp +++ b/src/core/tvector.hpp @@ -53,380 +53,403 @@ namespace spla { -/** + /** * @addtogroup internal * @{ */ -/** + /** * @class TVector * @brief Vector interface implementation with type information bound * * @tparam T Type of stored elements */ -template class TVector final : public Vector { -public: - explicit TVector(uint n_rows); - ~TVector() override = default; - - uint get_n_rows() override; - ref_ptr get_type() override; - void set_label(std::string label) override; - const std::string &get_label() const override; - Status set_format(FormatVector format) override; - Status set_fill_value(const ref_ptr &value) override; - Status set_reduce(ref_ptr resolve_duplicates) override; - Status set_int(uint row_id, std::int32_t value) override; - Status set_uint(uint row_id, std::uint32_t value) override; - Status set_float(uint row_id, float value) override; - Status set_pair(uint row_id, Pair value) override { - return Status::InvalidArgument; - } - Status get_int(uint row_id, int32_t &value) override; - Status get_uint(uint row_id, uint32_t &value) override; - Status get_float(uint row_id, float &value) override; - Status get_pair(uint row_id, Pair &value) override { - return Status::InvalidArgument; - } - Status fill_noize(uint seed) override; - Status fill_with(const ref_ptr &value) override; - Status build(const ref_ptr &keys, - const ref_ptr &values) override; - Status read(ref_ptr &keys, ref_ptr &values) override; - Status clear() override; - - template Decorator *get() { - return m_storage.template get(); - } - - void validate_rw(FormatVector format); - void validate_rwd(FormatVector format); - void validate_wd(FormatVector format); - void validate_ctor(FormatVector format); - bool is_valid(FormatVector format) const; - T get_fill_value() const { return m_storage.get_fill_value(); } - - static StorageManagerVector *get_storage_manager(); - -private: - typename StorageManagerVector::Storage m_storage; - std::string m_label; -}; - -template TVector::TVector(uint n_rows) { - m_storage.set_dims(n_rows, 1); -} - -template uint TVector::get_n_rows() { - return m_storage.get_n_rows(); -} -template ref_ptr TVector::get_type() { - return get_ttype().template as(); -} - -template void TVector::set_label(std::string label) { - m_label = std::move(label); - LOG_MSG(Status::Ok, "set label '" << m_label << "' to " << (void *)this); -} -template const std::string &TVector::get_label() const { - return m_label; -} - -template Status TVector::set_format(FormatVector format) { - validate_rw(format); - return Status::Ok; -} -template -Status TVector::set_fill_value(const ref_ptr &value) { - if (value) { - m_storage.invalidate(); - - if constexpr (std::is_same::value) - m_storage.set_fill_value(value->as_int()); - if constexpr (std::is_same::value) - m_storage.set_fill_value(value->as_uint()); - if constexpr (std::is_same::value) - m_storage.set_fill_value(value->as_float()); - - return Status::Ok; - } - - return Status::InvalidArgument; -} -template -Status TVector::set_reduce(ref_ptr resolve_duplicates) { - auto reduce = resolve_duplicates.template cast_safe>(); - - if (reduce) { - validate_ctor(FormatVector::CpuDok); - auto *vec = get>(); - vec->reduce = reduce->function; - return Status::Ok; - } - - return Status::InvalidArgument; -} - -template -Status TVector::set_int(uint row_id, std::int32_t value) { - if (is_valid(FormatVector::CpuDense)) { - validate_rwd(FormatVector::CpuDense); - get>()->Ax[row_id] = static_cast(value); - return Status::Ok; - } - - validate_rwd(FormatVector::CpuDok); - cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); - return Status::Ok; -} -template <> -inline Status TVector::set_int(uint row_id, std::int32_t value) { - return Status::InvalidArgument; -} - -template -Status TVector::set_uint(uint row_id, std::uint32_t value) { - if (is_valid(FormatVector::CpuDense)) { - validate_rwd(FormatVector::CpuDense); - get>()->Ax[row_id] = static_cast(value); - return Status::Ok; - } - - validate_rwd(FormatVector::CpuDok); - cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); - return Status::Ok; -} -template <> -inline Status TVector::set_uint(uint row_id, std::uint32_t value) { - return Status::InvalidArgument; -} -template Status TVector::set_float(uint row_id, float value) { - if (is_valid(FormatVector::CpuDense)) { - validate_rwd(FormatVector::CpuDense); - get>()->Ax[row_id] = static_cast(value); - return Status::Ok; - } - - validate_rwd(FormatVector::CpuDok); - cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); - return Status::Ok; -} -template <> inline Status TVector::set_float(uint row_id, float value) { - return Status::InvalidArgument; -} - -template Status TVector::get_int(uint row_id, int32_t &value) { - validate_rw(FormatVector::CpuDok); - - const auto &Ax = get>()->Ax; - const auto entry = Ax.find(row_id); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; -} -template <> inline Status TVector::get_int(uint row_id, int32_t &value) { - return Status::InvalidArgument; -} -template -Status TVector::get_uint(uint row_id, uint32_t &value) { - validate_rw(FormatVector::CpuDok); - - const auto &Ax = get>()->Ax; - const auto entry = Ax.find(row_id); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; -} -template <> -inline Status TVector::get_uint(uint row_id, uint32_t &value) { - return Status::InvalidArgument; -} -template Status TVector::get_float(uint row_id, float &value) { - validate_rw(FormatVector::CpuDok); - - const auto &Ax = get>()->Ax; - const auto entry = Ax.find(row_id); - value = m_storage.get_fill_value(); - - if (entry != Ax.end()) { - value = static_cast(entry->second); - } - - return Status::Ok; -} -template <> inline Status TVector::get_float(uint row_id, float &value) { - return Status::InvalidArgument; -} - -template Status TVector::fill_noize(uint seed) { - validate_wd(FormatVector::CpuDense); - auto &Ax = get>()->Ax; - auto engine = std::default_random_engine(seed); - - if constexpr (std::is_integral_v) { - std::uniform_int_distribution dist; - for (auto &x : Ax) - x = dist(engine); - } - if constexpr (std::is_floating_point_v) { - std::uniform_real_distribution dist; - for (auto &x : Ax) - x = dist(engine); - } - - return Status::Ok; -} -template -Status TVector::fill_with(const ref_ptr &value) { - assert(value); - - T t = T(); - - if constexpr (std::is_same::value) - t = value->as_int(); - if constexpr (std::is_same::value) - t = value->as_uint(); - if constexpr (std::is_same::value) - t = value->as_float(); - - validate_wd(FormatVector::CpuDense); - auto &Ax = get>()->Ax; - std::fill(Ax.begin(), Ax.end(), t); - - return Status::Ok; -} - -template -Status TVector::build(const ref_ptr &keys, - const ref_ptr &values) { - assert(keys); - assert(values); - - const auto key_size = sizeof(uint); - const auto value_size = sizeof(T); - const auto elements_count = keys->get_size() / key_size; - - if (elements_count != values->get_size() / value_size) { - return Status::InvalidArgument; - } - if (elements_count * key_size != keys->get_size()) { - return Status::InvalidArgument; - } - - validate_rwd(FormatVector::CpuCoo); - CpuCooVec &coo = *get>(); - - coo.Ai.resize(elements_count); - coo.Ax.resize(elements_count); - coo.values = uint(elements_count); - - keys->read(0, key_size * elements_count, coo.Ai.data()); - values->read(0, value_size * elements_count, coo.Ax.data()); - - return Status::Ok; -} -template -Status TVector::read(ref_ptr &keys, ref_ptr &values) { - const auto key_size = sizeof(uint); - const auto value_size = sizeof(T); - - validate_rw(FormatVector::CpuCoo); - CpuCooVec &coo = *get>(); - - const auto elements_count = coo.Ai.size(); - - keys = MemView::make(coo.Ai.data(), key_size * elements_count, false); - values = MemView::make(coo.Ax.data(), value_size * elements_count, false); - - return Status::Ok; -} - -template Status TVector::clear() { - m_storage.invalidate(); - return Status::Ok; -} - -template void TVector::validate_rw(FormatVector format) { - StorageManagerVector *manager = get_storage_manager(); - manager->validate_rw(format, m_storage); -} - -template void TVector::validate_rwd(FormatVector format) { - StorageManagerVector *manager = get_storage_manager(); - manager->validate_rwd(format, m_storage); -} - -template void TVector::validate_wd(FormatVector format) { - StorageManagerVector *manager = get_storage_manager(); - manager->validate_wd(format, m_storage); -} - -template void TVector::validate_ctor(FormatVector format) { - StorageManagerVector *manager = get_storage_manager(); - manager->validate_ctor(format, m_storage); -} - -template bool TVector::is_valid(FormatVector format) const { - return m_storage.is_valid(format); -} - -template -StorageManagerVector *TVector::get_storage_manager() { - static std::unique_ptr> storage_manager; - - if (!storage_manager) { - storage_manager = std::make_unique>(); - register_formats_vector(*storage_manager); - } - - return storage_manager.get(); -} -template <> inline Status TVector::set_pair(uint row_id, Pair value) { - if (get_type() != PAIR) { - return Status::InvalidArgument; - } - - if (is_valid(FormatVector::CpuDense)) { - validate_rwd(FormatVector::CpuDense); - get>()->Ax[row_id] = value; - return Status::Ok; - } - - validate_rwd(FormatVector::CpuDok); - cpu_dok_vec_add_element(row_id, value, *get>()); - return Status::Ok; -} -template <> inline Status TVector::get_pair(uint row_id, Pair &value) { - if (get_type() != PAIR) { - return Status::InvalidArgument; - } - - validate_rw(FormatVector::CpuDok); - - const auto &Ax = get>()->Ax; - const auto entry = Ax.find(row_id); - - if (entry != Ax.end()) { - value = entry->second; - } else { - value = m_storage.get_fill_value(); - } - - return Status::Ok; -} - -/** + template + class TVector final : public Vector { + public: + explicit TVector(uint n_rows); + ~TVector() override = default; + + uint get_n_rows() override; + ref_ptr get_type() override; + void set_label(std::string label) override; + const std::string& get_label() const override; + Status set_format(FormatVector format) override; + Status set_fill_value(const ref_ptr& value) override; + Status set_reduce(ref_ptr resolve_duplicates) override; + Status set_int(uint row_id, std::int32_t value) override; + Status set_uint(uint row_id, std::uint32_t value) override; + Status set_float(uint row_id, float value) override; + Status set_pair(uint row_id, Pair value) override { + return Status::InvalidArgument; + } + Status get_int(uint row_id, int32_t& value) override; + Status get_uint(uint row_id, uint32_t& value) override; + Status get_float(uint row_id, float& value) override; + Status get_pair(uint row_id, Pair& value) override { + return Status::InvalidArgument; + } + Status fill_noize(uint seed) override; + Status fill_with(const ref_ptr& value) override; + Status build(const ref_ptr& keys, + const ref_ptr& values) override; + Status read(ref_ptr& keys, ref_ptr& values) override; + Status clear() override; + + template + Decorator* get() { + return m_storage.template get(); + } + + void validate_rw(FormatVector format); + void validate_rwd(FormatVector format); + void validate_wd(FormatVector format); + void validate_ctor(FormatVector format); + bool is_valid(FormatVector format) const; + T get_fill_value() const { return m_storage.get_fill_value(); } + + static StorageManagerVector* get_storage_manager(); + + private: + typename StorageManagerVector::Storage m_storage; + std::string m_label; + }; + + template + TVector::TVector(uint n_rows) { + m_storage.set_dims(n_rows, 1); + } + + template + uint TVector::get_n_rows() { + return m_storage.get_n_rows(); + } + template + ref_ptr TVector::get_type() { + return get_ttype().template as(); + } + + template + void TVector::set_label(std::string label) { + m_label = std::move(label); + LOG_MSG(Status::Ok, "set label '" << m_label << "' to " << (void*) this); + } + template + const std::string& TVector::get_label() const { + return m_label; + } + + template + Status TVector::set_format(FormatVector format) { + validate_rw(format); + return Status::Ok; + } + template + Status TVector::set_fill_value(const ref_ptr& value) { + if (value) { + m_storage.invalidate(); + + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_int()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_uint()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_float()); + + return Status::Ok; + } + + return Status::InvalidArgument; + } + template + Status TVector::set_reduce(ref_ptr resolve_duplicates) { + auto reduce = resolve_duplicates.template cast_safe>(); + + if (reduce) { + validate_ctor(FormatVector::CpuDok); + auto* vec = get>(); + vec->reduce = reduce->function; + return Status::Ok; + } + + return Status::InvalidArgument; + } + + template + Status TVector::set_int(uint row_id, std::int32_t value) { + if (is_valid(FormatVector::CpuDense)) { + validate_rwd(FormatVector::CpuDense); + get>()->Ax[row_id] = static_cast(value); + return Status::Ok; + } + + validate_rwd(FormatVector::CpuDok); + cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); + return Status::Ok; + } + template<> + inline Status TVector::set_int(uint row_id, std::int32_t value) { + return Status::InvalidArgument; + } + + template + Status TVector::set_uint(uint row_id, std::uint32_t value) { + if (is_valid(FormatVector::CpuDense)) { + validate_rwd(FormatVector::CpuDense); + get>()->Ax[row_id] = static_cast(value); + return Status::Ok; + } + + validate_rwd(FormatVector::CpuDok); + cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); + return Status::Ok; + } + template<> + inline Status TVector::set_uint(uint row_id, std::uint32_t value) { + return Status::InvalidArgument; + } + template + Status TVector::set_float(uint row_id, float value) { + if (is_valid(FormatVector::CpuDense)) { + validate_rwd(FormatVector::CpuDense); + get>()->Ax[row_id] = static_cast(value); + return Status::Ok; + } + + validate_rwd(FormatVector::CpuDok); + cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); + return Status::Ok; + } + template<> + inline Status TVector::set_float(uint row_id, float value) { + return Status::InvalidArgument; + } + + template + Status TVector::get_int(uint row_id, int32_t& value) { + validate_rw(FormatVector::CpuDok); + + const auto& Ax = get>()->Ax; + const auto entry = Ax.find(row_id); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; + } + template<> + inline Status TVector::get_int(uint row_id, int32_t& value) { + return Status::InvalidArgument; + } + template + Status TVector::get_uint(uint row_id, uint32_t& value) { + validate_rw(FormatVector::CpuDok); + + const auto& Ax = get>()->Ax; + const auto entry = Ax.find(row_id); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; + } + template<> + inline Status TVector::get_uint(uint row_id, uint32_t& value) { + return Status::InvalidArgument; + } + template + Status TVector::get_float(uint row_id, float& value) { + validate_rw(FormatVector::CpuDok); + + const auto& Ax = get>()->Ax; + const auto entry = Ax.find(row_id); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; + } + template<> + inline Status TVector::get_float(uint row_id, float& value) { + return Status::InvalidArgument; + } + + template + Status TVector::fill_noize(uint seed) { + validate_wd(FormatVector::CpuDense); + auto& Ax = get>()->Ax; + auto engine = std::default_random_engine(seed); + + if constexpr (std::is_integral_v) { + std::uniform_int_distribution dist; + for (auto& x : Ax) + x = dist(engine); + } + if constexpr (std::is_floating_point_v) { + std::uniform_real_distribution dist; + for (auto& x : Ax) + x = dist(engine); + } + + return Status::Ok; + } + template + Status TVector::fill_with(const ref_ptr& value) { + assert(value); + + T t = T(); + + if constexpr (std::is_same::value) + t = value->as_int(); + if constexpr (std::is_same::value) + t = value->as_uint(); + if constexpr (std::is_same::value) + t = value->as_float(); + + validate_wd(FormatVector::CpuDense); + auto& Ax = get>()->Ax; + std::fill(Ax.begin(), Ax.end(), t); + + return Status::Ok; + } + + template + Status TVector::build(const ref_ptr& keys, + const ref_ptr& values) { + assert(keys); + assert(values); + + const auto key_size = sizeof(uint); + const auto value_size = sizeof(T); + const auto elements_count = keys->get_size() / key_size; + + if (elements_count != values->get_size() / value_size) { + return Status::InvalidArgument; + } + if (elements_count * key_size != keys->get_size()) { + return Status::InvalidArgument; + } + + validate_rwd(FormatVector::CpuCoo); + CpuCooVec& coo = *get>(); + + coo.Ai.resize(elements_count); + coo.Ax.resize(elements_count); + coo.values = uint(elements_count); + + keys->read(0, key_size * elements_count, coo.Ai.data()); + values->read(0, value_size * elements_count, coo.Ax.data()); + + return Status::Ok; + } + template + Status TVector::read(ref_ptr& keys, ref_ptr& values) { + const auto key_size = sizeof(uint); + const auto value_size = sizeof(T); + + validate_rw(FormatVector::CpuCoo); + CpuCooVec& coo = *get>(); + + const auto elements_count = coo.Ai.size(); + + keys = MemView::make(coo.Ai.data(), key_size * elements_count, false); + values = MemView::make(coo.Ax.data(), value_size * elements_count, false); + + return Status::Ok; + } + + template + Status TVector::clear() { + m_storage.invalidate(); + return Status::Ok; + } + + template + void TVector::validate_rw(FormatVector format) { + StorageManagerVector* manager = get_storage_manager(); + manager->validate_rw(format, m_storage); + } + + template + void TVector::validate_rwd(FormatVector format) { + StorageManagerVector* manager = get_storage_manager(); + manager->validate_rwd(format, m_storage); + } + + template + void TVector::validate_wd(FormatVector format) { + StorageManagerVector* manager = get_storage_manager(); + manager->validate_wd(format, m_storage); + } + + template + void TVector::validate_ctor(FormatVector format) { + StorageManagerVector* manager = get_storage_manager(); + manager->validate_ctor(format, m_storage); + } + + template + bool TVector::is_valid(FormatVector format) const { + return m_storage.is_valid(format); + } + + template + StorageManagerVector* TVector::get_storage_manager() { + static std::unique_ptr> storage_manager; + + if (!storage_manager) { + storage_manager = std::make_unique>(); + register_formats_vector(*storage_manager); + } + + return storage_manager.get(); + } + template<> + inline Status TVector::set_pair(uint row_id, Pair value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + + if (is_valid(FormatVector::CpuDense)) { + validate_rwd(FormatVector::CpuDense); + get>()->Ax[row_id] = value; + return Status::Ok; + } + + validate_rwd(FormatVector::CpuDok); + cpu_dok_vec_add_element(row_id, value, *get>()); + return Status::Ok; + } + template<> + inline Status TVector::get_pair(uint row_id, Pair& value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + + validate_rw(FormatVector::CpuDok); + + const auto& Ax = get>()->Ax; + const auto entry = Ax.find(row_id); + + if (entry != Ax.end()) { + value = entry->second; + } else { + value = m_storage.get_fill_value(); + } + + return Status::Ok; + } + + /** * @} */ -} // namespace spla +}// namespace spla -#endif // SPLA_TVECTOR_HPP +#endif// SPLA_TVECTOR_HPP diff --git a/src/opencl/cl_mxv.hpp b/src/opencl/cl_mxv.hpp index f7a04317c..6e4bc8a88 100644 --- a/src/opencl/cl_mxv.hpp +++ b/src/opencl/cl_mxv.hpp @@ -56,259 +56,260 @@ namespace spla { -template class Algo_mxv_masked_cl final : public RegistryAlgo { -public: - ~Algo_mxv_masked_cl() override = default; - - std::string get_name() override { return "mxv_masked"; } - - std::string get_description() override { - return "parallel matrix-vector masked product on opencl device"; - } - - Status execute(const DispatchContext &ctx) override { - auto t = ctx.task.template cast_safe(); - auto early_exit = t->get_desc_or_default()->get_early_exit(); - - if (early_exit) { - return execute_config_scalar(ctx); - } else { - return execute_vector(ctx); - } - } - -private: - Status execute_vector(const DispatchContext &ctx) { - TIME_PROFILE_SCOPE("opencl/mxv/vector"); - - auto t = ctx.task.template cast_safe(); - - ref_ptr> r = t->r.template cast_safe>(); - ref_ptr> mask = t->mask.template cast_safe>(); - ref_ptr> M = t->M.template cast_safe>(); - ref_ptr> v = t->v.template cast_safe>(); - ref_ptr> op_multiply = - t->op_multiply.template cast_safe>(); - ref_ptr> op_add = - t->op_add.template cast_safe>(); - ref_ptr> op_select = - t->op_select.template cast_safe>(); - ref_ptr> init = t->init.template cast_safe>(); - - r->validate_wd(FormatVector::AccDense); - mask->validate_rw(FormatVector::AccDense); - M->validate_rw(FormatMatrix::AccCsr); - v->validate_rw(FormatVector::AccDense); - - std::shared_ptr program; - if (!ensure_kernel(op_multiply, op_add, op_select, program)) - return Status::CompilationError; - - auto *p_cl_r = r->template get>(); - auto *p_cl_mask = mask->template get>(); - auto *p_cl_M = M->template get>(); - auto *p_cl_v = v->template get>(); - - auto *p_cl_acc = get_acc_cl(); - auto &queue = p_cl_acc->get_queue_default(); - - auto kernel_vector = program->make_kernel("mxv_vector"); - kernel_vector.setArg(0, p_cl_M->Ap); - kernel_vector.setArg(1, p_cl_M->Aj); - kernel_vector.setArg(2, p_cl_M->Ax); - kernel_vector.setArg(3, p_cl_v->Ax); - kernel_vector.setArg(4, p_cl_mask->Ax); - kernel_vector.setArg(5, p_cl_r->Ax); - kernel_vector.setArg(6, init->get_value()); - kernel_vector.setArg(7, r->get_n_rows()); - - uint n_groups_to_dispatch = - div_up_clamp(r->get_n_rows(), m_block_count, 1, 512); - - cl::NDRange exec_global(m_block_count * n_groups_to_dispatch, m_block_size); - cl::NDRange exec_local(m_block_count, m_block_size); - CL_DISPATCH_PROFILED("exec", queue, kernel_vector, cl::NDRange(), - exec_global, exec_local); - - return Status::Ok; - } - - Status execute_scalar(const DispatchContext &ctx) { - TIME_PROFILE_SCOPE("opencl/mxv/scalar"); - - auto t = ctx.task.template cast_safe(); - - ref_ptr> r = t->r.template cast_safe>(); - ref_ptr> mask = t->mask.template cast_safe>(); - ref_ptr> M = t->M.template cast_safe>(); - ref_ptr> v = t->v.template cast_safe>(); - ref_ptr> op_multiply = - t->op_multiply.template cast_safe>(); - ref_ptr> op_add = - t->op_add.template cast_safe>(); - ref_ptr> op_select = - t->op_select.template cast_safe>(); - ref_ptr> init = t->init.template cast_safe>(); - - r->validate_wd(FormatVector::AccDense); - mask->validate_rw(FormatVector::AccDense); - M->validate_rw(FormatMatrix::AccCsr); - v->validate_rw(FormatVector::AccDense); - - std::shared_ptr program; - if (!ensure_kernel(op_multiply, op_add, op_select, program)) - return Status::CompilationError; - - auto *p_cl_r = r->template get>(); - auto *p_cl_mask = mask->template get>(); - auto *p_cl_M = M->template get>(); - auto *p_cl_v = v->template get>(); - auto early_exit = t->get_desc_or_default()->get_early_exit(); - - auto *p_cl_acc = get_acc_cl(); - auto &queue = p_cl_acc->get_queue_default(); - - auto kernel_scalar = program->make_kernel("mxv_scalar"); - kernel_scalar.setArg(0, p_cl_M->Ap); - kernel_scalar.setArg(1, p_cl_M->Aj); - kernel_scalar.setArg(2, p_cl_M->Ax); - kernel_scalar.setArg(3, p_cl_v->Ax); - kernel_scalar.setArg(4, p_cl_mask->Ax); - kernel_scalar.setArg(5, p_cl_r->Ax); - kernel_scalar.setArg(6, init->get_value()); - kernel_scalar.setArg(7, r->get_n_rows()); - kernel_scalar.setArg(8, uint(early_exit)); - - uint n_groups_to_dispatch = - div_up_clamp(r->get_n_rows(), m_block_size, 1, 512); - - cl::NDRange exec_global(m_block_size * n_groups_to_dispatch); - cl::NDRange exec_local(m_block_size); - CL_DISPATCH_PROFILED("exec", queue, kernel_scalar, cl::NDRange(), - exec_global, exec_local); - - return Status::Ok; - } - - Status execute_config_scalar(const DispatchContext &ctx) { - TIME_PROFILE_SCOPE("opencl/mxv/config-scalar"); - - auto t = ctx.task.template cast_safe(); - - ref_ptr> r = t->r.template cast_safe>(); - ref_ptr> mask = t->mask.template cast_safe>(); - ref_ptr> M = t->M.template cast_safe>(); - ref_ptr> v = t->v.template cast_safe>(); - ref_ptr> op_multiply = - t->op_multiply.template cast_safe>(); - ref_ptr> op_add = - t->op_add.template cast_safe>(); - ref_ptr> op_select = - t->op_select.template cast_safe>(); - ref_ptr> init = t->init.template cast_safe>(); - - r->validate_wd(FormatVector::AccDense); - mask->validate_rw(FormatVector::AccDense); - M->validate_rw(FormatMatrix::AccCsr); - v->validate_rw(FormatVector::AccDense); - - std::shared_ptr program; - if (!ensure_kernel(op_multiply, op_add, op_select, program)) - return Status::CompilationError; - - auto *p_cl_r = r->template get>(); - auto *p_cl_mask = mask->template get>(); - auto *p_cl_M = M->template get>(); - auto *p_cl_v = v->template get>(); - auto early_exit = t->get_desc_or_default()->get_early_exit(); - - auto *p_cl_acc = get_acc_cl(); - auto &queue = p_cl_acc->get_queue_default(); - - uint config_size = 0; - cl::Buffer cl_config(p_cl_acc->get_context(), - CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS, - sizeof(uint) * M->get_n_rows()); - cl::Buffer cl_config_size(p_cl_acc->get_context(), - CL_MEM_READ_WRITE | CL_MEM_HOST_READ_ONLY | - CL_MEM_COPY_HOST_PTR, - sizeof(uint), &config_size); - - auto kernel_config = program->make_kernel("mxv_config"); - kernel_config.setArg(0, p_cl_mask->Ax); - kernel_config.setArg(1, p_cl_r->Ax); - kernel_config.setArg(2, cl_config); - kernel_config.setArg(3, cl_config_size); - kernel_config.setArg(4, init->get_value()); - kernel_config.setArg(5, M->get_n_rows()); - - uint n_groups_to_dispatch = - div_up_clamp(r->get_n_rows(), m_block_size, 1, 1024); - - cl::NDRange config_global(m_block_size * n_groups_to_dispatch); - cl::NDRange config_local(m_block_size); - CL_DISPATCH_PROFILED("config", queue, kernel_config, cl::NDRange(), - config_global, config_local); - - CL_READ_PROFILED("config-size", queue, cl_config_size, true, 0, - sizeof(config_size), &config_size); - - auto kernel_config_scalar = program->make_kernel("mxv_config_scalar"); - kernel_config_scalar.setArg(0, p_cl_M->Ap); - kernel_config_scalar.setArg(1, p_cl_M->Aj); - kernel_config_scalar.setArg(2, p_cl_M->Ax); - kernel_config_scalar.setArg(3, p_cl_v->Ax); - kernel_config_scalar.setArg(4, cl_config); - kernel_config_scalar.setArg(5, p_cl_r->Ax); - kernel_config_scalar.setArg(6, init->get_value()); - kernel_config_scalar.setArg(7, config_size); - kernel_config_scalar.setArg(8, uint(early_exit)); - - n_groups_to_dispatch = div_up_clamp(config_size, m_block_size, 1, 1024); - - cl::NDRange exec_global(m_block_size * n_groups_to_dispatch); - cl::NDRange exec_local(m_block_size); - CL_DISPATCH_PROFILED("exec", queue, kernel_config_scalar, cl::NDRange(), - exec_global, exec_local); - - return Status::Ok; - } - - bool ensure_kernel(const ref_ptr> &op_multiply, - const ref_ptr> &op_add, - const ref_ptr> &op_select, - std::shared_ptr &program) { - m_block_size = get_acc_cl()->get_wave_size(); - m_block_count = 1; - - assert(m_block_count >= 1); - - CLProgramBuilder program_builder; - program_builder.set_name("mxv") - .add_define("WARP_SIZE", get_acc_cl()->get_wave_size()) - .add_define("BLOCK_SIZE", m_block_size) - .add_define("BLOCK_COUNT", m_block_count) - .add_type("TYPE", get_ttype().template as()); - - if constexpr (std::is_same_v) { - program_builder.add_define("USE_PAIR_SEMANTICS", 1); - program_builder.add_define("USE_PAIR_COMPARISON", 1); - } else { - program_builder.add_op("OP_BINARY1", op_multiply.template as()) - .add_op("OP_BINARY2", op_add.template as()) - .add_op("OP_SELECT", op_select.template as()); - } - program_builder.set_source(source_mxv).acquire(); - program = program_builder.get_program(); - - return true; - } - -private: - uint m_block_size = 0; - uint m_block_count = 0; -}; - -} // namespace spla - -#endif // SPLA_CL_MXV_HPP + template + class Algo_mxv_masked_cl final : public RegistryAlgo { + public: + ~Algo_mxv_masked_cl() override = default; + + std::string get_name() override { return "mxv_masked"; } + + std::string get_description() override { + return "parallel matrix-vector masked product on opencl device"; + } + + Status execute(const DispatchContext& ctx) override { + auto t = ctx.task.template cast_safe(); + auto early_exit = t->get_desc_or_default()->get_early_exit(); + + if (early_exit) { + return execute_config_scalar(ctx); + } else { + return execute_vector(ctx); + } + } + + private: + Status execute_vector(const DispatchContext& ctx) { + TIME_PROFILE_SCOPE("opencl/mxv/vector"); + + auto t = ctx.task.template cast_safe(); + + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> mask = t->mask.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> v = t->v.template cast_safe>(); + ref_ptr> op_multiply = + t->op_multiply.template cast_safe>(); + ref_ptr> op_add = + t->op_add.template cast_safe>(); + ref_ptr> op_select = + t->op_select.template cast_safe>(); + ref_ptr> init = t->init.template cast_safe>(); + + r->validate_wd(FormatVector::AccDense); + mask->validate_rw(FormatVector::AccDense); + M->validate_rw(FormatMatrix::AccCsr); + v->validate_rw(FormatVector::AccDense); + + std::shared_ptr program; + if (!ensure_kernel(op_multiply, op_add, op_select, program)) + return Status::CompilationError; + + auto* p_cl_r = r->template get>(); + auto* p_cl_mask = mask->template get>(); + auto* p_cl_M = M->template get>(); + auto* p_cl_v = v->template get>(); + + auto* p_cl_acc = get_acc_cl(); + auto& queue = p_cl_acc->get_queue_default(); + + auto kernel_vector = program->make_kernel("mxv_vector"); + kernel_vector.setArg(0, p_cl_M->Ap); + kernel_vector.setArg(1, p_cl_M->Aj); + kernel_vector.setArg(2, p_cl_M->Ax); + kernel_vector.setArg(3, p_cl_v->Ax); + kernel_vector.setArg(4, p_cl_mask->Ax); + kernel_vector.setArg(5, p_cl_r->Ax); + kernel_vector.setArg(6, init->get_value()); + kernel_vector.setArg(7, r->get_n_rows()); + + uint n_groups_to_dispatch = + div_up_clamp(r->get_n_rows(), m_block_count, 1, 512); + + cl::NDRange exec_global(m_block_count * n_groups_to_dispatch, m_block_size); + cl::NDRange exec_local(m_block_count, m_block_size); + CL_DISPATCH_PROFILED("exec", queue, kernel_vector, cl::NDRange(), + exec_global, exec_local); + + return Status::Ok; + } + + Status execute_scalar(const DispatchContext& ctx) { + TIME_PROFILE_SCOPE("opencl/mxv/scalar"); + + auto t = ctx.task.template cast_safe(); + + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> mask = t->mask.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> v = t->v.template cast_safe>(); + ref_ptr> op_multiply = + t->op_multiply.template cast_safe>(); + ref_ptr> op_add = + t->op_add.template cast_safe>(); + ref_ptr> op_select = + t->op_select.template cast_safe>(); + ref_ptr> init = t->init.template cast_safe>(); + + r->validate_wd(FormatVector::AccDense); + mask->validate_rw(FormatVector::AccDense); + M->validate_rw(FormatMatrix::AccCsr); + v->validate_rw(FormatVector::AccDense); + + std::shared_ptr program; + if (!ensure_kernel(op_multiply, op_add, op_select, program)) + return Status::CompilationError; + + auto* p_cl_r = r->template get>(); + auto* p_cl_mask = mask->template get>(); + auto* p_cl_M = M->template get>(); + auto* p_cl_v = v->template get>(); + auto early_exit = t->get_desc_or_default()->get_early_exit(); + + auto* p_cl_acc = get_acc_cl(); + auto& queue = p_cl_acc->get_queue_default(); + + auto kernel_scalar = program->make_kernel("mxv_scalar"); + kernel_scalar.setArg(0, p_cl_M->Ap); + kernel_scalar.setArg(1, p_cl_M->Aj); + kernel_scalar.setArg(2, p_cl_M->Ax); + kernel_scalar.setArg(3, p_cl_v->Ax); + kernel_scalar.setArg(4, p_cl_mask->Ax); + kernel_scalar.setArg(5, p_cl_r->Ax); + kernel_scalar.setArg(6, init->get_value()); + kernel_scalar.setArg(7, r->get_n_rows()); + kernel_scalar.setArg(8, uint(early_exit)); + + uint n_groups_to_dispatch = + div_up_clamp(r->get_n_rows(), m_block_size, 1, 512); + + cl::NDRange exec_global(m_block_size * n_groups_to_dispatch); + cl::NDRange exec_local(m_block_size); + CL_DISPATCH_PROFILED("exec", queue, kernel_scalar, cl::NDRange(), + exec_global, exec_local); + + return Status::Ok; + } + + Status execute_config_scalar(const DispatchContext& ctx) { + TIME_PROFILE_SCOPE("opencl/mxv/config-scalar"); + + auto t = ctx.task.template cast_safe(); + + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> mask = t->mask.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> v = t->v.template cast_safe>(); + ref_ptr> op_multiply = + t->op_multiply.template cast_safe>(); + ref_ptr> op_add = + t->op_add.template cast_safe>(); + ref_ptr> op_select = + t->op_select.template cast_safe>(); + ref_ptr> init = t->init.template cast_safe>(); + + r->validate_wd(FormatVector::AccDense); + mask->validate_rw(FormatVector::AccDense); + M->validate_rw(FormatMatrix::AccCsr); + v->validate_rw(FormatVector::AccDense); + + std::shared_ptr program; + if (!ensure_kernel(op_multiply, op_add, op_select, program)) + return Status::CompilationError; + + auto* p_cl_r = r->template get>(); + auto* p_cl_mask = mask->template get>(); + auto* p_cl_M = M->template get>(); + auto* p_cl_v = v->template get>(); + auto early_exit = t->get_desc_or_default()->get_early_exit(); + + auto* p_cl_acc = get_acc_cl(); + auto& queue = p_cl_acc->get_queue_default(); + + uint config_size = 0; + cl::Buffer cl_config(p_cl_acc->get_context(), + CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS, + sizeof(uint) * M->get_n_rows()); + cl::Buffer cl_config_size(p_cl_acc->get_context(), + CL_MEM_READ_WRITE | CL_MEM_HOST_READ_ONLY | + CL_MEM_COPY_HOST_PTR, + sizeof(uint), &config_size); + + auto kernel_config = program->make_kernel("mxv_config"); + kernel_config.setArg(0, p_cl_mask->Ax); + kernel_config.setArg(1, p_cl_r->Ax); + kernel_config.setArg(2, cl_config); + kernel_config.setArg(3, cl_config_size); + kernel_config.setArg(4, init->get_value()); + kernel_config.setArg(5, M->get_n_rows()); + + uint n_groups_to_dispatch = + div_up_clamp(r->get_n_rows(), m_block_size, 1, 1024); + + cl::NDRange config_global(m_block_size * n_groups_to_dispatch); + cl::NDRange config_local(m_block_size); + CL_DISPATCH_PROFILED("config", queue, kernel_config, cl::NDRange(), + config_global, config_local); + + CL_READ_PROFILED("config-size", queue, cl_config_size, true, 0, + sizeof(config_size), &config_size); + + auto kernel_config_scalar = program->make_kernel("mxv_config_scalar"); + kernel_config_scalar.setArg(0, p_cl_M->Ap); + kernel_config_scalar.setArg(1, p_cl_M->Aj); + kernel_config_scalar.setArg(2, p_cl_M->Ax); + kernel_config_scalar.setArg(3, p_cl_v->Ax); + kernel_config_scalar.setArg(4, cl_config); + kernel_config_scalar.setArg(5, p_cl_r->Ax); + kernel_config_scalar.setArg(6, init->get_value()); + kernel_config_scalar.setArg(7, config_size); + kernel_config_scalar.setArg(8, uint(early_exit)); + + n_groups_to_dispatch = div_up_clamp(config_size, m_block_size, 1, 1024); + + cl::NDRange exec_global(m_block_size * n_groups_to_dispatch); + cl::NDRange exec_local(m_block_size); + CL_DISPATCH_PROFILED("exec", queue, kernel_config_scalar, cl::NDRange(), + exec_global, exec_local); + + return Status::Ok; + } + + bool ensure_kernel(const ref_ptr>& op_multiply, + const ref_ptr>& op_add, + const ref_ptr>& op_select, + std::shared_ptr& program) { + m_block_size = get_acc_cl()->get_wave_size(); + m_block_count = 1; + + assert(m_block_count >= 1); + + CLProgramBuilder program_builder; + program_builder.set_name("mxv") + .add_define("WARP_SIZE", get_acc_cl()->get_wave_size()) + .add_define("BLOCK_SIZE", m_block_size) + .add_define("BLOCK_COUNT", m_block_count) + .add_type("TYPE", get_ttype().template as()); + + if constexpr (std::is_same_v) { + program_builder.add_define("USE_PAIR_SEMANTICS", 1); + program_builder.add_define("USE_PAIR_COMPARISON", 1); + } else { + program_builder.add_op("OP_BINARY1", op_multiply.template as()) + .add_op("OP_BINARY2", op_add.template as()) + .add_op("OP_SELECT", op_select.template as()); + } + program_builder.set_source(source_mxv).acquire(); + program = program_builder.get_program(); + + return true; + } + + private: + uint m_block_size = 0; + uint m_block_count = 0; + }; + +}// namespace spla + +#endif// SPLA_CL_MXV_HPP diff --git a/src/opencl/cl_program_builder.cpp b/src/opencl/cl_program_builder.cpp index c1a4ab67d..5a2241e8e 100644 --- a/src/opencl/cl_program_builder.cpp +++ b/src/opencl/cl_program_builder.cpp @@ -83,22 +83,22 @@ namespace spla { } std::stringstream builder; - bool needs_pair_override = false; + bool needs_pair_override = false; for (const auto& define : m_defines) { builder << "#define " << define.first << " " << define.second << "\n"; if (define.first == "TYPE" && define.second.find("Pair") != std::string::npos) { - needs_pair_override = true; + needs_pair_override = true; } } builder << source_common_api; if (needs_pair_override) { - builder << "#define OP_BINARY1(a, b) make_pair((a).weight, " - "(b).vertex)\n\n"; - builder << "#define OP_BINARY2(a, b) min_pair(a, b)\n\n"; - builder << "#define OP_SELECT(a) pair_always(a)\n\n"; + builder << "#define OP_BINARY1(a, b) make_pair((a).weight, " + "(b).vertex)\n\n"; + builder << "#define OP_BINARY2(a, b) min_pair(a, b)\n\n"; + builder << "#define OP_SELECT(a) pair_always(a)\n\n"; } for (const auto& function : m_functions) {