doasku

Sudoku solver
git clone git://git.dimitrijedobrota.com/doasku.git
Log | Files | Refs | README | LICENSE

commit e5cc8a937714a060616c910895c3b132cfe8ae15
parent 260874dfa278ba73acb805eaf53464771e84d6e5
Author: Dimitrije Dobrota <mail@dimitrijedobrota.com>
Date:   Wed,  3 Apr 2024 01:16:49 +0200

Better type safety

Diffstat:
Mmain.cpp | 99++++++++++++++++++++++++++++++++++++++++++++++++-------------------------------
1 file changed, 60 insertions(+), 39 deletions(-)

diff --git a/main.cpp b/main.cpp @@ -9,29 +9,48 @@ static constexpr const std::int64_t mask_field = (1 << 9) - 1; static constexpr const std::int64_t mask_value = 0x201008040201; -using row_col_t = std::tuple<uint8_t, uint8_t>; +class row_col_t; +class row_col_rt; -std::ostream &operator<<(std::ostream &os, row_col_t rc) { - return os << std::format("({}, {})", std::get<0>(rc), std::get<1>(rc)); -} +struct row_col_t { + row_col_t(uint8_t row, uint8_t col) : row(row), col(col) {} + row_col_t(uint8_t field) : row(field / 3), col(field % 3) {} + row_col_t(row_col_rt rc); + + explicit operator row_col_rt() const; -row_col_t row_col(const uint8_t field) { return {field / 3, field % 3}; } -row_col_t row_col_rev(const uint8_t field) { return {2 - field % 3, field / 3}; } + friend std::ostream &operator<<(std::ostream &os, row_col_t rc) { + return os << std::format("({}, {})", rc.row, rc.col); + } + + uint8_t row; + uint8_t col; +}; -row_col_t row_col(const row_col_t rcr) { return {std::get<1>(rcr), 2 - std::get<0>(rcr)}; } -row_col_t row_col_rev(const row_col_t rc) { return {2 - std::get<1>(rc), std::get<0>(rc)}; } +struct row_col_rt { + row_col_rt(uint8_t row, uint8_t col) : row(row), col(col) {} + row_col_rt(uint8_t field) : row(2 - field % 3), col(field / 3) {} + row_col_rt(row_col_t rc); + + friend std::ostream &operator<<(std::ostream &os, row_col_rt rc) { + return os << std::format("({}, {})", rc.row, rc.col); + } + + uint8_t row; + uint8_t col; +}; -uint8_t field(const row_col_t rc) { return 3 * std::get<0>(rc) + std::get<1>(rc); } +row_col_t::row_col_t(row_col_rt rc) : row(rc.col), col(2 - rc.row) {} +row_col_rt::row_col_rt(row_col_t rc) : row(2 - rc.col), col(rc.row) {} + +uint8_t field(row_col_t rc) { return 3 * rc.row + rc.col; } class Row { public: Row(uint64_t value = 0) : val(value) {} operator uint64_t() const { return val; } - void set(uint8_t field, uint8_t value) { - clear(field); - toggle(field, value); - } + void set(uint8_t field, uint8_t value) { clear(field), toggle(field, value); } uint64_t get(uint8_t field) const { return (val >> 9 * field) & mask_field; } uint64_t get(uint8_t field, uint8_t value) const { return (val >> 9 * field) & (1 << value); } @@ -57,8 +76,8 @@ class Subgrid { public: Subgrid() {} - uint64_t get(row_col_t rc) const { return rows[std::get<0>(rc)].get(std::get<1>(rc)); } - uint64_t get_rev(row_col_t rc) const { return rows[std::get<0>(rc)].get(std::get<1>(rc) + 3); } + uint64_t get(row_col_t rc) const { return rows[rc.row].get(rc.col); } + uint64_t get(row_col_rt rc) const { return rows[rc.row].get(rc.col + 3); } void set(uint8_t field, uint8_t value) { assert(field < 9 && value < 9); @@ -67,30 +86,31 @@ class Subgrid { rows[1].clear_all(value); rows[2].clear_all(value); - set(row_col(field), value); - set_rev(row_col_rev(field), value); + set(row_col_t(field), value); + set(row_col_rt(field), value); } void clear_row(uint8_t row, uint8_t value) { assert(row < 3 && value < 9); - for (int i = 0; i < 3; i++) { - clear({row, i}, value); - clear_rev(row_col_rev({row, i}), value); + for (uint8_t i = 0; i < 3; i++) { + const row_col_t rc = {row, i}; + clear(row_col_rt(rc), value); + clear(rc, value); } } void clear_col(uint8_t col, uint8_t value) { assert(col < 3 && value < 9); - for (int i = 0; i < 3; i++) { - clear({i, col}, value); - clear_rev(row_col_rev({i, col}), value); + for (uint8_t i = 0; i < 3; i++) { + const row_col_t rc = {i, col}; + clear(row_col_rt(rc), value); + clear(rc, value); } } friend std::ostream &operator<<(std::ostream &os, const Subgrid &b) { - for (int i = 0; i < 3; i++) { os << b.rows[i] << " "; } @@ -98,11 +118,11 @@ class Subgrid { } private: - void set(row_col_t rc, uint8_t value) { rows[std::get<0>(rc)].set(std::get<1>(rc), value); } - void set_rev(row_col_t rc, uint8_t value) { rows[std::get<0>(rc)].set(std::get<1>(rc) + 3, value); } + void set(row_col_t rc, uint8_t value) { rows[rc.row].set(rc.col, value); } + void set(row_col_rt rcr, uint8_t value) { rows[rcr.row].set(rcr.col + 3, value); } - void clear(row_col_t rc, uint8_t value) { rows[std::get<0>(rc)].clear(std::get<1>(rc), value); } - void clear_rev(row_col_t rc, uint8_t value) { rows[std::get<0>(rc)].clear(std::get<1>(rc) + 3, value); } + void clear(row_col_t rc, uint8_t value) { rows[rc.row].clear(rc.col, value); } + void clear(row_col_rt rcr, uint8_t value) { rows[rcr.row].clear(rcr.col + 3, value); } Row rows[3] = {~0, ~0, ~0}; }; @@ -112,8 +132,8 @@ class Grid { void set(uint8_t subgrid, uint8_t field, uint8_t value) { assert(subgrid < 9 && field < 9 && value < 9); - const auto [row, col] = row_col(subgrid); - const auto [frow, fcol] = row_col(field); + const auto [row, col] = row_col_t(subgrid); + const auto [frow, fcol] = row_col_t(field); subgrids[3 * row + 0].clear_row(frow, value); subgrids[3 * row + 1].clear_row(frow, value); @@ -129,10 +149,10 @@ class Grid { void print() const { std::cout << "Field: " << std::endl; - for (int i = 0; i < 9; i++) { - for (int j = 0; j < 9; j++) { - const row_col_t subgrid = {i / 3, j / 3}; - const row_col_t field = {i % 3, j % 3}; + for (uint8_t i = 0; i < 9; i++) { + for (uint8_t j = 0; j < 9; j++) { + const row_col_t subgrid = {uint8_t(i / 3), uint8_t(j / 3)}; + const row_col_t field = {uint8_t(i % 3), uint8_t(j % 3)}; std::bitset<9> value = subgrids[::field(subgrid)].get(field); std::cout << value << " "; if (j % 3 == 2) std::cout << " "; @@ -144,11 +164,12 @@ class Grid { std::cout << std::endl; std::cout << "Reversed Field: " << std::endl; - for (int i = 0; i < 9; i++) { - for (int j = 0; j < 9; j++) { - const row_col_t subgrid = {i / 3, j / 3}; - const row_col_t rfield = row_col_rev({i % 3, j % 3}); - std::bitset<9> value = subgrids[::field(subgrid)].get_rev(rfield); + for (uint8_t i = 0; i < 9; i++) { + for (uint8_t j = 0; j < 9; j++) { + const row_col_t subgrid = {uint8_t(i / 3), uint8_t(j / 3)}; + const row_col_t field = {uint8_t(i % 3), uint8_t(j % 3)}; + const row_col_rt rfield = row_col_rt(field); + std::bitset<9> value = subgrids[::field(subgrid)].get(rfield); std::cout << value << " "; if (j % 3 == 2) std::cout << " "; }