diff --git a/CMakeLists.txt b/CMakeLists.txt index 8c8cfa4f0..54f890fb5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -402,6 +402,7 @@ if (SPLA_BUILD_EXAMPLES) spla_example_application(tc) spla_example_application(pi) spla_example_application(convert) + spla_example_application(mst) endif () ###################################################################### diff --git a/examples/mst.cpp b/examples/mst.cpp new file mode 100644 index 000000000..60bcb45f0 --- /dev/null +++ b/examples/mst.cpp @@ -0,0 +1,132 @@ +/**********************************************************************************/ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ +/**********************************************************************************/ +/* MIT License */ +/* */ +/* Copyright (c) 2023 SparseLinearAlgebra */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ +/**********************************************************************************/ + +#include "common.hpp" +#include "options.hpp" + +#include +#include + +int main(int argc, const char* const* argv) { + auto options = make_options( + "mst", "Boruvka's Minimum Spanning Tree algorithm with spla library"); + + cxxopts::ParseResult args; + int ret; + + if (parse_options(argc, argv, options, args, ret)) { + std::cerr << "failed to parse options" << std::endl; + return ret; + } + + spla::Timer timer_total; + spla::Timer timer_gpu; + spla::Timer timer_ref; + spla::MtxLoader loader; + + timer_total.start(); + + if (!loader.load(args["mtxpath"].as())) { + std::cerr << "failed to load graph"; + return 1; + } + + std::string acc_info; + spla::Library* library = spla::Library::get(); + + library->set_platform(args["platform"].as()); + library->set_device(args["device"].as()); + library->set_queues_count(1); + library->get_accelerator_info(acc_info); + std::cout << "env: " << acc_info << std::endl; + + const spla::uint N = loader.get_n_rows(); + auto S = spla::Matrix::make(N, N, spla::PAIR); + + const auto& Ai = loader.get_Ai(); + const auto& Aj = loader.get_Aj(); + const auto& Aw = loader.get_Aw(); + + for (std::size_t k = 0; k < loader.get_n_values(); ++k) { + S->set_pair(Ai[k], Aj[k], spla::T_PAIR(Aw[k], Aj[k])); + } + + auto T_gpu = spla::Matrix::make(N, N, spla::FLOAT); + + auto desc = spla::Descriptor::make(); + + const int n_iters = args["niters"].as(); + + double total_weight_gpu = 0.0; + + if (args["run-gpu"].as()) { + library->set_force_no_acceleration(false); + + for (int i = 0; i < n_iters; ++i) { + T_gpu->clear(); + S = spla::Matrix::make(N, N, spla::PAIR); + for (std::size_t k = 0; k < loader.get_n_values(); ++k) { + S->set_pair(Ai[k], Aj[k], spla::T_PAIR(Aw[k], Aj[k])); + } + timer_gpu.lap_begin(); + spla::mst(T_gpu, S, desc, nullptr); + timer_gpu.lap_end(); + } + + total_weight_gpu = 0; + for (spla::uint i = 0; i < N; ++i) { + for (spla::uint j = i + 1; j < N; ++j) { + float w; + T_gpu->get_float(i, j, w); + if (w != 0.0) { + total_weight_gpu += w; + } + } + } + + std::cout << "GPU MST total weight: " << total_weight_gpu << std::endl; + } + + spla::Library::get()->finalize(); + + timer_total.stop(); + + std::cout << "\n=== Timing Results ===" << std::endl; + std::cout << "total(ms):" << timer_total.get_elapsed_ms() << std::endl; + std::cout << "gpu(ms): "; + timer_gpu.print(); + std::cout << std::endl; + + return 0; +} \ No newline at end of file diff --git a/include/spla.hpp b/include/spla.hpp index dc82624a6..8a2a4cc29 100644 --- a/include/spla.hpp +++ b/include/spla.hpp @@ -39,6 +39,7 @@ #include "spla/memview.hpp" #include "spla/object.hpp" #include "spla/op.hpp" +#include "spla/pair.hpp" #include "spla/ref.hpp" #include "spla/scalar.hpp" #include "spla/schedule.hpp" diff --git a/include/spla/algorithm.hpp b/include/spla/algorithm.hpp index a26a2dbba..c4f1cbbb4 100644 --- a/include/spla/algorithm.hpp +++ b/include/spla/algorithm.hpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #ifndef SPLA_ALGORITHM_HPP @@ -32,152 +39,160 @@ #include "descriptor.hpp" #include "matrix.hpp" #include "scalar.hpp" +#include "schedule.hpp" #include "vector.hpp" namespace spla { /** - * @addtogroup spla - * @{ - */ + * @addtogroup spla + * @{ + */ /** - * @brief Breadth-first search algorithm - * - * @param v int vector to store reached distances - * @param A int matrix filled with 1 where exist edge from i to j - * @param s start vertex id to search - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status bfs( - const ref_ptr& v, - const ref_ptr& A, - uint s, - const ref_ptr& descriptor = spla::Descriptor::make()); + * @brief Breadth-first search algorithm + * + * @param v int vector to store reached distances + * @param A int matrix filled with 1 where exist edge from i to j + * @param s start vertex id to search + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ + SPLA_API Status + bfs(const ref_ptr& v, const ref_ptr& A, uint s, + const ref_ptr& descriptor = spla::Descriptor::make()); /** - * @brief Naive breadth-first search algorithm (reference cpu implementation) - * - * @param v int vector to store reached distances - * @param A int graph adjacency lists filled with 1 where exist edge from i to j - * @param s start vertex id to search - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status bfs_naive( - std::vector& v, - std::vector>& A, - uint s, - const ref_ptr& descriptor = spla::Descriptor::make()); + * @brief Naive breadth-first search algorithm (reference cpu implementation) + * + * @param v int vector to store reached distances + * @param A int graph adjacency lists filled with 1 where exist edge from i to j + * @param s start vertex id to search + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ + SPLA_API Status + bfs_naive(std::vector& v, std::vector>& A, uint s, + const ref_ptr& descriptor = spla::Descriptor::make()); /** - * @brief Single-source shortest path algorithm - * - * @param v float vector to store reached distances - * @param A float matrix filled with >0.0f distances where exist edge from i to j otherwise 0.0f - * @param s start vertex id to search - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status sssp( - const ref_ptr& v, - const ref_ptr& A, - uint s, - const ref_ptr& descriptor = ref_ptr()); + * @brief Single-source shortest path algorithm + * + * @param v float vector to store reached distances + * @param A float matrix filled with >0.0f distances where exist edge from i to + * j otherwise 0.0f + * @param s start vertex id to search + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ + SPLA_API Status + sssp(const ref_ptr& v, const ref_ptr& A, uint s, + const ref_ptr& descriptor = ref_ptr()); /** - * @brief Naive single-source shortest path algorithm (reference cpu implementation) - * - * @param v float vector to store reached distances - * @param Ai uint matrix column indices - * @param Ax float matrix values with >0.0f distances where exist edge from i to j - * @param s start vertex id to search - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status sssp_naive(std::vector& v, - std::vector>& Ai, - std::vector>& Ax, - uint s, - const ref_ptr& descriptor = spla::Descriptor::make()); + * @brief Naive single-source shortest path algorithm (reference cpu + * implementation) + * + * @param v float vector to store reached distances + * @param Ai uint matrix column indices + * @param Ax float matrix values with >0.0f distances where exist edge from i to + * j + * @param s start vertex id to search + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ + SPLA_API Status + sssp_naive(std::vector& v, std::vector>& Ai, + std::vector>& Ax, uint s, + const ref_ptr& descriptor = spla::Descriptor::make()); /** - * @brief PageRank algorithm - * - * @param p float vector to store result vertices weights - * @param A float graph matrix with weights A[i][j] = alpha / outdegree(i) - * @param alpha float alpha to control PageRank (default is 0.85) - * @param eps float tolerance to control precision of PageRank (default is 1e-6) - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status pr( - ref_ptr& p, - const ref_ptr& A, - float alpha = 0.85, - float eps = 1e-6, - const ref_ptr& descriptor = spla::Descriptor::make()); + * @brief PageRank algorithm + * + * @param p float vector to store result vertices weights + * @param A float graph matrix with weights A[i][j] = alpha / outdegree(i) + * @param alpha float alpha to control PageRank (default is 0.85) + * @param eps float tolerance to control precision of PageRank (default is 1e-6) + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ + SPLA_API Status + pr(ref_ptr& p, const ref_ptr& A, float alpha = 0.85, + float eps = 1e-6, + const ref_ptr& descriptor = spla::Descriptor::make()); /** - * @brief Naive PageRank algorithm (reference cpu implementation) - * - * @param p float vector to store result vertices weights - * @param Ai float graph matrix column indices - * @param Ax float graph matrix weights A[i][j] = alpha / outdegree(i) - * @param alpha float alpha to control PageRank (default is 0.85) - * @param eps float tolerance to control precision of PageRank (default is 1e-6) - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ + * @brief Naive PageRank algorithm (reference cpu implementation) + * + * @param p float vector to store result vertices weights + * @param Ai float graph matrix column indices + * @param Ax float graph matrix weights A[i][j] = alpha / outdegree(i) + * @param alpha float alpha to control PageRank (default is 0.85) + * @param eps float tolerance to control precision of PageRank (default is 1e-6) + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ SPLA_API Status pr_naive( - std::vector& p, - std::vector>& Ai, - std::vector>& Ax, - float alpha = 0.85, - float eps = 1e-6, - const ref_ptr& descriptor = spla::Descriptor::make()); + std::vector& p, std::vector>& Ai, + std::vector>& Ax, float alpha = 0.85, float eps = 1e-6, + const ref_ptr& descriptor = spla::Descriptor::make()); /** - * @brief Triangles counting algorithm - * - * @param ntrins Number of triangles counted - * @param A Lower trilingual int matrix with 1 where has edge in a graph - * @param B Buffer int matrix to store result - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status tc( - int& ntrins, - const ref_ptr& A, - const ref_ptr& B, - const ref_ptr& descriptor = spla::Descriptor::make()); + * @brief Triangles counting algorithm + * + * @param ntrins Number of triangles counted + * @param A Lower trilingual int matrix with 1 where has edge in a graph + * @param B Buffer int matrix to store result + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ + SPLA_API Status + tc(int& ntrins, const ref_ptr& A, const ref_ptr& B, + const ref_ptr& descriptor = spla::Descriptor::make()); /** - * @brief Naive triangles counting algorithm (reference cpu implementation) - * - * @param ntrins Number of triangles counted - * @param A Lower trilingual int matrix structure - * @param descriptor optional descriptor for algorithm - * - * @return ok on success - */ - SPLA_API Status tc_naive( - int& ntrins, - std::vector>& Ai, - const ref_ptr& descriptor = spla::Descriptor::make()); + * @brief Naive triangles counting algorithm (reference cpu implementation) + * + * @param ntrins Number of triangles counted + * @param A Lower trilingual int matrix structure + * @param descriptor optional descriptor for algorithm + * + * @return ok on success + */ + SPLA_API Status + tc_naive(int& ntrins, std::vector>& Ai, + const ref_ptr& descriptor = spla::Descriptor::make()); + /** + * @brief Boruvka's Minimum Spanning Tree algorithm + * + * Finds the Minimum Spanning Tree of a weighted undirected graph using + * Boruvka's algorithm with algebraic operations. + * + * @param T float matrix to store MST edges (result). Only upper triangle is + * used. + * @param S PAIR matrix adjacency matrix with edges (weight, vertex). + * The vertex field stores the target vertex of the edge. + * @param descriptor optional descriptor for algorithm configuration + * @param task_hnd optional pointer to store task handle for async execution + * + * @return ok on success + */ + SPLA_API Status + mst(const ref_ptr& T, ref_ptr& S, + const ref_ptr& descriptor = spla::Descriptor::make(), + ref_ptr* task_hnd = nullptr); /** - * @} - */ + * @} + */ }// namespace spla -#endif//SPLA_ALGORITHM_HPP +#endif// SPLA_ALGORITHM_HPP diff --git a/include/spla/io.hpp b/include/spla/io.hpp index 90b0d5c54..f60763015 100644 --- a/include/spla/io.hpp +++ b/include/spla/io.hpp @@ -80,8 +80,9 @@ namespace spla { [[nodiscard]] SPLA_API const std::vector& get_Ai() const; [[nodiscard]] SPLA_API const std::vector& get_Aj() const; - [[nodiscard]] SPLA_API uint get_n_rows() const; - [[nodiscard]] SPLA_API uint get_n_cols() const; + [[nodiscard]] SPLA_API const std::vector& get_Aw() const; + [[nodiscard]] SPLA_API uint get_n_rows() const; + [[nodiscard]] SPLA_API uint get_n_cols() const; [[nodiscard]] SPLA_API std::size_t get_n_values() const; private: @@ -89,6 +90,7 @@ namespace spla { std::filesystem::path m_file_path; std::vector m_Ai; std::vector m_Aj; + std::vector m_Aw; bool m_base_is_zero = false; uint m_n_rows = 0; uint m_n_cols = 0; diff --git a/include/spla/matrix.hpp b/include/spla/matrix.hpp index 706a5dc86..645142b0e 100644 --- a/include/spla/matrix.hpp +++ b/include/spla/matrix.hpp @@ -57,9 +57,13 @@ namespace spla { SPLA_API virtual Status set_int(uint row_id, uint col_id, std::int32_t value) = 0; SPLA_API virtual Status set_uint(uint row_id, uint col_id, std::uint32_t value) = 0; SPLA_API virtual Status set_float(uint row_id, uint col_id, float value) = 0; + SPLA_API virtual Status set_pair(uint row_id, uint col_id, + Pair value) = 0; SPLA_API virtual Status get_int(uint row_id, uint col_id, std::int32_t& value) = 0; SPLA_API virtual Status get_uint(uint row_id, uint col_id, std::uint32_t& value) = 0; SPLA_API virtual Status get_float(uint row_id, uint col_id, float& value) = 0; + SPLA_API virtual Status get_pair(uint row_id, uint col_id, + Pair& value) = 0; SPLA_API virtual Status build(const ref_ptr& keys1, const ref_ptr& keys2, const ref_ptr& values) = 0; SPLA_API virtual Status read(ref_ptr& keys1, ref_ptr& keys2, ref_ptr& values) = 0; SPLA_API virtual Status clear() = 0; diff --git a/include/spla/op.hpp b/include/spla/op.hpp index 96dd16848..727bdb120 100644 --- a/include/spla/op.hpp +++ b/include/spla/op.hpp @@ -29,6 +29,7 @@ #define SPLA_OP_HPP #include "object.hpp" +#include "spla/pair.hpp" #include "type.hpp" #include @@ -64,6 +65,9 @@ namespace spla { SPLA_API static ref_ptr make_int(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_uint(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_float(std::string name, std::string code, std::function function); + SPLA_API static ref_ptr + make_pair(std::string name, std::string code, + std::function function); }; /** @@ -78,6 +82,9 @@ namespace spla { SPLA_API static ref_ptr make_int(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_uint(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_float(std::string name, std::string code, std::function function); + SPLA_API static ref_ptr + make_pair(std::string name, std::string code, + std::function function); }; /** @@ -91,6 +98,9 @@ namespace spla { SPLA_API static ref_ptr make_int(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_uint(std::string name, std::string code, std::function function); SPLA_API static ref_ptr make_float(std::string name, std::string code, std::function function); + SPLA_API static ref_ptr + make_pair(std::string name, std::string code, + std::function function); }; //////////////////////////////// Unary //////////////////////////////// @@ -130,6 +140,7 @@ namespace spla { SPLA_API extern ref_ptr FLOOR_FLOAT; SPLA_API extern ref_ptr ROUND_FLOAT; SPLA_API extern ref_ptr TRUNC_FLOAT; + SPLA_API extern ref_ptr IDENTITY_PAIR; //////////////////////////////// Binary //////////////////////////////// @@ -182,6 +193,9 @@ namespace spla { SPLA_API extern ref_ptr BXOR_INT; SPLA_API extern ref_ptr BXOR_UINT; + SPLA_API extern ref_ptr MIN_PAIR; + SPLA_API extern ref_ptr MUL_PAIR; + //////////////////////////////// Select //////////////////////////////// SPLA_API extern ref_ptr EQZERO_INT; @@ -205,6 +219,7 @@ namespace spla { SPLA_API extern ref_ptr ALWAYS_INT; SPLA_API extern ref_ptr ALWAYS_UINT; SPLA_API extern ref_ptr ALWAYS_FLOAT; + SPLA_API extern ref_ptr ALWAYS_PAIR; SPLA_API extern ref_ptr NEVER_INT; SPLA_API extern ref_ptr NEVER_UINT; SPLA_API extern ref_ptr NEVER_FLOAT; diff --git a/include/spla/pair.hpp b/include/spla/pair.hpp new file mode 100644 index 000000000..cf1c93220 --- /dev/null +++ b/include/spla/pair.hpp @@ -0,0 +1,23 @@ +#ifndef SPLA_PAIR_HPP +#define SPLA_PAIR_HPP +#include + +namespace spla { + struct Pair { + float weight; + int vertex; + + Pair() : weight(std::numeric_limits::infinity()), vertex(-1) {} + Pair(float w, int v) : weight(w), vertex(v) {} + + bool operator<(const Pair& other) const { return weight < other.weight; } + bool operator==(const Pair& other) const { + return weight == other.weight && vertex == other.vertex; + } + + bool operator!=(const Pair& other) const { return !(*this == other); } + + Pair& operator=(const Pair& other) = default; + }; +}// namespace spla +#endif \ No newline at end of file diff --git a/include/spla/scalar.hpp b/include/spla/scalar.hpp index 0a0588914..b8e7f463e 100644 --- a/include/spla/scalar.hpp +++ b/include/spla/scalar.hpp @@ -44,17 +44,24 @@ namespace spla { */ class Scalar : public Object { public: - SPLA_API ~Scalar() override = default; - SPLA_API virtual ref_ptr get_type() = 0; - SPLA_API virtual Status set_int(std::int32_t value) = 0; - SPLA_API virtual Status set_uint(std::uint32_t value) = 0; - SPLA_API virtual Status set_float(float value) = 0; - SPLA_API virtual Status get_int(std::int32_t& value) = 0; - SPLA_API virtual Status get_uint(std::uint32_t& value) = 0; - SPLA_API virtual Status get_float(float& value) = 0; - SPLA_API virtual T_INT as_int() = 0; - SPLA_API virtual T_UINT as_uint() = 0; - SPLA_API virtual T_FLOAT as_float() = 0; + SPLA_API ~Scalar() override = default; + SPLA_API virtual ref_ptr get_type() = 0; + SPLA_API virtual Status set_int(std::int32_t value) = 0; + SPLA_API virtual Status set_uint(std::uint32_t value) = 0; + SPLA_API virtual Status set_float(float value) = 0; + SPLA_API virtual Status set_pair(Pair value) { + return Status::InvalidArgument; + } + SPLA_API virtual Status get_int(std::int32_t& value) = 0; + SPLA_API virtual Status get_uint(std::uint32_t& value) = 0; + SPLA_API virtual Status get_float(float& value) = 0; + SPLA_API virtual Status get_pair(Pair& value) { + return Status::InvalidArgument; + } + SPLA_API virtual T_INT as_int() = 0; + SPLA_API virtual T_UINT as_uint() = 0; + SPLA_API virtual T_FLOAT as_float() = 0; + SPLA_API virtual T_PAIR as_pair() = 0; SPLA_API static ref_ptr make(const ref_ptr& type); SPLA_API static ref_ptr make_int(std::int32_t value); diff --git a/include/spla/type.hpp b/include/spla/type.hpp index f61adf867..afaf70c48 100644 --- a/include/spla/type.hpp +++ b/include/spla/type.hpp @@ -29,6 +29,7 @@ #define SPLA_TYPE_HPP #include "object.hpp" +#include "pair.hpp" #include @@ -58,11 +59,13 @@ namespace spla { using T_INT = std::int32_t; using T_UINT = std::uint32_t; using T_FLOAT = float; + using T_PAIR = Pair; SPLA_API extern ref_ptr BOOL; SPLA_API extern ref_ptr INT; SPLA_API extern ref_ptr UINT; SPLA_API extern ref_ptr FLOAT; + SPLA_API extern ref_ptr PAIR; /** * @} diff --git a/include/spla/vector.hpp b/include/spla/vector.hpp index 0ba6d512d..2af5546c3 100644 --- a/include/spla/vector.hpp +++ b/include/spla/vector.hpp @@ -59,6 +59,8 @@ namespace spla { SPLA_API virtual Status get_int(uint row_id, T_INT& value) = 0; SPLA_API virtual Status get_uint(uint row_id, T_UINT& value) = 0; SPLA_API virtual Status get_float(uint row_id, float& value) = 0; + SPLA_API virtual Status get_pair(uint row_id, Pair& value) = 0; + SPLA_API virtual Status set_pair(uint row_id, Pair value) = 0; SPLA_API virtual Status fill_noize(uint seed) = 0; SPLA_API virtual Status fill_with(const ref_ptr& value) = 0; SPLA_API virtual Status build(const ref_ptr& keys, const ref_ptr& values) = 0; diff --git a/src/algorithm.cpp b/src/algorithm.cpp index 189ea4b45..74f2c1901 100644 --- a/src/algorithm.cpp +++ b/src/algorithm.cpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include @@ -37,14 +44,16 @@ #include #include #include +#include +#include + +#define INF std::numeric_limits::infinity() namespace spla { #pragma region Bfs - Status bfs(const ref_ptr& v, - const ref_ptr& A, - uint s, + Status bfs(const ref_ptr& v, const ref_ptr& A, uint s, const ref_ptr& descriptor) { assert(v); assert(A); @@ -71,13 +80,17 @@ namespace spla { bool push_pull = descriptor->get_push_pull(); float front_factor = descriptor->get_front_factor(); - if (!(push || pull || push_pull)) push = true; + if (!(push || pull || push_pull)) + push = true; #ifndef SPLA_RELEASE std::string mode; - if (push_pull) mode = "(push_pull " + std::to_string(front_factor * 100.0f) + "%)"; - if (pull) mode = "(pull)"; - if (push) mode = "(push)"; + if (push_pull) + mode = "(push_pull " + std::to_string(front_factor * 100.0f) + "%)"; + if (pull) + mode = "(pull)"; + if (push) + mode = "(push)"; std::cout << "start bfs from " << s << " " << mode << std::endl; @@ -94,17 +107,19 @@ namespace spla { bool is_push_better = (front_density <= front_factor); if (push || (push_pull && is_push_better)) { - exec_vxm_masked(frontier_new, v, frontier_prev, A, BAND_INT, BOR_INT, EQZERO_INT, zero, desc); + exec_vxm_masked(frontier_new, v, frontier_prev, A, BAND_INT, BOR_INT, + EQZERO_INT, zero, desc); } else { - exec_mxv_masked(frontier_new, v, A, frontier_prev, BAND_INT, BOR_INT, EQZERO_INT, zero, desc); + exec_mxv_masked(frontier_new, v, A, frontier_prev, BAND_INT, BOR_INT, + EQZERO_INT, zero, desc); } exec_v_count_mf(frontier_size, frontier_new); #ifndef SPLA_RELEASE tight.stop(); - std::cout << " - iter " << current_level - << " front " << frontier_size->as_int() << " discovered " << discovered << " " + std::cout << " - iter " << current_level << " front " + << frontier_size->as_int() << " discovered " << discovered << " " << tight.get_elapsed_ms() << " ms" << std::endl; Library::get()->time_profile_dump(); Library::get()->time_profile_reset(); @@ -119,10 +134,8 @@ namespace spla { return Status::Ok; } - Status bfs_naive(std::vector& v, - std::vector>& A, - uint s, - const ref_ptr& descriptor) { + Status bfs_naive(std::vector& v, std::vector>& A, + uint s, const ref_ptr& descriptor) { const auto N = v.size(); @@ -155,9 +168,7 @@ namespace spla { #pragma region Sssp - Status sssp(const ref_ptr& v, - const ref_ptr& A, - uint s, + Status sssp(const ref_ptr& v, const ref_ptr& A, uint s, const ref_ptr& descriptor) { assert(v); assert(A); @@ -185,13 +196,17 @@ namespace spla { bool push_pull = descriptor->get_push_pull(); float front_factor = descriptor->get_front_factor(); - if (!(push || pull || push_pull)) push = true; + if (!(push || pull || push_pull)) + push = true; #ifndef SPLA_RELEASE std::string mode; - if (push_pull) mode = "(push_pull " + std::to_string(front_factor * 100.0f) + "%)"; - if (pull) mode = "(pull)"; - if (push) mode = "(push)"; + if (push_pull) + mode = "(push_pull " + std::to_string(front_factor * 100.0f) + "%)"; + if (pull) + mode = "(pull)"; + if (push) + mode = "(push)"; std::cout << "start sssp from " << s << " " << mode << std::endl; @@ -205,9 +220,11 @@ namespace spla { bool is_push_better = (front_density <= front_factor); if (push || (push_pull && is_push_better)) { - exec_vxm_masked(frontier, dummy_mask, feedback, A, PLUS_FLOAT, MIN_FLOAT, ALWAYS_FLOAT, inf_init); + exec_vxm_masked(frontier, dummy_mask, feedback, A, PLUS_FLOAT, MIN_FLOAT, + ALWAYS_FLOAT, inf_init); } else { - exec_mxv_masked(frontier, dummy_mask, A, feedback, PLUS_FLOAT, MIN_FLOAT, ALWAYS_FLOAT, inf_init); + exec_mxv_masked(frontier, dummy_mask, A, feedback, PLUS_FLOAT, MIN_FLOAT, + ALWAYS_FLOAT, inf_init); } exec_v_eadd_fdb(v, frontier, feedback, MIN_FLOAT); @@ -215,9 +232,9 @@ namespace spla { #ifndef SPLA_RELEASE tight.stop(); - std::cout << " - iter " << current_level - << " feed " << feedback_size->as_int() - << " " << tight.get_elapsed_ms() << " ms" << std::endl; + std::cout << " - iter " << current_level << " feed " + << feedback_size->as_int() << " " << tight.get_elapsed_ms() + << " ms" << std::endl; Library::get()->time_profile_dump(); Library::get()->time_profile_reset(); #endif @@ -228,11 +245,9 @@ namespace spla { return Status::Ok; } - Status sssp_naive(std::vector& v, - std::vector>& Ai, - std::vector>& Ax, - uint s, - const ref_ptr& descriptor) { + Status sssp_naive(std::vector& v, std::vector>& Ai, + std::vector>& Ax, uint s, + const ref_ptr& descriptor) { const auto N = v.size(); const auto inf = std::numeric_limits::max(); @@ -275,10 +290,7 @@ namespace spla { #pragma region Pr - Status pr(ref_ptr& p, - const ref_ptr& A, - float alpha, - float eps, + Status pr(ref_ptr& p, const ref_ptr& A, float alpha, float eps, const ref_ptr& descriptor) { assert(p); assert(A); @@ -309,7 +321,8 @@ namespace spla { tight.start(); #endif // p = A*p + (1-alpha)/N - exec_mxv_masked(p_tmp, dummy_mask, A, p_prev, MULT_FLOAT, PLUS_FLOAT, ALWAYS_FLOAT, zero); + exec_mxv_masked(p_tmp, dummy_mask, A, p_prev, MULT_FLOAT, PLUS_FLOAT, + ALWAYS_FLOAT, zero); exec_v_eadd(p, p_tmp, addition, PLUS_FLOAT); // error = sqrt((p[01]-prev[0])^2 + ... + p[N-1]-prev[N-1])^2) @@ -322,9 +335,8 @@ namespace spla { #ifndef SPLA_RELEASE tight.stop(); - std::cout << " - iter " << iter++ - << " error " << error - << " " << tight.get_elapsed_ms() << " ms" << std::endl; + std::cout << " - iter " << iter++ << " error " << error << " " + << tight.get_elapsed_ms() << " ms" << std::endl; Library::get()->time_profile_dump(); Library::get()->time_profile_reset(); #endif @@ -334,12 +346,9 @@ namespace spla { return Status::Ok; } - Status pr_naive(std::vector& p, - std::vector>& Ai, - std::vector>& Ax, - float alpha, - float eps, - const ref_ptr& descriptor) { + Status pr_naive(std::vector& p, std::vector>& Ai, + std::vector>& Ax, float alpha, float eps, + const ref_ptr& descriptor) { const auto N = p.size(); @@ -377,11 +386,8 @@ namespace spla { #pragma region Tc - Status tc( - int& ntrins, - const ref_ptr& A, - const ref_ptr& B, - const ref_ptr& descriptor) { + Status tc(int& ntrins, const ref_ptr& A, const ref_ptr& B, + const ref_ptr& descriptor) { assert(A); assert(B); @@ -403,8 +409,8 @@ namespace spla { #ifndef SPLA_RELEASE tight.stop(); - std::cout << " - ntrins " << ntrins - << " " << tight.get_elapsed_ms() << " ms" << std::endl; + std::cout << " - ntrins " << ntrins << " " << tight.get_elapsed_ms() << " ms" + << std::endl; Library::get()->time_profile_dump(); Library::get()->time_profile_reset(); @@ -413,10 +419,8 @@ namespace spla { return Status::Ok; } - Status tc_naive( - int& ntrins, - std::vector>& Ai, - const ref_ptr& descriptor) { + Status tc_naive(int& ntrins, std::vector>& Ai, + const ref_ptr& descriptor) { ntrins = 0; @@ -448,5 +452,264 @@ namespace spla { } #pragma endregion Pr +#pragma region Mst + Status mst(const ref_ptr& T, ref_ptr& S, + const ref_ptr& descriptor, + ref_ptr* task_hnd) { + + assert(S); + assert(T); + + const auto n = S->get_n_rows(); + int comp = n; + + auto parent = Vector::make(n, PAIR); + for (uint i = 0; i < n; i++) { + parent->set_pair(i, T_PAIR(0.0f, i)); + } + auto edge = Vector::make(n, PAIR); + auto cedge = Vector::make(n, PAIR); + auto t_vec = Vector::make(n, PAIR); + auto mask = Vector::make(n, PAIR); + for (uint i = 0; i < n; i++) { + mask->set_pair(i, T_PAIR(1.0f, 0)); + } + auto init_inf = Scalar::make(PAIR); + T_PAIR init_val; + init_inf->set_pair(init_val); + int iteration = 0; + auto new_S = S; +#ifdef SPLA_RELEASE + std::cout << "start Boruvka MST, vertices = " << n << "\n"; + Timer tight; +#endif + + while (comp > 1) { +#ifdef SPLA_RELEASE + tight.start(); +#endif + iteration++; + int edges_added_this_iteration = 0; + spla::exec_mxv_masked(edge, mask, S, parent, spla::MUL_PAIR, spla::MIN_PAIR, + spla::ALWAYS_PAIR, init_inf); +#ifdef SPLA_DEBUG + + std::cout << "edge = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + edge->get_pair(i, p); + std::cout << "(" << p.weight << ", " << p.vertex << "), "; + } + std::cout << "]\n"; +#endif + + for (int32_t i = 0; i < n; i++) { + cedge->set_pair(i, init_val); + } + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + spla::T_PAIR p1; + spla::T_PAIR p2; + parent->get_pair(i, p); + auto p_i = p.vertex; + cedge->get_pair(p_i, p1); + edge->get_pair(i, p2); + auto min_for_comp = p1.weight <= p2.weight ? p1 : p2; + cedge->set_pair(p_i, min_for_comp); + } + +#ifdef SPLA_DEBUG + std::cout << "cedge = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + cedge->get_pair(i, p); + std::cout << "(" << p.weight << ", " << p.vertex << "), "; + } + std::cout << "]\n"; +#endif + + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR parent_v; + spla::T_PAIR cedge_v; + parent->get_pair(i, parent_v); + cedge->get_pair(parent_v.vertex, cedge_v); + t_vec->set_pair(i, cedge_v); + } + + auto index = spla::Vector::make(n, spla::INT); + + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR edge_v, t_v; + edge->get_pair(i, edge_v); + t_vec->get_pair(i, t_v); + if (edge_v == t_v) + index->set_int(i, i); + else + index->set_int(i, n); + } + auto temp = spla::Vector::make(n, spla::INT); + for (int32_t i = 0; i < n; i++) + temp->set_int(i, n); + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR parent_v; + parent->get_pair(i, parent_v); + auto p_i = parent_v.vertex; + spla::T_INT temp_v, ind_v; + temp->get_int(p_i, temp_v); + index->get_int(i, ind_v); + spla::T_INT min_v = temp_v < ind_v ? temp_v : ind_v; + temp->set_int(p_i, min_v); + } +#ifdef SPLA_DEBUG + std::cout << "t = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_INT p; + temp->get_int(i, p); + std::cout << p << ", "; + } + std::cout << "]\n"; +#endif + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR parent_v; + parent->get_pair(i, parent_v); + auto p_i = parent_v.vertex; + spla::T_INT temp_v; + temp->get_int(p_i, temp_v); + index->set_int(i, temp_v); + } + +#ifdef SPLA_DEBUG + std::cout << "index = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_INT p; + index->get_int(i, p); + std::cout << p << ", "; + } + std::cout << "]\n"; +#endif + auto new_parent = spla::Vector::make(n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + parent->get_pair(i, p); + new_parent->set_pair(i, p); + } + for (int32_t i = 0; i < n; i++) { + spla::T_INT ind_v; + index->get_int(i, ind_v); + if (i == ind_v) { + auto row = spla::Vector::make(n, spla::PAIR); + spla::exec_m_extract_row(row, S, i, spla::IDENTITY_PAIR); + int min_vertex = -1; + float min_weight = INF; + + for (int32_t j = 0; j < n; j++) { + spla::T_PAIR pair_row; + row->get_pair(j, pair_row); + auto pair_row_weight = pair_row.weight; + auto pair_row_vertex = pair_row.vertex; + if (pair_row_weight < INF) { + spla::T_PAIR p1, p2; + parent->get_pair(i, p1); + parent->get_pair(pair_row_vertex, p2); + if (p1.vertex != p2.vertex) { + if (pair_row_weight < min_weight) { + min_weight = pair_row_weight; + min_vertex = j; + } + } + } + } + if (min_vertex == -1) + continue; + T->set_float(i, min_vertex, min_weight); + T->set_float(min_vertex, i, min_weight); + edges_added_this_iteration++; + if (i < min_vertex) { + spla::T_PAIR p; + spla::T_PAIR old_p; + new_parent->get_pair(i, p); + new_parent->get_pair(min_vertex, old_p); + new_parent->set_pair(min_vertex, spla::T_PAIR(0.0f, p.vertex)); + for (int k = 0; k < n; k++) { + spla::T_PAIR p1; + new_parent->get_pair(k, p1); + if (p1.vertex == old_p.vertex) + new_parent->set_pair(k, spla::T_PAIR(0.0f, p.vertex)); + } + } else { + spla::T_PAIR p; + spla::T_PAIR old_p; + new_parent->get_pair(min_vertex, p); + new_parent->get_pair(i, old_p); + new_parent->set_pair(i, spla::T_PAIR(0.0f, p.vertex)); + for (int k = 0; k < n; k++) { + spla::T_PAIR p1; + new_parent->get_pair(k, p1); + if (p1.vertex == old_p.vertex) + new_parent->set_pair(k, spla::T_PAIR(0.0f, p.vertex)); + } + } + } + } + parent = new_parent; + std::vector seen(n, false); + for (uint i = 0; i < n; i++) { + T_PAIR p; + parent->get_pair(i, p); + seen[p.vertex] = true; + } + + comp = 0; + for (uint i = 0; i < n; i++) { + if (seen[i]) + comp++; + } + +#ifdef SPLA_DEBUG + std::cout << "parent = ["; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + parent->get_pair(i, p); + std::cout << p.vertex << ", "; + } + std::cout << "]\n"; +#endif +#ifdef SPLA_RELEASE + tight.stop(); + std::cout << " - iteration " << iteration << " components " << comp << " " + << tight.get_elapsed_ms() << " ms" << std::endl; + Library::get()->time_profile_dump(); + Library::get()->time_profile_reset(); +#endif + if (comp == 1) { + std::cout << "MST complete after " << iteration << " iterations" + << std::endl; + return Status::Ok; + } + auto filtered_S = spla::Matrix::make(n, n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + for (int32_t j = 0; j < n; j++) { + spla::T_PAIR val; + S->get_pair(i, j, val); + if (val.weight != std::numeric_limits::infinity()) { + spla::T_PAIR parent_i, parent_j; + parent->get_pair(i, parent_i); + parent->get_pair(j, parent_j); + if ((parent_i.vertex != parent_j.vertex) && + (val.weight != std::numeric_limits::infinity())) { + filtered_S->set_pair(i, j, val); + } + } + } + } + S = filtered_S; + if (edges_added_this_iteration == 0) { + return Status::Ok; + } + } + return Status::Ok; + } + +#pragma endregion Mst }// namespace spla diff --git a/src/binding/c_op.cpp b/src/binding/c_op.cpp index e3202e4f0..3feb1ac1c 100644 --- a/src/binding/c_op.cpp +++ b/src/binding/c_op.cpp @@ -1,130 +1,347 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/JetBrains-Research/spla */ +/* This file is part of spla project */ +/* https://github.com/JetBrains-Research/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include "c_config.hpp" -spla_OpUnary spla_OpUnary_IDENTITY_INT() { return as_ptr(spla::IDENTITY_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_IDENTITY_UINT() { return as_ptr(spla::IDENTITY_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_IDENTITY_FLOAT() { return as_ptr(spla::IDENTITY_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_AINV_INT() { return as_ptr(spla::AINV_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_AINV_UINT() { return as_ptr(spla::AINV_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_AINV_FLOAT() { return as_ptr(spla::AINV_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_MINV_INT() { return as_ptr(spla::MINV_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_MINV_UINT() { return as_ptr(spla::MINV_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_MINV_FLOAT() { return as_ptr(spla::MINV_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_LNOT_INT() { return as_ptr(spla::LNOT_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_LNOT_UINT() { return as_ptr(spla::LNOT_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_LNOT_FLOAT() { return as_ptr(spla::LNOT_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_UONE_INT() { return as_ptr(spla::UONE_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_UONE_UINT() { return as_ptr(spla::UONE_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_UONE_FLOAT() { return as_ptr(spla::UONE_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ABS_INT() { return as_ptr(spla::ABS_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ABS_UINT() { return as_ptr(spla::ABS_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ABS_FLOAT() { return as_ptr(spla::ABS_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_BNOT_INT() { return as_ptr(spla::BNOT_INT.ref_and_get()); } -spla_OpUnary spla_OpUnary_BNOT_UINT() { return as_ptr(spla::BNOT_UINT.ref_and_get()); } -spla_OpUnary spla_OpUnary_SQRT_FLOAT() { return as_ptr(spla::SQRT_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_LOG_FLOAT() { return as_ptr(spla::LOG_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_EXP_FLOAT() { return as_ptr(spla::EXP_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_SIN_FLOAT() { return as_ptr(spla::SIN_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_COS_FLOAT() { return as_ptr(spla::COS_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_TAN_FLOAT() { return as_ptr(spla::TAN_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ASIN_FLOAT() { return as_ptr(spla::ASIN_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ACOS_FLOAT() { return as_ptr(spla::ACOS_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ATAN_FLOAT() { return as_ptr(spla::ATAN_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_CEIL_FLOAT() { return as_ptr(spla::CEIL_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_FLOOR_FLOAT() { return as_ptr(spla::FLOOR_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_ROUND_FLOAT() { return as_ptr(spla::ROUND_FLOAT.ref_and_get()); } -spla_OpUnary spla_OpUnary_TRUNC_FLOAT() { return as_ptr(spla::TRUNC_FLOAT.ref_and_get()); } +spla_OpUnary spla_OpUnary_IDENTITY_INT() { + return as_ptr(spla::IDENTITY_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_IDENTITY_UINT() { + return as_ptr(spla::IDENTITY_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_IDENTITY_FLOAT() { + return as_ptr(spla::IDENTITY_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_IDENTITY_PAIR() { + return as_ptr(spla::IDENTITY_PAIR.ref_and_get()); +} +spla_OpUnary spla_OpUnary_AINV_INT() { + return as_ptr(spla::AINV_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_AINV_UINT() { + return as_ptr(spla::AINV_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_AINV_FLOAT() { + return as_ptr(spla::AINV_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_MINV_INT() { + return as_ptr(spla::MINV_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_MINV_UINT() { + return as_ptr(spla::MINV_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_MINV_FLOAT() { + return as_ptr(spla::MINV_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_LNOT_INT() { + return as_ptr(spla::LNOT_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_LNOT_UINT() { + return as_ptr(spla::LNOT_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_LNOT_FLOAT() { + return as_ptr(spla::LNOT_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_UONE_INT() { + return as_ptr(spla::UONE_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_UONE_UINT() { + return as_ptr(spla::UONE_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_UONE_FLOAT() { + return as_ptr(spla::UONE_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ABS_INT() { + return as_ptr(spla::ABS_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ABS_UINT() { + return as_ptr(spla::ABS_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ABS_FLOAT() { + return as_ptr(spla::ABS_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_BNOT_INT() { + return as_ptr(spla::BNOT_INT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_BNOT_UINT() { + return as_ptr(spla::BNOT_UINT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_SQRT_FLOAT() { + return as_ptr(spla::SQRT_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_LOG_FLOAT() { + return as_ptr(spla::LOG_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_EXP_FLOAT() { + return as_ptr(spla::EXP_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_SIN_FLOAT() { + return as_ptr(spla::SIN_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_COS_FLOAT() { + return as_ptr(spla::COS_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_TAN_FLOAT() { + return as_ptr(spla::TAN_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ASIN_FLOAT() { + return as_ptr(spla::ASIN_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ACOS_FLOAT() { + return as_ptr(spla::ACOS_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ATAN_FLOAT() { + return as_ptr(spla::ATAN_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_CEIL_FLOAT() { + return as_ptr(spla::CEIL_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_FLOOR_FLOAT() { + return as_ptr(spla::FLOOR_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_ROUND_FLOAT() { + return as_ptr(spla::ROUND_FLOAT.ref_and_get()); +} +spla_OpUnary spla_OpUnary_TRUNC_FLOAT() { + return as_ptr(spla::TRUNC_FLOAT.ref_and_get()); +} -spla_OpBinary spla_OpBinary_PLUS_INT() { return as_ptr(spla::PLUS_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_PLUS_UINT() { return as_ptr(spla::PLUS_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_PLUS_FLOAT() { return as_ptr(spla::PLUS_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MINUS_INT() { return as_ptr(spla::MINUS_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MINUS_UINT() { return as_ptr(spla::MINUS_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MINUS_FLOAT() { return as_ptr(spla::MINUS_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MULT_INT() { return as_ptr(spla::MULT_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MULT_UINT() { return as_ptr(spla::MULT_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MULT_FLOAT() { return as_ptr(spla::MULT_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_DIV_INT() { return as_ptr(spla::DIV_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_DIV_UINT() { return as_ptr(spla::DIV_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_DIV_FLOAT() { return as_ptr(spla::DIV_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MINUS_POW2_INT() { return as_ptr(spla::MINUS_POW2_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MINUS_POW2_UINT() { return as_ptr(spla::MINUS_POW2_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MINUS_POW2_FLOAT() { return as_ptr(spla::MINUS_POW2_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_FIRST_INT() { return as_ptr(spla::FIRST_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_FIRST_UINT() { return as_ptr(spla::FIRST_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_FIRST_FLOAT() { return as_ptr(spla::FIRST_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_SECOND_INT() { return as_ptr(spla::SECOND_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_SECOND_UINT() { return as_ptr(spla::SECOND_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_SECOND_FLOAT() { return as_ptr(spla::SECOND_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BONE_INT() { return as_ptr(spla::BONE_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BONE_UINT() { return as_ptr(spla::BONE_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BONE_FLOAT() { return as_ptr(spla::BONE_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MIN_INT() { return as_ptr(spla::MIN_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MIN_UINT() { return as_ptr(spla::MIN_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MIN_FLOAT() { return as_ptr(spla::MIN_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MAX_INT() { return as_ptr(spla::MAX_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MAX_UINT() { return as_ptr(spla::MAX_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_MAX_FLOAT() { return as_ptr(spla::MAX_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_LOR_INT() { return as_ptr(spla::LOR_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_LOR_UINT() { return as_ptr(spla::LOR_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_LOR_FLOAT() { return as_ptr(spla::LOR_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_LAND_INT() { return as_ptr(spla::LAND_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_LAND_UINT() { return as_ptr(spla::LAND_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_LAND_FLOAT() { return as_ptr(spla::LAND_FLOAT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BOR_INT() { return as_ptr(spla::BOR_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BOR_UINT() { return as_ptr(spla::BOR_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BAND_INT() { return as_ptr(spla::BAND_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BAND_UINT() { return as_ptr(spla::BAND_UINT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BXOR_INT() { return as_ptr(spla::BXOR_INT.ref_and_get()); } -spla_OpBinary spla_OpBinary_BXOR_UINT() { return as_ptr(spla::BXOR_UINT.ref_and_get()); } +spla_OpBinary spla_OpBinary_PLUS_INT() { + return as_ptr(spla::PLUS_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_PLUS_UINT() { + return as_ptr(spla::PLUS_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_PLUS_FLOAT() { + return as_ptr(spla::PLUS_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MINUS_INT() { + return as_ptr(spla::MINUS_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MINUS_UINT() { + return as_ptr(spla::MINUS_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MINUS_FLOAT() { + return as_ptr(spla::MINUS_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MULT_INT() { + return as_ptr(spla::MULT_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MULT_UINT() { + return as_ptr(spla::MULT_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MULT_FLOAT() { + return as_ptr(spla::MULT_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_DIV_INT() { + return as_ptr(spla::DIV_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_DIV_UINT() { + return as_ptr(spla::DIV_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_DIV_FLOAT() { + return as_ptr(spla::DIV_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MINUS_POW2_INT() { + return as_ptr(spla::MINUS_POW2_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MINUS_POW2_UINT() { + return as_ptr(spla::MINUS_POW2_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MINUS_POW2_FLOAT() { + return as_ptr(spla::MINUS_POW2_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_FIRST_INT() { + return as_ptr(spla::FIRST_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_FIRST_UINT() { + return as_ptr(spla::FIRST_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_FIRST_FLOAT() { + return as_ptr(spla::FIRST_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_SECOND_INT() { + return as_ptr(spla::SECOND_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_SECOND_UINT() { + return as_ptr(spla::SECOND_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_SECOND_FLOAT() { + return as_ptr(spla::SECOND_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BONE_INT() { + return as_ptr(spla::BONE_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BONE_UINT() { + return as_ptr(spla::BONE_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BONE_FLOAT() { + return as_ptr(spla::BONE_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MIN_INT() { + return as_ptr(spla::MIN_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MIN_UINT() { + return as_ptr(spla::MIN_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MIN_FLOAT() { + return as_ptr(spla::MIN_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MAX_INT() { + return as_ptr(spla::MAX_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MAX_UINT() { + return as_ptr(spla::MAX_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MAX_FLOAT() { + return as_ptr(spla::MAX_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_LOR_INT() { + return as_ptr(spla::LOR_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_LOR_UINT() { + return as_ptr(spla::LOR_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_LOR_FLOAT() { + return as_ptr(spla::LOR_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_LAND_INT() { + return as_ptr(spla::LAND_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_LAND_UINT() { + return as_ptr(spla::LAND_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_LAND_FLOAT() { + return as_ptr(spla::LAND_FLOAT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BOR_INT() { + return as_ptr(spla::BOR_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BOR_UINT() { + return as_ptr(spla::BOR_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BAND_INT() { + return as_ptr(spla::BAND_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BAND_UINT() { + return as_ptr(spla::BAND_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BXOR_INT() { + return as_ptr(spla::BXOR_INT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_BXOR_UINT() { + return as_ptr(spla::BXOR_UINT.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MIN_PAIR() { + return as_ptr(spla::MIN_PAIR.ref_and_get()); +} +spla_OpBinary spla_OpBinary_MUL_PAIR() { + return as_ptr(spla::MUL_PAIR.ref_and_get()); +} -spla_OpSelect spla_OpSelect_EQZERO_INT() { return as_ptr(spla::EQZERO_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_EQZERO_UINT() { return as_ptr(spla::EQZERO_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_EQZERO_FLOAT() { return as_ptr(spla::EQZERO_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_NQZERO_INT() { return as_ptr(spla::NQZERO_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_NQZERO_UINT() { return as_ptr(spla::NQZERO_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_NQZERO_FLOAT() { return as_ptr(spla::NQZERO_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_GTZERO_INT() { return as_ptr(spla::GTZERO_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_GTZERO_UINT() { return as_ptr(spla::GTZERO_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_GTZERO_FLOAT() { return as_ptr(spla::GTZERO_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_GEZERO_INT() { return as_ptr(spla::GEZERO_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_GEZERO_UINT() { return as_ptr(spla::GEZERO_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_GEZERO_FLOAT() { return as_ptr(spla::GEZERO_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_LTZERO_INT() { return as_ptr(spla::LTZERO_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_LTZERO_UINT() { return as_ptr(spla::LTZERO_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_LTZERO_FLOAT() { return as_ptr(spla::LTZERO_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_LEZERO_INT() { return as_ptr(spla::LEZERO_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_LEZERO_UINT() { return as_ptr(spla::LEZERO_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_LEZERO_FLOAT() { return as_ptr(spla::LEZERO_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_ALWAYS_INT() { return as_ptr(spla::ALWAYS_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_ALWAYS_UINT() { return as_ptr(spla::ALWAYS_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_ALWAYS_FLOAT() { return as_ptr(spla::ALWAYS_FLOAT.ref_and_get()); } -spla_OpSelect spla_OpSelect_NEVER_INT() { return as_ptr(spla::NEVER_INT.ref_and_get()); } -spla_OpSelect spla_OpSelect_NEVER_UINT() { return as_ptr(spla::NEVER_UINT.ref_and_get()); } -spla_OpSelect spla_OpSelect_NEVER_FLOAT() { return as_ptr(spla::NEVER_FLOAT.ref_and_get()); } \ No newline at end of file +spla_OpSelect spla_OpSelect_EQZERO_INT() { + return as_ptr(spla::EQZERO_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_EQZERO_UINT() { + return as_ptr(spla::EQZERO_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_EQZERO_FLOAT() { + return as_ptr(spla::EQZERO_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_NQZERO_INT() { + return as_ptr(spla::NQZERO_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_NQZERO_UINT() { + return as_ptr(spla::NQZERO_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_NQZERO_FLOAT() { + return as_ptr(spla::NQZERO_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_GTZERO_INT() { + return as_ptr(spla::GTZERO_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_GTZERO_UINT() { + return as_ptr(spla::GTZERO_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_GTZERO_FLOAT() { + return as_ptr(spla::GTZERO_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_GEZERO_INT() { + return as_ptr(spla::GEZERO_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_GEZERO_UINT() { + return as_ptr(spla::GEZERO_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_GEZERO_FLOAT() { + return as_ptr(spla::GEZERO_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_LTZERO_INT() { + return as_ptr(spla::LTZERO_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_LTZERO_UINT() { + return as_ptr(spla::LTZERO_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_LTZERO_FLOAT() { + return as_ptr(spla::LTZERO_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_LEZERO_INT() { + return as_ptr(spla::LEZERO_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_LEZERO_UINT() { + return as_ptr(spla::LEZERO_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_LEZERO_FLOAT() { + return as_ptr(spla::LEZERO_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_ALWAYS_INT() { + return as_ptr(spla::ALWAYS_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_ALWAYS_UINT() { + return as_ptr(spla::ALWAYS_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_ALWAYS_FLOAT() { + return as_ptr(spla::ALWAYS_FLOAT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_ALWAYS_PAIR() { + return as_ptr(spla::ALWAYS_PAIR.ref_and_get()); +} +spla_OpSelect spla_OpSelect_NEVER_INT() { + return as_ptr(spla::NEVER_INT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_NEVER_UINT() { + return as_ptr(spla::NEVER_UINT.ref_and_get()); +} +spla_OpSelect spla_OpSelect_NEVER_FLOAT() { + return as_ptr(spla::NEVER_FLOAT.ref_and_get()); +} \ No newline at end of file diff --git a/src/binding/c_type.cpp b/src/binding/c_type.cpp index 2ca432791..1c294878a 100644 --- a/src/binding/c_type.cpp +++ b/src/binding/c_type.cpp @@ -36,6 +36,5 @@ spla_Type spla_Type_INT() { spla_Type spla_Type_UINT() { return as_ptr(spla::UINT.get()); } -spla_Type spla_Type_FLOAT() { - return as_ptr(spla::FLOAT.get()); -} \ No newline at end of file +spla_Type spla_Type_FLOAT() { return as_ptr(spla::FLOAT.get()); } +spla_Type spla_Type_PAIR() { return as_ptr(spla::PAIR.get()); } \ No newline at end of file diff --git a/src/core/tmatrix.hpp b/src/core/tmatrix.hpp index cd36e72b5..20c8344b8 100644 --- a/src/core/tmatrix.hpp +++ b/src/core/tmatrix.hpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #ifndef SPLA_TMATRIX_HPP @@ -43,16 +50,16 @@ namespace spla { /** - * @addtogroup internal - * @{ - */ + * @addtogroup internal + * @{ + */ /** - * @class TMatrix - * @brief Matrix interface implementation with type information bound - * - * @tparam T Type of stored elements - */ + * @class TMatrix + * @brief Matrix interface implementation with type information bound + * + * @tparam T Type of stored elements + */ template class TMatrix final : public Matrix { public: @@ -70,15 +77,25 @@ namespace spla { Status set_int(uint row_id, uint col_id, std::int32_t value) override; Status set_uint(uint row_id, uint col_id, std::uint32_t value) override; Status set_float(uint row_id, uint col_id, float value) override; - Status get_int(uint row_id, uint col_id, int32_t& value) override; - Status get_uint(uint row_id, uint col_id, uint32_t& value) override; - Status get_float(uint row_id, uint col_id, float& value) override; - Status build(const ref_ptr& keys1, const ref_ptr& keys2, const ref_ptr& values) override; - Status read(ref_ptr& keys1, ref_ptr& keys2, ref_ptr& values) override; - Status clear() override; + Status set_pair(uint row_id, uint col_id, Pair value) override { + return Status::InvalidArgument; + } + Status get_int(uint row_id, uint col_id, int32_t& value) override; + Status get_uint(uint row_id, uint col_id, uint32_t& value) override; + Status get_float(uint row_id, uint col_id, float& value) override; + Status get_pair(uint row_id, uint col_id, Pair& value) override { + return Status::InvalidArgument; + } + Status build(const ref_ptr& keys1, const ref_ptr& keys2, + const ref_ptr& values) override; + Status read(ref_ptr& keys1, ref_ptr& keys2, + ref_ptr& values) override; + Status clear() override; template - Decorator* get() { return m_storage.template get(); } + Decorator* get() { + return m_storage.template get(); + } void validate_rw(FormatMatrix format); void validate_rwd(FormatMatrix format); @@ -132,9 +149,14 @@ namespace spla { if (value) { m_storage.invalidate(); - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_int()); - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_uint()); - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_float()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_int()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_uint()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_float()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_pair()); return Status::Ok; } @@ -161,18 +183,32 @@ namespace spla { cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); return Status::Ok; } + template<> + inline Status TMatrix::set_int(uint row_id, uint col_id, + std::int32_t value) { + return Status::InvalidArgument; + } template Status TMatrix::set_uint(uint row_id, uint col_id, std::uint32_t value) { validate_rwd(FormatMatrix::CpuLil); cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); return Status::Ok; } + template<> + inline Status TMatrix::set_uint(uint row_id, uint col_id, + std::uint32_t value) { + return Status::InvalidArgument; + } template Status TMatrix::set_float(uint row_id, uint col_id, float value) { validate_rwd(FormatMatrix::CpuLil); cpu_lil_add_element(row_id, col_id, static_cast(value), *get>()); return Status::Ok; } + template<> + inline Status TMatrix::set_float(uint row_id, uint col_id, float value) { + return Status::InvalidArgument; + } template Status TMatrix::get_int(uint row_id, uint col_id, int32_t& value) { @@ -188,6 +224,11 @@ namespace spla { return Status::Ok; } + template<> + inline Status TMatrix::get_int(uint row_id, uint col_id, + std::int32_t& value) { + return Status::InvalidArgument; + } template Status TMatrix::get_uint(uint row_id, uint col_id, uint32_t& value) { validate_rw(FormatMatrix::CpuDok); @@ -202,6 +243,11 @@ namespace spla { return Status::Ok; } + template<> + inline Status TMatrix::get_uint(uint row_id, uint col_id, + std::uint32_t& value) { + return Status::InvalidArgument; + } template Status TMatrix::get_float(uint row_id, uint col_id, float& value) { validate_rw(FormatMatrix::CpuDok); @@ -216,9 +262,15 @@ namespace spla { return Status::Ok; } + template<> + inline Status TMatrix::get_float(uint row_id, uint col_id, float& value) { + return Status::InvalidArgument; + } template - Status TMatrix::build(const ref_ptr& keys1, const ref_ptr& keys2, const ref_ptr& values) { + Status TMatrix::build(const ref_ptr& keys1, + const ref_ptr& keys2, + const ref_ptr& values) { assert(keys1); assert(keys2); assert(values); @@ -252,7 +304,8 @@ namespace spla { return Status::Ok; } template - Status TMatrix::read(ref_ptr& keys1, ref_ptr& keys2, ref_ptr& values) { + Status TMatrix::read(ref_ptr& keys1, ref_ptr& keys2, + ref_ptr& values) { const auto key_size = sizeof(uint); const auto value_size = sizeof(T); @@ -314,12 +367,38 @@ namespace spla { return storage_manager.get(); } + template<> + inline Status TMatrix::set_pair(uint row_id, uint col_id, Pair value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + + validate_rwd(FormatMatrix::CpuLil); + cpu_lil_add_element(row_id, col_id, value, *get>()); + return Status::Ok; + } + template<> + inline Status TMatrix::get_pair(uint row_id, uint col_id, Pair& value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + validate_rw(FormatMatrix::CpuDok); + + auto& Ax = get>()->Ax; + auto entry = Ax.find(typename CpuDok::Key(row_id, col_id)); + value = m_storage.get_fill_value(); + + if (entry != Ax.end()) { + value = static_cast(entry->second); + } + + return Status::Ok; + } /** - * @} - */ + * @} + */ }// namespace spla - -#endif//SPLA_TMATRIX_HPP +#endif// SPLA_TMATRIX_HPP diff --git a/src/core/top.hpp b/src/core/top.hpp index 654e928eb..a313564a0 100644 --- a/src/core/top.hpp +++ b/src/core/top.hpp @@ -43,7 +43,7 @@ namespace spla { { \ auto func = make_ref>(); \ \ - func->function = [](A0 a) -> R __VA_ARGS__; \ + func->function = [](A0 a)->R __VA_ARGS__; \ func->name = #fname; \ \ std::stringstream source_builder; \ @@ -69,7 +69,7 @@ namespace spla { { \ auto func = make_ref>(); \ \ - func->function = [](A0 a, A1 b) -> R __VA_ARGS__; \ + func->function = [](A0 a, A1 b)->R __VA_ARGS__; \ func->name = #fname; \ \ std::stringstream source_builder; \ @@ -98,7 +98,7 @@ namespace spla { { \ auto func = make_ref>(); \ \ - func->function = [](A0 a) -> bool __VA_ARGS__; \ + func->function = [](A0 a)->bool __VA_ARGS__; \ func->name = #fname; \ \ std::stringstream source_builder; \ diff --git a/src/core/tscalar.hpp b/src/core/tscalar.hpp index 091a469fa..4a07060f1 100644 --- a/src/core/tscalar.hpp +++ b/src/core/tscalar.hpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #ifndef SPLA_TSCALAR_HPP @@ -35,14 +42,14 @@ namespace spla { /** - * @addtogroup internal - * @{ - */ + * @addtogroup internal + * @{ + */ /** - * - * @tparam T - */ + * + * @tparam T + */ template class TScalar final : public Scalar { public: @@ -60,6 +67,7 @@ namespace spla { T_INT as_int() override { return static_cast(m_value); } T_UINT as_uint() override { return static_cast(m_value); } T_FLOAT as_float() override { return static_cast(m_value); } + T_PAIR as_pair() override { return static_cast(m_value); } void set_label(std::string label) override; const std::string& get_label() const override; @@ -73,8 +81,7 @@ namespace spla { }; template - TScalar::TScalar(T value) : m_value(value) { - } + TScalar::TScalar(T value) : m_value(value) {} template ref_ptr TScalar::get_type() { @@ -124,18 +131,78 @@ namespace spla { } template - T& TScalar::get_value() { - return m_value; - } + T& TScalar::get_value() { return m_value; } template - T TScalar::get_value() const { - return m_value; - } + T TScalar::get_value() const { return m_value; } + template<> + inline T_PAIR TScalar::as_pair() { return Pair(); } + template<> + inline T_PAIR TScalar::as_pair() { return Pair(); } + template<> + inline T_PAIR TScalar::as_pair() { return Pair(); } + + template<> + class TScalar final : public Scalar { + public: + TScalar() = default; + explicit TScalar(Pair value) : m_value(value) {} + ~TScalar() override = default; - /** - * @} - */ + Status set_pair(Pair value) { + m_value = value; + return Status::Ok; + } + + Status get_pair(Pair& value) const { + value = m_value; + return Status::Ok; + } + + ref_ptr get_type() override { return PAIR; } + + Status set_int(std::int32_t) override { return Status::InvalidArgument; } + + Status set_uint(std::uint32_t) override { return Status::InvalidArgument; } + + Status set_float(float) override { return Status::InvalidArgument; } + + Status get_int(std::int32_t&) override { return Status::InvalidArgument; } + + Status get_uint(std::uint32_t&) override { return Status::InvalidArgument; } + + Status get_float(float&) override { return Status::InvalidArgument; } + + T_INT as_int() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to int"); + return 0; + } + + T_UINT as_uint() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to uint"); + return 0; + } + + T_FLOAT as_float() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to float"); + return 0.0f; + } + T_PAIR as_pair() override { + LOG_MSG(Status::InvalidArgument, "cannot convert Pair to pair"); + return Pair(); + } + + void set_label(std::string label) override { m_label = std::move(label); } + + const std::string& get_label() const override { return m_label; } + + Pair& get_value() { return m_value; } + Pair get_value() const { return m_value; } + + private: + std::string m_label; + Pair m_value = Pair(); + }; }// namespace spla -#endif//SPLA_TSCALAR_HPP +#endif// SPLA_TSCALAR_HPP diff --git a/src/core/ttype.hpp b/src/core/ttype.hpp index 344de29ac..136401a7f 100644 --- a/src/core/ttype.hpp +++ b/src/core/ttype.hpp @@ -129,6 +129,10 @@ namespace spla { ref_ptr> get_ttype() { return FLOAT.cast_safe>(); } + template<> + ref_ptr> get_ttype() { + return PAIR.cast_safe>(); + } /** * @} diff --git a/src/core/tvector.hpp b/src/core/tvector.hpp index af3e5c71f..ed48a3a44 100644 --- a/src/core/tvector.hpp +++ b/src/core/tvector.hpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #ifndef SPLA_TVECTOR_HPP @@ -40,22 +47,23 @@ #include #include +#include "spla/pair.hpp" #include #include namespace spla { /** - * @addtogroup internal - * @{ - */ + * @addtogroup internal + * @{ + */ /** - * @class TVector - * @brief Vector interface implementation with type information bound - * - * @tparam T Type of stored elements - */ + * @class TVector + * @brief Vector interface implementation with type information bound + * + * @tparam T Type of stored elements + */ template class TVector final : public Vector { public: @@ -72,17 +80,26 @@ namespace spla { Status set_int(uint row_id, std::int32_t value) override; Status set_uint(uint row_id, std::uint32_t value) override; Status set_float(uint row_id, float value) override; - Status get_int(uint row_id, int32_t& value) override; - Status get_uint(uint row_id, uint32_t& value) override; - Status get_float(uint row_id, float& value) override; - Status fill_noize(uint seed) override; - Status fill_with(const ref_ptr& value) override; - Status build(const ref_ptr& keys, const ref_ptr& values) override; - Status read(ref_ptr& keys, ref_ptr& values) override; - Status clear() override; + Status set_pair(uint row_id, Pair value) override { + return Status::InvalidArgument; + } + Status get_int(uint row_id, int32_t& value) override; + Status get_uint(uint row_id, uint32_t& value) override; + Status get_float(uint row_id, float& value) override; + Status get_pair(uint row_id, Pair& value) override { + return Status::InvalidArgument; + } + Status fill_noize(uint seed) override; + Status fill_with(const ref_ptr& value) override; + Status build(const ref_ptr& keys, + const ref_ptr& values) override; + Status read(ref_ptr& keys, ref_ptr& values) override; + Status clear() override; template - Decorator* get() { return m_storage.template get(); } + Decorator* get() { + return m_storage.template get(); + } void validate_rw(FormatVector format); void validate_rwd(FormatVector format); @@ -132,9 +149,12 @@ namespace spla { if (value) { m_storage.invalidate(); - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_int()); - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_uint()); - if constexpr (std::is_same::value) m_storage.set_fill_value(value->as_float()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_int()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_uint()); + if constexpr (std::is_same::value) + m_storage.set_fill_value(value->as_float()); return Status::Ok; } @@ -167,6 +187,11 @@ namespace spla { cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); return Status::Ok; } + template<> + inline Status TVector::set_int(uint row_id, std::int32_t value) { + return Status::InvalidArgument; + } + template Status TVector::set_uint(uint row_id, std::uint32_t value) { if (is_valid(FormatVector::CpuDense)) { @@ -179,6 +204,10 @@ namespace spla { cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); return Status::Ok; } + template<> + inline Status TVector::set_uint(uint row_id, std::uint32_t value) { + return Status::InvalidArgument; + } template Status TVector::set_float(uint row_id, float value) { if (is_valid(FormatVector::CpuDense)) { @@ -191,6 +220,10 @@ namespace spla { cpu_dok_vec_add_element(row_id, static_cast(value), *get>()); return Status::Ok; } + template<> + inline Status TVector::set_float(uint row_id, float value) { + return Status::InvalidArgument; + } template Status TVector::get_int(uint row_id, int32_t& value) { @@ -206,6 +239,10 @@ namespace spla { return Status::Ok; } + template<> + inline Status TVector::get_int(uint row_id, int32_t& value) { + return Status::InvalidArgument; + } template Status TVector::get_uint(uint row_id, uint32_t& value) { validate_rw(FormatVector::CpuDok); @@ -220,6 +257,10 @@ namespace spla { return Status::Ok; } + template<> + inline Status TVector::get_uint(uint row_id, uint32_t& value) { + return Status::InvalidArgument; + } template Status TVector::get_float(uint row_id, float& value) { validate_rw(FormatVector::CpuDok); @@ -234,6 +275,10 @@ namespace spla { return Status::Ok; } + template<> + inline Status TVector::get_float(uint row_id, float& value) { + return Status::InvalidArgument; + } template Status TVector::fill_noize(uint seed) { @@ -243,11 +288,13 @@ namespace spla { if constexpr (std::is_integral_v) { std::uniform_int_distribution dist; - for (auto& x : Ax) x = dist(engine); + for (auto& x : Ax) + x = dist(engine); } if constexpr (std::is_floating_point_v) { std::uniform_real_distribution dist; - for (auto& x : Ax) x = dist(engine); + for (auto& x : Ax) + x = dist(engine); } return Status::Ok; @@ -258,9 +305,12 @@ namespace spla { T t = T(); - if constexpr (std::is_same::value) t = value->as_int(); - if constexpr (std::is_same::value) t = value->as_uint(); - if constexpr (std::is_same::value) t = value->as_float(); + if constexpr (std::is_same::value) + t = value->as_int(); + if constexpr (std::is_same::value) + t = value->as_uint(); + if constexpr (std::is_same::value) + t = value->as_float(); validate_wd(FormatVector::CpuDense); auto& Ax = get>()->Ax; @@ -270,7 +320,8 @@ namespace spla { } template - Status TVector::build(const ref_ptr& keys, const ref_ptr& values) { + Status TVector::build(const ref_ptr& keys, + const ref_ptr& values) { assert(keys); assert(values); @@ -359,11 +410,46 @@ namespace spla { return storage_manager.get(); } + template<> + inline Status TVector::set_pair(uint row_id, Pair value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + + if (is_valid(FormatVector::CpuDense)) { + validate_rwd(FormatVector::CpuDense); + get>()->Ax[row_id] = value; + return Status::Ok; + } + + validate_rwd(FormatVector::CpuDok); + cpu_dok_vec_add_element(row_id, value, *get>()); + return Status::Ok; + } + template<> + inline Status TVector::get_pair(uint row_id, Pair& value) { + if (get_type() != PAIR) { + return Status::InvalidArgument; + } + + validate_rw(FormatVector::CpuDok); + + const auto& Ax = get>()->Ax; + const auto entry = Ax.find(row_id); + + if (entry != Ax.end()) { + value = entry->second; + } else { + value = m_storage.get_fill_value(); + } + + return Status::Ok; + } /** - * @} - */ + * @} + */ }// namespace spla -#endif//SPLA_TVECTOR_HPP +#endif// SPLA_TVECTOR_HPP diff --git a/src/cpu/cpu_algo_registry.cpp b/src/cpu/cpu_algo_registry.cpp index 4944eb756..2e23d1ce0 100644 --- a/src/cpu/cpu_algo_registry.cpp +++ b/src/cpu/cpu_algo_registry.cpp @@ -127,6 +127,8 @@ namespace spla { g_registry->add(MAKE_KEY_CPU_0("m_extract_row", INT), std::make_shared>()); g_registry->add(MAKE_KEY_CPU_0("m_extract_row", UINT), std::make_shared>()); g_registry->add(MAKE_KEY_CPU_0("m_extract_row", FLOAT), std::make_shared>()); + g_registry->add(MAKE_KEY_CPU_0("m_extract_row", PAIR), + std::make_shared>()); // algorthm m_extract_column g_registry->add(MAKE_KEY_CPU_0("m_extract_column", INT), std::make_shared>()); diff --git a/src/io.cpp b/src/io.cpp index 9a962087d..65c753441 100644 --- a/src/io.cpp +++ b/src/io.cpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include @@ -44,10 +51,10 @@ namespace spla { - MtxLoader::MtxLoader(std::string name) : m_name(std::move(name)) { - } + MtxLoader::MtxLoader(std::string name) : m_name(std::move(name)) {} - bool MtxLoader::load(std::filesystem::path file_path, bool offset_indices, bool make_undirected, bool remove_loops) { + bool MtxLoader::load(std::filesystem::path file_path, bool offset_indices, + bool make_undirected, bool remove_loops) { m_file_path = std::move(file_path); m_base_is_zero = offset_indices; @@ -65,7 +72,8 @@ namespace spla { std::string line; while (std::getline(file, line)) { - if (line[0] != '%') break; + if (line[0] != '%') + break; n_lines++; } @@ -73,13 +81,23 @@ namespace spla { std::stringstream header(line); header >> m_n_rows >> m_n_cols >> nnz; + bool file_has_values = false; + if (line.find("pattern") == + std::string::npos) {// есть подстрока pattern => граф невзвешенный + file_has_values = true; + } + std::cout << "Loading matrix-market coordinate format data... " << std::endl; std::cout << " Reading from " << m_file_path << std::endl; - std::cout << " Matrix size " << m_n_rows << " rows, " << m_n_cols << " cols" << std::endl; + std::cout << " Matrix size " << m_n_rows << " rows, " << m_n_cols << " cols" + << std::endl; std::cout << " Data: " << nnz << " directed edges" << std::endl; - if (remove_loops) std::cout << " Opt: remove self-loops" << std::endl; - if (offset_indices) std::cout << " Opt: offset indices by -1" << std::endl; - if (make_undirected) std::cout << " Opt: double edges" << std::endl; + if (remove_loops) + std::cout << " Opt: remove self-loops" << std::endl; + if (offset_indices) + std::cout << " Opt: offset indices by -1" << std::endl; + if (make_undirected) + std::cout << " Opt: double edges" << std::endl; std::cout << " Reading data: "; // optimized reading by sliding window @@ -89,15 +107,18 @@ namespace spla { char buffer[BUFFER_CAPACITY + 1]; // read data - std::size_t to_count = 0; - std::size_t to_read = nnz; - std::size_t to_preallocate = to_read * (make_undirected ? 2 : 1); - std::vector Ai; - std::vector Aj; + std::size_t to_count = 0; + std::size_t to_read = nnz; + std::size_t to_preallocate = to_read * (make_undirected ? 2 : 1); + std::vector Ai; + std::vector Aj; + std::vector Av; // preallocate to avoid copy Ai.reserve(to_preallocate); Aj.reserve(to_preallocate); + if (file_has_values) + Av.reserve(to_preallocate); float job_done = 0.0f; float job_total = 35.0f; @@ -134,7 +155,8 @@ namespace spla { if (buffer_offset > 0) { if (buffer_offset < BUFFER_CAPACITY) { - std::memcpy(buffer, buffer + buffer_offset, BUFFER_CAPACITY - buffer_offset); + std::memcpy(buffer, buffer + buffer_offset, + BUFFER_CAPACITY - buffer_offset); } buffer_offset = BUFFER_CAPACITY - buffer_offset; } @@ -149,15 +171,26 @@ namespace spla { } } - char* end = nullptr; - auto i = uint(std::strtoll(buffer + buffer_offset, &end, 10)); - auto j = uint(std::strtoll(end, &end, 10)); + char* end = nullptr; + auto i = uint(std::strtoll(buffer + buffer_offset, &end, 10)); + auto j = uint(std::strtoll(end, &end, 10)); + float val = 1.0f;// default value + + if (file_has_values) { + char* next = end; + while (*next == ' ' || *next == '\t') + next++; + if (*next != '\n' && *next != '\0') { + val = static_cast(std::strtod(next, &end)); + } + } buffer_offset = line_end + 1; assert(i > 0 && j > 0); if (remove_loops) { - if (i == j) continue; + if (i == j) + continue; } if (offset_indices) { i -= 1; @@ -166,53 +199,98 @@ namespace spla { if (make_undirected) { Ai.push_back(j); Aj.push_back(i); + if (file_has_values) + Av.push_back(val); } Ai.push_back(i); Aj.push_back(j); + if (file_has_values) + Av.push_back(val); } t.lap_end();// parsing - std::vector sorted; - { - sorted.reserve(Ai.size()); - n_sort = Ai.size(); + if (file_has_values) { + struct Edge { + uint i, j; + float w; + bool operator<(const Edge& other) const { + if (i != other.i) + return i < other.i; + return j < other.j; + } + }; + std::vector edges; + edges.reserve(Ai.size()); for (std::size_t k = 0; k < Ai.size(); k++) { - std::uint64_t entry = 0; - entry |= std::uint64_t(Ai[k]) << 32u; - entry |= std::uint64_t(Aj[k]) << 0u; - sorted.push_back(entry); + edges.push_back({Ai[k], Aj[k], Av[k]}); } - Ai.clear(); - Aj.clear(); - std::sort(sorted.begin(), sorted.end()); - } - t.lap_end();// sorting - - std::vector reduced_Ai; - std::vector reduced_Aj; - { - reduced_Ai.reserve(sorted.size()); - reduced_Aj.reserve(sorted.size()); - - std::uint64_t entry_prev = 0xffffffffffffffff; - for (std::uint64_t entry : sorted) { - if (entry_prev != entry) { - uint i = uint((entry >> 32u) & 0xffffffff); - uint j = uint((entry >> 0u) & 0xffffffff); - reduced_Ai.push_back(i); - reduced_Aj.push_back(j); + std::sort(edges.begin(), edges.end()); + + std::vector reduced_Ai; + std::vector reduced_Aj; + std::vector reduced_Av; + reduced_Ai.reserve(edges.size()); + reduced_Aj.reserve(edges.size()); + reduced_Av.reserve(edges.size()); + + for (std::size_t k = 0; k < edges.size(); k++) { + if (k == 0 || edges[k].i != edges[k - 1].i || + edges[k].j != edges[k - 1].j) { + reduced_Ai.push_back(edges[k].i); + reduced_Aj.push_back(edges[k].j); + reduced_Av.push_back(edges[k].w); } - entry_prev = entry; } m_n_values = reduced_Ai.size(); m_Ai = std::move(reduced_Ai); m_Aj = std::move(reduced_Aj); + m_Aw = std::move(reduced_Av); + + } else { + std::vector sorted; + { + sorted.reserve(Ai.size()); + n_sort = Ai.size(); + + for (std::size_t k = 0; k < Ai.size(); k++) { + std::uint64_t entry = 0; + entry |= std::uint64_t(Ai[k]) << 32u; + entry |= std::uint64_t(Aj[k]) << 0u; + sorted.push_back(entry); + } + Ai.clear(); + Aj.clear(); + + std::sort(sorted.begin(), sorted.end()); + } + t.lap_end();// sorting + + std::vector reduced_Ai; + std::vector reduced_Aj; + { + reduced_Ai.reserve(sorted.size()); + reduced_Aj.reserve(sorted.size()); + + std::uint64_t entry_prev = 0xffffffffffffffff; + for (std::uint64_t entry : sorted) { + if (entry_prev != entry) { + uint i = uint((entry >> 32u) & 0xffffffff); + uint j = uint((entry >> 0u) & 0xffffffff); + reduced_Ai.push_back(i); + reduced_Aj.push_back(j); + } + entry_prev = entry; + } + + m_n_values = reduced_Ai.size(); + m_Ai = std::move(reduced_Ai); + m_Aj = std::move(reduced_Aj); + } } - t.lap_end();// reducing calc_stats(); t.lap_end();// stats @@ -220,12 +298,18 @@ namespace spla { t.stop(); std::cout << " 100%" << std::endl; - std::cout << " Parsed in " << t.get_laps_ms()[0] * 1e-3 << " sec " << n_lines << " lines" - << " speed " << float(n_lines) / (t.get_laps_ms()[0] * 1e-3) << " lines/sec" << std::endl; - std::cout << " Sorted in " << t.get_laps_ms()[1] * 1e-3 << " sec " << n_sort << " lines" << std::endl; - std::cout << " Reduced in " << t.get_laps_ms()[2] * 1e-3 << " sec " << m_n_values << " lines" << std::endl; - std::cout << " Calc stats in " << t.get_laps_ms()[3] * 1e-3 << " sec" << std::endl; - std::cout << " Loaded in " << t.get_elapsed_ms() * 1e-3 << " sec, " << m_n_values << " edges total" << std::endl; + std::cout << " Parsed in " << t.get_laps_ms()[0] * 1e-3 << " sec " << n_lines + << " lines" + << " speed " << float(n_lines) / (t.get_laps_ms()[0] * 1e-3) + << " lines/sec" << std::endl; + std::cout << " Sorted in " << t.get_laps_ms()[1] * 1e-3 << " sec " << n_sort + << " lines" << std::endl; + std::cout << " Reduced in " << t.get_laps_ms()[2] * 1e-3 << " sec " + << m_n_values << " lines" << std::endl; + std::cout << " Calc stats in " << t.get_laps_ms()[3] * 1e-3 << " sec" + << std::endl; + std::cout << " Loaded in " << t.get_elapsed_ms() * 1e-3 << " sec, " + << m_n_values << " edges total" << std::endl; output_stats(); @@ -241,8 +325,10 @@ namespace spla { } file << "%%MatrixMarket matrix coordinate pattern general\n"; - file << "%-------------------------------------------------------------------------------\n"; - file << "%-------------------------------------------------------------------------------\n"; + file << "%-------------------------------------------------------------------" + "------------\n"; + file << "%-------------------------------------------------------------------" + "------------\n"; file << "% meta-info:\n"; file << "% name: " << m_name << "\n"; @@ -254,10 +340,12 @@ namespace spla { file << "% deg-distribution: \n"; for (std::size_t i = 0; i < m_deg_distribution.size(); i++) { - file << "% " << m_deg_ranges[i] << " " << m_deg_ranges[i + 1] << " " << m_deg_distribution[i] << "\n"; + file << "% " << m_deg_ranges[i] << " " << m_deg_ranges[i + 1] << " " + << m_deg_distribution[i] << "\n"; } - file << "%-------------------------------------------------------------------------------\n"; + file << "%-------------------------------------------------------------------" + "------------\n"; file << m_n_rows << " " << m_n_cols << " " << m_n_values << "\n"; if (!stats_only) { @@ -292,9 +380,11 @@ namespace spla { auto n = static_cast(m_n_rows); m_deg_avg = m_deg_avg / n; - m_deg_sd = std::sqrt(n * (m_deg_sd / n - m_deg_avg * m_deg_avg) / (n > 1.0 ? n - 1.0 : 1.0)); + m_deg_sd = std::sqrt(n * (m_deg_sd / n - m_deg_avg * m_deg_avg) / + (n > 1.0 ? n - 1.0 : 1.0)); - const uint GROUPS_COUNT_MAX = std::max(uint(10), uint(std::log2(double(m_n_rows) * 0.77))); + const uint GROUPS_COUNT_MAX = + std::max(uint(10), uint(std::log2(double(m_n_rows) * 0.77))); std::vector count_per_deg(static_cast(m_deg_max) + 2, 0); std::vector count_per_deg_offsets(static_cast(m_deg_max) + 2, 0); @@ -303,23 +393,27 @@ namespace spla { count_per_deg[std::min(deg_pre_vertex[i], uint(m_deg_max))] += 1; } - std::exclusive_scan(count_per_deg.begin(), count_per_deg.end(), count_per_deg_offsets.begin(), 0); + std::exclusive_scan(count_per_deg.begin(), count_per_deg.end(), + count_per_deg_offsets.begin(), 0); count_per_deg_offsets.back() += 1; std::vector distributions; std::vector ranges; - auto range = m_deg_max - m_deg_min; - auto groups_count = std::max(std::min(GROUPS_COUNT_MAX, static_cast(range)), 1u); - auto g = static_cast(groups_count); + auto range = m_deg_max - m_deg_min; + auto groups_count = + std::max(std::min(GROUPS_COUNT_MAX, static_cast(range)), 1u); + auto g = static_cast(groups_count); auto total = static_cast(count_per_deg_offsets.back()); auto from = count_per_deg_offsets.begin(); ranges.push_back(static_cast(m_deg_min)); for (uint i = 0; i < groups_count; ++i) { - auto next = (from + 1 == count_per_deg_offsets.end()) ? from : from + 1; - auto to = std::lower_bound(next, count_per_deg_offsets.end(), static_cast(total / g * static_cast(i + 1))); + auto next = (from + 1 == count_per_deg_offsets.end()) ? from : from + 1; + auto to = std::lower_bound( + next, count_per_deg_offsets.end(), + static_cast(total / g * static_cast(i + 1))); auto to_offset = std::distance(count_per_deg_offsets.begin(), to); assert(to != count_per_deg_offsets.end()); @@ -346,7 +440,8 @@ namespace spla { const auto default_precision{std::cout.precision()}; const auto n_digits = static_cast(std::log10(n) + 1.0); - const double DISPLAY_DENSITY = std::max(double(100), double(m_deg_distribution.size())); + const double DISPLAY_DENSITY = + std::max(double(100), double(m_deg_distribution.size())); for (std::size_t i = 0; i < m_deg_distribution.size(); i++) { auto deg = m_deg_distribution[i] >= 0.01 ? m_deg_distribution[i] : 0.0; @@ -354,28 +449,23 @@ namespace spla { auto k_end = m_deg_ranges[i + 1]; auto k_count = std::round(static_cast(deg * DISPLAY_DENSITY)); - std::cout << " [" << std::setw(n_digits) << k_start << " - " << std::setw(n_digits) << k_end << "): "; - std::cout << std::setw(6) << std::setprecision(2) << deg * 100.0 << std::setprecision(default_precision) << "% "; - for (uint s = 0; s < k_count; ++s) { std::cout << "*"; } + std::cout << " [" << std::setw(n_digits) << k_start << " - " + << std::setw(n_digits) << k_end << "): "; + std::cout << std::setw(6) << std::setprecision(2) << deg * 100.0 + << std::setprecision(default_precision) << "% "; + for (uint s = 0; s < k_count; ++s) { + std::cout << "*"; + } std::cout << std::endl; } } - const std::vector& MtxLoader::get_Ai() const { - return m_Ai; - } - const std::vector& MtxLoader::get_Aj() const { - return m_Aj; - } + const std::vector& MtxLoader::get_Ai() const { return m_Ai; } + const std::vector& MtxLoader::get_Aj() const { return m_Aj; } + const std::vector& MtxLoader::get_Aw() const { return m_Aw; } - uint MtxLoader::get_n_rows() const { - return m_n_rows; - } - uint MtxLoader::get_n_cols() const { - return m_n_cols; - } - std::size_t MtxLoader::get_n_values() const { - return m_n_values; - } + uint MtxLoader::get_n_rows() const { return m_n_rows; } + uint MtxLoader::get_n_cols() const { return m_n_cols; } + std::size_t MtxLoader::get_n_values() const { return m_n_values; } }// namespace spla \ No newline at end of file diff --git a/src/matrix.cpp b/src/matrix.cpp index 4ecab16aa..aa319ca3b 100644 --- a/src/matrix.cpp +++ b/src/matrix.cpp @@ -51,6 +51,9 @@ namespace spla { if (type == FLOAT) { return ref_ptr(new TMatrix(n_rows, n_cols)); } + if (type == PAIR) { + return ref_ptr(new TMatrix(n_rows, n_cols)); + } LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); return ref_ptr(); diff --git a/src/op.cpp b/src/op.cpp index a1bca7514..8509dc750 100644 --- a/src/op.cpp +++ b/src/op.cpp @@ -1,36 +1,45 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include #include "spla/op.hpp" +#include "spla/pair.hpp" #include #include +#include namespace spla { @@ -69,6 +78,7 @@ namespace spla { ref_ptr FLOOR_FLOAT; ref_ptr ROUND_FLOAT; ref_ptr TRUNC_FLOAT; + ref_ptr IDENTITY_PAIR; ////////////////////////////////////////////////////////////////////////////// @@ -121,6 +131,9 @@ namespace spla { ref_ptr BXOR_INT; ref_ptr BXOR_UINT; + ref_ptr MIN_PAIR; + ref_ptr MUL_PAIR; + ////////////////////////////////////////////////////////////////////////////// ref_ptr EQZERO_INT; @@ -144,6 +157,7 @@ namespace spla { ref_ptr ALWAYS_INT; ref_ptr ALWAYS_UINT; ref_ptr ALWAYS_FLOAT; + ref_ptr ALWAYS_PAIR; ref_ptr NEVER_INT; ref_ptr NEVER_UINT; ref_ptr NEVER_FLOAT; @@ -190,6 +204,8 @@ namespace spla { DECL_OP_UNA_S(FLOOR_FLOAT, FLOOR, T_FLOAT, { return floor(a); }); DECL_OP_UNA_S(ROUND_FLOAT, ROUND, T_FLOAT, { return round(a); }); DECL_OP_UNA_S(TRUNC_FLOAT, TRUNC, T_FLOAT, { return trunc(a); }); + IDENTITY_PAIR = spla::OpUnary::make_pair( + "IDENTITY_PAIR", "(a) identity_pair(a)", [](Pair a) { return a; }); DECL_OP_BIN_S(PLUS_INT, PLUS, T_INT, { return a + b; }); DECL_OP_BIN_S(PLUS_UINT, PLUS, T_UINT, { return a + b; }); @@ -204,9 +220,12 @@ namespace spla { DECL_OP_BIN_S(DIV_UINT, DIV, T_UINT, { return a / b; }); DECL_OP_BIN_S(DIV_FLOAT, DIV, T_FLOAT, { return a / b; }); - DECL_OP_BIN_S(MINUS_POW2_INT, MINUS_POW2, T_INT, { return (a - b) * (a - b); }); - DECL_OP_BIN_S(MINUS_POW2_UINT, MINUS_POW2, T_UINT, { return (a - b) * (a - b); }); - DECL_OP_BIN_S(MINUS_POW2_FLOAT, MINUS_POW2, T_FLOAT, { return (a - b) * (a - b); }); + DECL_OP_BIN_S(MINUS_POW2_INT, MINUS_POW2, T_INT, + { return (a - b) * (a - b); }); + DECL_OP_BIN_S(MINUS_POW2_UINT, MINUS_POW2, T_UINT, + { return (a - b) * (a - b); }); + DECL_OP_BIN_S(MINUS_POW2_FLOAT, MINUS_POW2, T_FLOAT, + { return (a - b) * (a - b); }); DECL_OP_BIN_S(FIRST_INT, FIRST, T_INT, { return a; }); DECL_OP_BIN_S(FIRST_UINT, FIRST, T_UINT, { return a; }); @@ -240,6 +259,16 @@ namespace spla { DECL_OP_BIN_S(BXOR_INT, BXOR, T_INT, { return a ^ b; }); DECL_OP_BIN_S(BXOR_UINT, BXOR, T_UINT, { return a ^ b; }); + MUL_PAIR = OpBinary::make_pair( + "MUL_PAIR", "(a, b) make_pair(a.weight, b.vertex)", + [](Pair a, Pair b) { return Pair(a.weight, b.vertex); }); + MIN_PAIR = OpBinary::make_pair("MIN_PAIR", "(a, b) min_pair(a, b)", + [](Pair a, Pair b) { + if (a.weight == b.weight) + return a.vertex < b.vertex ? a : b; + return a.weight < b.weight ? a : b; + }); + DECL_OP_SELECT(EQZERO_INT, EQZERO, T_INT, { return a == 0; }); DECL_OP_SELECT(EQZERO_UINT, EQZERO, T_UINT, { return a == 0; }); DECL_OP_SELECT(EQZERO_FLOAT, EQZERO, T_FLOAT, { return a == 0; }); @@ -261,62 +290,101 @@ namespace spla { DECL_OP_SELECT(ALWAYS_INT, ALWAYS, T_INT, { return 1; }); DECL_OP_SELECT(ALWAYS_UINT, ALWAYS, T_UINT, { return 1; }); DECL_OP_SELECT(ALWAYS_FLOAT, ALWAYS, T_FLOAT, { return 1; }); + ALWAYS_PAIR = OpSelect::make_pair("ALWAYS_PAIR", "(a) pair_always(a)", + [](Pair a) { return 1; }); DECL_OP_SELECT(NEVER_INT, NEVER, T_INT, { return 0; }); DECL_OP_SELECT(NEVER_UINT, NEVER, T_UINT, { return 0; }); DECL_OP_SELECT(NEVER_FLOAT, NEVER, T_FLOAT, { return 0; }); } - ref_ptr OpUnary::make_int(std::string name, std::string code, std::function function) { + ref_ptr OpUnary::make_int(std::string name, std::string code, + std::function function) { auto op = make_ref>(); op->name = std::move(name); op->function = std::move(function); op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_res()->get_code(); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_res()->get_code(); return op.as(); } - ref_ptr OpUnary::make_uint(std::string name, std::string code, std::function function) { + ref_ptr OpUnary::make_uint(std::string name, std::string code, + std::function function) { auto op = make_ref>(); op->name = std::move(name); op->function = std::move(function); op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_res()->get_code(); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_res()->get_code(); return op.as(); } - ref_ptr OpUnary::make_float(std::string name, std::string code, std::function function) { + ref_ptr OpUnary::make_float(std::string name, std::string code, + std::function function) { auto op = make_ref>(); op->name = std::move(name); op->function = std::move(function); op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_res()->get_code(); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_res()->get_code(); + return op.as(); + } + ref_ptr OpUnary::make_pair(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_res()->get_code(); return op.as(); } - ref_ptr OpBinary::make_int(std::string name, std::string code, std::function function) { + ref_ptr + OpBinary::make_int(std::string name, std::string code, + std::function function) { auto op = make_ref>(); op->name = std::move(name); op->function = std::move(function); op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); return op.as(); } - ref_ptr OpBinary::make_uint(std::string name, std::string code, std::function function) { + ref_ptr + OpBinary::make_uint(std::string name, std::string code, + std::function function) { auto op = make_ref>(); op->name = std::move(name); op->function = std::move(function); op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); return op.as(); } - ref_ptr OpBinary::make_float(std::string name, std::string code, std::function function) { + ref_ptr + OpBinary::make_float(std::string name, std::string code, + std::function function) { auto op = make_ref>(); op->name = std::move(name); op->function = std::move(function); op->source = std::move(code); - op->key = op->name + "_" + op->get_type_arg_0()->get_code() + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); + return op.as(); + } + ref_ptr + OpBinary::make_pair(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code() + + op->get_type_arg_1()->get_code() + op->get_type_res()->get_code(); return op.as(); } - ref_ptr OpSelect::make_int(std::string name, std::string code, std::function function) { + ref_ptr OpSelect::make_int(std::string name, std::string code, + std::function function) { auto op = make_ref>(); op->name = std::move(name); op->function = std::move(function); @@ -324,7 +392,8 @@ namespace spla { op->key = op->name + "_" + op->get_type_arg_0()->get_code(); return op.as(); } - ref_ptr OpSelect::make_uint(std::string name, std::string code, std::function function) { + ref_ptr OpSelect::make_uint(std::string name, std::string code, + std::function function) { auto op = make_ref>(); op->name = std::move(name); op->function = std::move(function); @@ -332,7 +401,8 @@ namespace spla { op->key = op->name + "_" + op->get_type_arg_0()->get_code(); return op.as(); } - ref_ptr OpSelect::make_float(std::string name, std::string code, std::function function) { + ref_ptr OpSelect::make_float(std::string name, std::string code, + std::function function) { auto op = make_ref>(); op->name = std::move(name); op->function = std::move(function); @@ -340,5 +410,14 @@ namespace spla { op->key = op->name + "_" + op->get_type_arg_0()->get_code(); return op.as(); } + ref_ptr OpSelect::make_pair(std::string name, std::string code, + std::function function) { + auto op = make_ref>(); + op->name = std::move(name); + op->function = std::move(function); + op->source = std::move(code); + op->key = op->name + "_" + op->get_type_arg_0()->get_code(); + return op.as(); + } }// namespace spla \ No newline at end of file diff --git a/src/opencl/cl_algo_registry.cpp b/src/opencl/cl_algo_registry.cpp index 65726579b..715acde66 100644 --- a/src/opencl/cl_algo_registry.cpp +++ b/src/opencl/cl_algo_registry.cpp @@ -83,6 +83,8 @@ namespace spla { g_registry->add(MAKE_KEY_CL_0("mxv_masked", INT), std::make_shared>()); g_registry->add(MAKE_KEY_CL_0("mxv_masked", UINT), std::make_shared>()); g_registry->add(MAKE_KEY_CL_0("mxv_masked", FLOAT), std::make_shared>()); + g_registry->add(MAKE_KEY_CL_0("mxv_masked", PAIR), + std::make_shared>()); // algorthm vxm_masked g_registry->add(MAKE_KEY_CL_0("vxm_masked", INT), std::make_shared>()); diff --git a/src/opencl/cl_mxv.hpp b/src/opencl/cl_mxv.hpp index 9a2e437bd..6e4bc8a88 100644 --- a/src/opencl/cl_mxv.hpp +++ b/src/opencl/cl_mxv.hpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #ifndef SPLA_CL_MXV_HPP @@ -54,9 +61,7 @@ namespace spla { public: ~Algo_mxv_masked_cl() override = default; - std::string get_name() override { - return "mxv_masked"; - } + std::string get_name() override { return "mxv_masked"; } std::string get_description() override { return "parallel matrix-vector masked product on opencl device"; @@ -79,14 +84,17 @@ namespace spla { auto t = ctx.task.template cast_safe(); - ref_ptr> r = t->r.template cast_safe>(); - ref_ptr> mask = t->mask.template cast_safe>(); - ref_ptr> M = t->M.template cast_safe>(); - ref_ptr> v = t->v.template cast_safe>(); - ref_ptr> op_multiply = t->op_multiply.template cast_safe>(); - ref_ptr> op_add = t->op_add.template cast_safe>(); - ref_ptr> op_select = t->op_select.template cast_safe>(); - ref_ptr> init = t->init.template cast_safe>(); + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> mask = t->mask.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> v = t->v.template cast_safe>(); + ref_ptr> op_multiply = + t->op_multiply.template cast_safe>(); + ref_ptr> op_add = + t->op_add.template cast_safe>(); + ref_ptr> op_select = + t->op_select.template cast_safe>(); + ref_ptr> init = t->init.template cast_safe>(); r->validate_wd(FormatVector::AccDense); mask->validate_rw(FormatVector::AccDense); @@ -94,7 +102,8 @@ namespace spla { v->validate_rw(FormatVector::AccDense); std::shared_ptr program; - if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError; + if (!ensure_kernel(op_multiply, op_add, op_select, program)) + return Status::CompilationError; auto* p_cl_r = r->template get>(); auto* p_cl_mask = mask->template get>(); @@ -114,11 +123,13 @@ namespace spla { kernel_vector.setArg(6, init->get_value()); kernel_vector.setArg(7, r->get_n_rows()); - uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_count, 1, 512); + uint n_groups_to_dispatch = + div_up_clamp(r->get_n_rows(), m_block_count, 1, 512); cl::NDRange exec_global(m_block_count * n_groups_to_dispatch, m_block_size); cl::NDRange exec_local(m_block_count, m_block_size); - CL_DISPATCH_PROFILED("exec", queue, kernel_vector, cl::NDRange(), exec_global, exec_local); + CL_DISPATCH_PROFILED("exec", queue, kernel_vector, cl::NDRange(), + exec_global, exec_local); return Status::Ok; } @@ -128,14 +139,17 @@ namespace spla { auto t = ctx.task.template cast_safe(); - ref_ptr> r = t->r.template cast_safe>(); - ref_ptr> mask = t->mask.template cast_safe>(); - ref_ptr> M = t->M.template cast_safe>(); - ref_ptr> v = t->v.template cast_safe>(); - ref_ptr> op_multiply = t->op_multiply.template cast_safe>(); - ref_ptr> op_add = t->op_add.template cast_safe>(); - ref_ptr> op_select = t->op_select.template cast_safe>(); - ref_ptr> init = t->init.template cast_safe>(); + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> mask = t->mask.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> v = t->v.template cast_safe>(); + ref_ptr> op_multiply = + t->op_multiply.template cast_safe>(); + ref_ptr> op_add = + t->op_add.template cast_safe>(); + ref_ptr> op_select = + t->op_select.template cast_safe>(); + ref_ptr> init = t->init.template cast_safe>(); r->validate_wd(FormatVector::AccDense); mask->validate_rw(FormatVector::AccDense); @@ -143,7 +157,8 @@ namespace spla { v->validate_rw(FormatVector::AccDense); std::shared_ptr program; - if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError; + if (!ensure_kernel(op_multiply, op_add, op_select, program)) + return Status::CompilationError; auto* p_cl_r = r->template get>(); auto* p_cl_mask = mask->template get>(); @@ -165,11 +180,13 @@ namespace spla { kernel_scalar.setArg(7, r->get_n_rows()); kernel_scalar.setArg(8, uint(early_exit)); - uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_size, 1, 512); + uint n_groups_to_dispatch = + div_up_clamp(r->get_n_rows(), m_block_size, 1, 512); cl::NDRange exec_global(m_block_size * n_groups_to_dispatch); cl::NDRange exec_local(m_block_size); - CL_DISPATCH_PROFILED("exec", queue, kernel_scalar, cl::NDRange(), exec_global, exec_local); + CL_DISPATCH_PROFILED("exec", queue, kernel_scalar, cl::NDRange(), + exec_global, exec_local); return Status::Ok; } @@ -179,14 +196,17 @@ namespace spla { auto t = ctx.task.template cast_safe(); - ref_ptr> r = t->r.template cast_safe>(); - ref_ptr> mask = t->mask.template cast_safe>(); - ref_ptr> M = t->M.template cast_safe>(); - ref_ptr> v = t->v.template cast_safe>(); - ref_ptr> op_multiply = t->op_multiply.template cast_safe>(); - ref_ptr> op_add = t->op_add.template cast_safe>(); - ref_ptr> op_select = t->op_select.template cast_safe>(); - ref_ptr> init = t->init.template cast_safe>(); + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> mask = t->mask.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> v = t->v.template cast_safe>(); + ref_ptr> op_multiply = + t->op_multiply.template cast_safe>(); + ref_ptr> op_add = + t->op_add.template cast_safe>(); + ref_ptr> op_select = + t->op_select.template cast_safe>(); + ref_ptr> init = t->init.template cast_safe>(); r->validate_wd(FormatVector::AccDense); mask->validate_rw(FormatVector::AccDense); @@ -194,7 +214,8 @@ namespace spla { v->validate_rw(FormatVector::AccDense); std::shared_ptr program; - if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError; + if (!ensure_kernel(op_multiply, op_add, op_select, program)) + return Status::CompilationError; auto* p_cl_r = r->template get>(); auto* p_cl_mask = mask->template get>(); @@ -206,8 +227,13 @@ namespace spla { auto& queue = p_cl_acc->get_queue_default(); uint config_size = 0; - cl::Buffer cl_config(p_cl_acc->get_context(), CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS, sizeof(uint) * M->get_n_rows()); - cl::Buffer cl_config_size(p_cl_acc->get_context(), CL_MEM_READ_WRITE | CL_MEM_HOST_READ_ONLY | CL_MEM_COPY_HOST_PTR, sizeof(uint), &config_size); + cl::Buffer cl_config(p_cl_acc->get_context(), + CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS, + sizeof(uint) * M->get_n_rows()); + cl::Buffer cl_config_size(p_cl_acc->get_context(), + CL_MEM_READ_WRITE | CL_MEM_HOST_READ_ONLY | + CL_MEM_COPY_HOST_PTR, + sizeof(uint), &config_size); auto kernel_config = program->make_kernel("mxv_config"); kernel_config.setArg(0, p_cl_mask->Ax); @@ -217,13 +243,16 @@ namespace spla { kernel_config.setArg(4, init->get_value()); kernel_config.setArg(5, M->get_n_rows()); - uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_size, 1, 1024); + uint n_groups_to_dispatch = + div_up_clamp(r->get_n_rows(), m_block_size, 1, 1024); cl::NDRange config_global(m_block_size * n_groups_to_dispatch); cl::NDRange config_local(m_block_size); - CL_DISPATCH_PROFILED("config", queue, kernel_config, cl::NDRange(), config_global, config_local); + CL_DISPATCH_PROFILED("config", queue, kernel_config, cl::NDRange(), + config_global, config_local); - CL_READ_PROFILED("config-size", queue, cl_config_size, true, 0, sizeof(config_size), &config_size); + CL_READ_PROFILED("config-size", queue, cl_config_size, true, 0, + sizeof(config_size), &config_size); auto kernel_config_scalar = program->make_kernel("mxv_config_scalar"); kernel_config_scalar.setArg(0, p_cl_M->Ap); @@ -240,7 +269,8 @@ namespace spla { cl::NDRange exec_global(m_block_size * n_groups_to_dispatch); cl::NDRange exec_local(m_block_size); - CL_DISPATCH_PROFILED("exec", queue, kernel_config_scalar, cl::NDRange(), exec_global, exec_local); + CL_DISPATCH_PROFILED("exec", queue, kernel_config_scalar, cl::NDRange(), + exec_global, exec_local); return Status::Ok; } @@ -255,18 +285,21 @@ namespace spla { assert(m_block_count >= 1); CLProgramBuilder program_builder; - program_builder - .set_name("mxv") + program_builder.set_name("mxv") .add_define("WARP_SIZE", get_acc_cl()->get_wave_size()) .add_define("BLOCK_SIZE", m_block_size) .add_define("BLOCK_COUNT", m_block_count) - .add_type("TYPE", get_ttype().template as()) - .add_op("OP_BINARY1", op_multiply.template as()) - .add_op("OP_BINARY2", op_add.template as()) - .add_op("OP_SELECT", op_select.template as()) - .set_source(source_mxv) - .acquire(); + .add_type("TYPE", get_ttype().template as()); + if constexpr (std::is_same_v) { + program_builder.add_define("USE_PAIR_SEMANTICS", 1); + program_builder.add_define("USE_PAIR_COMPARISON", 1); + } else { + program_builder.add_op("OP_BINARY1", op_multiply.template as()) + .add_op("OP_BINARY2", op_add.template as()) + .add_op("OP_SELECT", op_select.template as()); + } + program_builder.set_source(source_mxv).acquire(); program = program_builder.get_program(); return true; @@ -279,4 +312,4 @@ namespace spla { }// namespace spla -#endif//SPLA_CL_MXV_HPP +#endif// SPLA_CL_MXV_HPP diff --git a/src/opencl/cl_program_builder.cpp b/src/opencl/cl_program_builder.cpp index 0c17e4e13..5a2241e8e 100644 --- a/src/opencl/cl_program_builder.cpp +++ b/src/opencl/cl_program_builder.cpp @@ -30,6 +30,7 @@ #include #include +#include #include namespace spla { @@ -63,6 +64,7 @@ namespace spla { return *this; } void CLProgramBuilder::acquire() { + CLAccelerator* acc = get_acc_cl(); CLProgramCache* cache = acc->get_cache(); @@ -81,22 +83,31 @@ namespace spla { } std::stringstream builder; - + bool needs_pair_override = false; for (const auto& define : m_defines) { builder << "#define " << define.first << " " << define.second << "\n"; + if (define.first == "TYPE" && + define.second.find("Pair") != std::string::npos) { + needs_pair_override = true; + } } + builder << source_common_api; + if (needs_pair_override) { + builder << "#define OP_BINARY1(a, b) make_pair((a).weight, " + "(b).vertex)\n\n"; + builder << "#define OP_BINARY2(a, b) min_pair(a, b)\n\n"; + builder << "#define OP_SELECT(a) pair_always(a)\n\n"; + } + for (const auto& function : m_functions) { builder << function.second->get_type_res()->get_cpp() << " " << function.first << function.second->get_source_cl() << "\n"; } builder << m_source; - m_program_code = builder.str(); - m_program = std::make_shared(); - m_program->m_program = cl::Program(acc->get_context(), m_program_code); - + m_program_code = builder.str(); Timer t; t.start(); auto status = m_program->m_program.build("-cl-std=CL1.2"); diff --git a/src/opencl/generated/auto_common_api.hpp b/src/opencl/generated/auto_common_api.hpp index 0e1e2f7a6..30b06fe21 100644 --- a/src/opencl/generated/auto_common_api.hpp +++ b/src/opencl/generated/auto_common_api.hpp @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////// -// Copyright (c) 2021 - 2023 SparseLinearAlgebra +// Copyright (c) 2021 - 2026 SparseLinearAlgebra // Autogenerated file, do not modify //////////////////////////////////////////////////////////////////// @@ -7,6 +7,30 @@ static const char source_common_api[] = R"( +struct Pair{ + float weight; + int vertex; +}; + +struct Pair make_pair(float w, int v) { + struct Pair p; + p.weight = w; + p.vertex = v; + return p; + +} + +struct Pair min_pair(struct Pair a, struct Pair b) { + if (a.weight == b.weight) return a.vertex < b.vertex ? a : b; + return a.weight < b.weight ? a : b; +} + +int pair_always(struct Pair a) { + return 1; +} +struct Pair identity_pair(struct Pair a) { + return a; +} uint random_gen_java(ulong seed) { seed = (seed * 0x5DEECE66DL + 0xBL) & ((1L << 48L) - 1); diff --git a/src/opencl/kernels/common_api.cl b/src/opencl/kernels/common_api.cl index 73e3af610..aaaa91765 100644 --- a/src/opencl/kernels/common_api.cl +++ b/src/opencl/kernels/common_api.cl @@ -26,6 +26,30 @@ /**********************************************************************************/ #include "common_def.cl" +struct Pair{ + float weight; + int vertex; +}; + +struct Pair make_pair(float w, int v) { + struct Pair p; + p.weight = w; + p.vertex = v; + return p; + +} + +struct Pair min_pair(struct Pair a, struct Pair b) { + if (a.weight == b.weight) return a.vertex < b.vertex ? a : b; + return a.weight < b.weight ? a : b; +} + +int pair_always(struct Pair a) { + return 1; +} +struct Pair identity_pair(struct Pair a) { + return a; +} uint random_gen_java(ulong seed) { seed = (seed * 0x5DEECE66DL + 0xBL) & ((1L << 48L) - 1); diff --git a/src/opencl/kernels/common_def.cl b/src/opencl/kernels/common_def.cl index 880708f42..792ce452a 100644 --- a/src/opencl/kernels/common_def.cl +++ b/src/opencl/kernels/common_def.cl @@ -1,42 +1,54 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #pragma once -#define TYPE int -#define BLOCK_SIZE 32 +#define TYPE int +#define BLOCK_SIZE 32 #define LM_NUM_MEM_BANKS 32 -#define BLOCK_COUNT 1 -#define WARP_SIZE 32 -#define OP_SELECT(a) a -#define OP_UNARY(a) a -#define OP_BINARY(a, b) a + b +#define BLOCK_COUNT 1 +#define WARP_SIZE 32 +#define OP_UNARY(a) a +#ifndef USE_PAIR_SEMANTICS #define OP_BINARY1(a, b) a + b #define OP_BINARY2(a, b) a + b +#define OP_SELECT(a) a +#else +#define OP_BINARY1(a, b) make_pair((a).weight, (b).vertex) +#define OP_BINARY2(a, b) min_pair(a, b) +#define OP_SELECT(a) pair_always(a) +#endif #define __kernel #define __global @@ -48,22 +60,19 @@ #define half float struct float2 { - float x; + float x; }; struct float3 { - float x, y, z; + float x, y, z; }; struct float4 { - float x, y, z, w; + float x, y, z, w; }; -#define uint unsigned int +#define uint unsigned int #define ulong unsigned long int -enum cl_mem_fence_flags { - CLK_LOCAL_MEM_FENCE, - CLK_GLOBAL_MEM_FENCE -}; +enum cl_mem_fence_flags { CLK_LOCAL_MEM_FENCE, CLK_GLOBAL_MEM_FENCE }; void barrier(cl_mem_fence_flags flags); @@ -74,16 +83,16 @@ size_t get_local_id(uint dimindx); size_t get_num_groups(uint dimindx); size_t get_group_id(uint dimindx); size_t get_global_offset(uint dimindx); -uint get_work_dim(); +uint get_work_dim(); -#define atomic_add(p, val) p[0] += val -#define atomic_sub(p, val) p[0] -= val -#define atomic_inc(p) (p)[0] -#define atomic_dec(p) (p)[0] +#define atomic_add(p, val) p[0] += val +#define atomic_sub(p, val) p[0] -= val +#define atomic_inc(p) (p)[0] +#define atomic_dec(p) (p)[0] #define atomic_cmpxchg(p, cmp, val) ((p)[0] == cmp ? val : (p)[0]) -#define min(x, y) (x < y ? x : y) -#define max(x, y) (x > y ? x : y) -#define sin(x) x -#define cos(x) x +#define min(x, y) (x < y ? x : y) +#define max(x, y) (x > y ? x : y) +#define sin(x) x +#define cos(x) x #define fract(x, ptr) x \ No newline at end of file diff --git a/src/opencl/kernels/mxv.cl b/src/opencl/kernels/mxv.cl index e7bc77337..c61aa8478 100644 --- a/src/opencl/kernels/mxv.cl +++ b/src/opencl/kernels/mxv.cl @@ -112,7 +112,13 @@ __kernel void mxv_scalar(__global const uint* g_Ap, const uint col_id = g_Aj[i]; sum = OP_BINARY2(sum, OP_BINARY1(g_Ax[i], g_vx[col_id])); - if (early_exit && (sum != init)) break; + if (early_exit) { + #ifdef USE_PAIR_COMPARISON + if (sum.weight != init.weight || sum.vertex != init.vertex) break; + #else + if (sum != init) break; + #endif + } } } @@ -162,7 +168,13 @@ __kernel void mxv_config_scalar(__global const uint* g_Ap, const uint col_id = g_Aj[i]; sum = OP_BINARY2(sum, OP_BINARY1(g_Ax[i], g_vx[col_id])); - if (early_exit && (sum != init)) break; + if (early_exit) { + #ifdef USE_PAIR_COMPARISON + if (sum.weight != init.weight || sum.vertex != init.vertex) break; + #else + if (sum != init) break; + #endif + } } g_rx[row_id] = sum; diff --git a/src/scalar.cpp b/src/scalar.cpp index e48ce958e..551e34765 100644 --- a/src/scalar.cpp +++ b/src/scalar.cpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include @@ -47,6 +54,9 @@ namespace spla { if (type == FLOAT) { return ref_ptr(new TScalar()); } + if (type == PAIR) { + return ref_ptr(new TScalar()); + } LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); return ref_ptr(); diff --git a/src/type.cpp b/src/type.cpp index 7b75a7154..3e32829f2 100644 --- a/src/type.cpp +++ b/src/type.cpp @@ -1,37 +1,50 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include namespace spla { - ref_ptr BOOL = TType::make_type("BOOL", "B", "bool", "4 byte logical type", 1); - ref_ptr INT = TType::make_type("INT", "I", "int", "signed 4 byte integral type", 2); - ref_ptr UINT = TType::make_type("UINT", "U", "uint", "unsigned 4 byte integral type", 3); - ref_ptr FLOAT = TType::make_type("FLOAT", "F", "float", "4 byte floating point type", 4); + ref_ptr BOOL = + TType::make_type("BOOL", "B", "bool", "4 byte logical type", 1); + ref_ptr INT = TType::make_type("INT", "I", "int", + "signed 4 byte integral type", 2); + ref_ptr UINT = TType::make_type( + "UINT", "U", "uint", "unsigned 4 byte integral type", 3); + ref_ptr FLOAT = TType::make_type( + "FLOAT", "F", "float", "4 byte floating point type", 4); + ref_ptr PAIR = TType::make_type( + "PAIR", "P", "struct Pair", "weight-vertex pair float-int", 5); }// namespace spla \ No newline at end of file diff --git a/src/vector.cpp b/src/vector.cpp index e6427d90f..a25fc6d0c 100644 --- a/src/vector.cpp +++ b/src/vector.cpp @@ -1,28 +1,35 @@ /**********************************************************************************/ -/* This file is part of spla project */ -/* https://github.com/SparseLinearAlgebra/spla */ +/* This file is part of spla project */ +/* https://github.com/SparseLinearAlgebra/spla */ /**********************************************************************************/ -/* MIT License */ +/* MIT License */ /* */ -/* Copyright (c) 2023 SparseLinearAlgebra */ +/* Copyright (c) 2023 SparseLinearAlgebra */ /* */ -/* Permission is hereby granted, free of charge, to any person obtaining a copy */ -/* of this software and associated documentation files (the "Software"), to deal */ -/* in the Software without restriction, including without limitation the rights */ -/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ -/* copies of the Software, and to permit persons to whom the Software is */ -/* furnished to do so, subject to the following conditions: */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy + */ +/* of this software and associated documentation files (the "Software"), to deal + */ +/* in the Software without restriction, including without limitation the rights + */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ /* */ -/* The above copyright notice and this permission notice shall be included in all */ -/* copies or substantial portions of the Software. */ +/* The above copyright notice and this permission notice shall be included in + * all */ +/* copies or substantial portions of the Software. */ /* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ -/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ -/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ -/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ -/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ -/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ -/* SOFTWARE. */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + */ +/* SOFTWARE. */ /**********************************************************************************/ #include @@ -51,6 +58,9 @@ namespace spla { if (type == FLOAT) { return ref_ptr(new TVector(n_rows)); } + if (type == spla::PAIR) { + return ref_ptr(new TVector(n_rows)); + } LOG_MSG(Status::NotImplemented, "not supported type " << type->get_name()); return ref_ptr{}; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e20657a89..244ea481f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -15,6 +15,7 @@ spla_test_target(test_mxmT) spla_test_target(test_mxv) spla_test_target(test_vxm) spla_test_target(test_op) +spla_test_target(test_pair) if (SPLA_BUILD_OPENCL) spla_test_target(test_opencl) spla_test_target(test_opencl_merge) diff --git a/tests/test_op.cpp b/tests/test_op.cpp index a1d915891..db74ae415 100644 --- a/tests/test_op.cpp +++ b/tests/test_op.cpp @@ -48,6 +48,7 @@ TEST(op_binary, info_built_in) { display_op_info(spla::MIN_FLOAT); display_op_info(spla::BONE_FLOAT); + display_op_info(spla::MIN_PAIR); } TEST(op_binary, custom) { diff --git a/tests/test_pair.cpp b/tests/test_pair.cpp new file mode 100644 index 000000000..3f460061e --- /dev/null +++ b/tests/test_pair.cpp @@ -0,0 +1,137 @@ +#include "test_common.hpp" +#include + +#include "spla.hpp" + +TEST(pair, struct_creation) { + spla::T_PAIR p(2.5f, 2); + EXPECT_EQ(p.weight, 2.5f); + EXPECT_EQ(p.vertex, 2); +} +TEST(pair, basic_operations) { + spla::T_PAIR p1(2.5f, 2); + spla::T_PAIR p2(1.5f, 5); + + EXPECT_EQ(p1.weight, 2.5f); + EXPECT_EQ(p1.vertex, 2); + EXPECT_TRUE(p2.weight < p1.weight); +} + +TEST(pair, type_registration) { + auto type = spla::PAIR; + ASSERT_TRUE(type); + EXPECT_EQ(type->get_name(), "PAIR"); + EXPECT_EQ(type->get_code(), "P"); + EXPECT_EQ(type->get_cpp(), "struct Pair"); + EXPECT_EQ(type->get_description(), "weight-vertex pair float-int"); + EXPECT_EQ(type->get_size(), sizeof(spla::Pair)); + EXPECT_EQ(type->get_id(), 5); +} + +TEST(pair, op_registration) { + spla::Library::get(); + EXPECT_EQ(spla::MIN_PAIR->get_name(), "MIN_PAIR"); +} +TEST(pair, set_get_pair_matrix) { + auto S = spla::Matrix::make(2, 2, spla::PAIR); + S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); + spla::T_PAIR pair; + S->get_pair(0, 1, pair); + EXPECT_EQ(pair.vertex, 1); + EXPECT_EQ(pair.weight, 7.0f); +} +TEST(pair, set_get_pair_vector) { + auto V = spla::Vector::make(2, spla::PAIR); + V->set_pair(0, spla::T_PAIR(1.0f, 1)); + V->set_pair(1, spla::T_PAIR(2.0f, 2)); + spla::T_PAIR pair; + V->get_pair(0, pair); + EXPECT_EQ(pair.vertex, 1); + EXPECT_EQ(pair.weight, 1.0f); + V->get_pair(1, pair); + EXPECT_EQ(pair.vertex, 2); + EXPECT_EQ(pair.weight, 2.0f); +} +TEST(pair, set_get_pair_scalar) { + auto V = spla::Scalar::make(spla::PAIR); + spla::T_PAIR pair = spla::T_PAIR(1.0f, 1); + V->set_pair(pair); + V->get_pair(pair); + EXPECT_EQ(pair.vertex, 1); + EXPECT_EQ(pair.weight, 1.0f); +} + +TEST(pair, mxv_pair) { + int32_t n = 7; + auto S = spla::Matrix::make(n, n, spla::PAIR); + S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); + S->set_pair(0, 4, spla::T_PAIR(4.0f, 4)); + S->set_pair(1, 0, spla::T_PAIR(7.0f, 0)); + S->set_pair(1, 2, spla::T_PAIR(11.0f, 2)); + S->set_pair(1, 3, spla::T_PAIR(10.0f, 3)); + S->set_pair(1, 4, spla::T_PAIR(9.0f, 4)); + S->set_pair(2, 1, spla::T_PAIR(11.0f, 1)); + S->set_pair(2, 3, spla::T_PAIR(5.0f, 3)); + S->set_pair(3, 1, spla::T_PAIR(10.0f, 1)); + S->set_pair(3, 2, spla::T_PAIR(5.0f, 2)); + S->set_pair(3, 4, spla::T_PAIR(15.0f, 4)); + S->set_pair(3, 5, spla::T_PAIR(12.0f, 5)); + S->set_pair(3, 6, spla::T_PAIR(8.0f, 6)); + S->set_pair(4, 0, spla::T_PAIR(4.0f, 0)); + S->set_pair(4, 1, spla::T_PAIR(9.0f, 1)); + S->set_pair(4, 3, spla::T_PAIR(15.0f, 3)); + S->set_pair(4, 5, spla::T_PAIR(6.0f, 5)); + S->set_pair(5, 3, spla::T_PAIR(12.0f, 3)); + S->set_pair(5, 4, spla::T_PAIR(6.0f, 4)); + S->set_pair(5, 6, spla::T_PAIR(13.0f, 6)); + S->set_pair(6, 3, spla::T_PAIR(8.0f, 3)); + S->set_pair(6, 5, spla::T_PAIR(13.0f, 5)); + + auto parent = spla::Vector::make(n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + parent->set_pair(i, spla::T_PAIR(0.0f, i)); + } + auto edge = spla::Vector::make(n, spla::PAIR); + + auto mask = spla::Vector::make(n, spla::PAIR); + for (int32_t i = 0; i < n; i++) { + mask->set_pair(i, spla::T_PAIR(0.0f, 0)); + } + auto init_inf = spla::Scalar::make(spla::PAIR); + spla::T_PAIR init_val(1e9f, -1); + init_inf->set_pair(init_val); + spla::exec_mxv_masked(edge, mask, S, parent, spla::MUL_PAIR, spla::MIN_PAIR, + spla::ALWAYS_PAIR, init_inf); + + spla::T_PAIR expected[] = {spla::T_PAIR(4.0f, 4), spla::T_PAIR(7.0f, 0), + spla::T_PAIR(5.0f, 3), spla::T_PAIR(5.0f, 2), + spla::T_PAIR(4.0f, 0), spla::T_PAIR(6.0f, 4), + spla::T_PAIR(8.0f, 3)}; + + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + edge->get_pair(i, p); + EXPECT_FLOAT_EQ(p.weight, expected[i].weight); + EXPECT_EQ(p.vertex, expected[i].vertex); + } +} +TEST(pair, extract_row) { + int32_t n = 3; + auto S = spla::Matrix::make(n, n, spla::PAIR); + S->set_pair(0, 1, spla::T_PAIR(7.0f, 1)); + S->set_pair(1, 0, spla::T_PAIR(4.0f, 4)); + S->set_pair(1, 2, spla::T_PAIR(7.0f, 0)); + + auto row1 = spla::Vector::make(n, spla::PAIR); + spla::exec_m_extract_row(row1, S, 1, spla::IDENTITY_PAIR); + spla::T_PAIR expected[] = {spla::T_PAIR(4.0f, 4), spla::T_PAIR(), + spla::T_PAIR(7.0f, 0)}; + for (int32_t i = 0; i < n; i++) { + spla::T_PAIR p; + row1->get_pair(i, p); + EXPECT_FLOAT_EQ(p.weight, expected[i].weight); + EXPECT_EQ(p.vertex, expected[i].vertex); + } +} + +SPLA_GTEST_MAIN