doasku

Sudoku solver
git clone git://git.dimitrijedobrota.com/doasku.git
Log | Files | Refs | README | LICENSE

commit 1f999f0dafb7e5b013add6f89e89a1fb28a60d02
parent 63d9656c58864cb1199578a2fb9a443998a5a8e9
Author: Dimitrije Dobrota <mail@dimitrijedobrota.com>
Date:   Sat,  6 Apr 2024 22:25:40 +0200

Simplify the logic

* Get rid of row and subgrid
* Put everything inside ref class

Diffstat:
Mmain.cpp | 422++++++++++++++++++++++---------------------------------------------------------
1 file changed, 114 insertions(+), 308 deletions(-)

diff --git a/main.cpp b/main.cpp @@ -1,7 +1,6 @@ #include <array> #include <bit> #include <bitset> -#include <cassert> #include <cinttypes> #include <cstring> #include <format> @@ -12,36 +11,19 @@ static constexpr const std::int64_t mask_field = (1 << 9) - 1; static constexpr const std::int64_t mask_value = 0x201008040201; -using change_t = std::tuple<uint16_t, uint16_t>; -using changes_t = std::vector<change_t>; - class row_col_t; -class row_col_rt; struct row_col_t { row_col_t(uint8_t row, uint8_t col) : row(row), col(col) {} row_col_t(uint8_t field) : row(field / 3), col(field % 3) {} - row_col_t(row_col_rt rc); - explicit operator row_col_rt() const; operator uint8_t() const { return 3 * row + col; } row_col_t absolute(row_col_t rc) { return {uint8_t(row * 3 + rc.row), uint8_t(col * 3 + rc.col)}; } - friend std::ostream &operator<<(std::ostream &os, row_col_t rc) { - return os << std::format("({}, {})", rc.row, rc.col); - } + std::tuple<uint8_t, uint8_t> relative() { return {(row / 3) * 3 + col / 3, (row % 3) * 3 + col % 3}; } - uint8_t row; - uint8_t col; -}; - -struct row_col_rt { - row_col_rt(uint8_t row, uint8_t col) : row(row), col(col) {} - explicit row_col_rt(uint8_t field) : row(2 - field % 3), col(field / 3) {} - row_col_rt(row_col_t rc); - - friend std::ostream &operator<<(std::ostream &os, row_col_rt rc) { + friend std::ostream &operator<<(std::ostream &os, row_col_t rc) { return os << std::format("({}, {})", rc.row, rc.col); } @@ -49,24 +31,29 @@ struct row_col_rt { uint8_t col; }; -row_col_t::row_col_t(row_col_rt rc) : row(rc.col), col(2 - rc.row) {} -row_col_rt::row_col_rt(row_col_t rc) : row(2 - rc.col), col(rc.row) {} - class Ref { public: - uint16_t get(uint8_t value) const { return ref[value]; } + uint16_t get(uint8_t field) const { return value[field]; } + uint16_t get_ref(uint8_t value) const { return ref[value]; } - void zero(uint8_t value) { ref[value] = 0; } + uint8_t get_value(uint8_t field) const { return res[field]; } - void clear(uint8_t field, uint8_t value) { ref[value] &= ~(1 << field); } - void clear(uint8_t field) { + void clear(uint8_t field, uint8_t value) { + this->value[field] &= ~(1 << value); + ref[value] &= ~(1 << field); + } + + void set(uint8_t field, uint8_t value) { for (uint8_t i = 0; i < 9; i++) { ref[i] &= ~(1 << field); } + + this->value[field] = ref[value] = 0; + res[field] = value + 1; } - changes_t get_hidden_singles() const { - changes_t res; + auto get_hidden_singles() const { + std::vector<std::tuple<uint8_t, uint8_t>> res; for (uint8_t candidate = 0; candidate < 9; candidate++) { if (std::popcount(ref[candidate]) != 1) continue; @@ -76,225 +63,54 @@ class Ref { return res; } - private: - uint16_t ref[9] = {mask_field, mask_field, mask_field, mask_field, mask_field, - mask_field, mask_field, mask_field, mask_field}; -}; - -class Row { - public: - Row(uint64_t value = 0) : val(value) {} - operator uint64_t() const { return val; } - - void set_field(uint8_t field, uint8_t value) { clear(field), set(field, value); } - - uint16_t get(uint8_t field) const { return (val >> 9 * field) & mask_field; } - uint16_t get(uint8_t field, uint8_t value) const { return (val >> 9 * field) & (1 << value); } - - void set(uint8_t field) { val |= mask_field << 9 * field; } - void toggle(uint8_t field) { val ^= mask_field << 9 * field; } - void clear(uint8_t field) { val &= ~(mask_field << 9 * field); } - - void set(uint8_t field, uint8_t value) { val |= 1ull << (9 * field + value); } - void toggle(uint8_t field, uint8_t value) { val ^= 1ull << (9 * field + value); } - void clear(uint8_t field, uint8_t value) { val &= ~(1ull << (9 * field + value)); } - - void toggle_all(uint8_t value) { val ^= mask_value << value; } - void clear_all(uint8_t value) { val &= ~(mask_value << value); } - - friend std::ostream &operator<<(std::ostream &os, Row r) { - for (int i = 0; i < 6; i++) { - os << std::bitset<9>(r.get(i)) << " "; - } - return os; - } - - private: - uint64_t val; -}; - -class Subgrid { - public: - Subgrid() {} - - uint16_t get(row_col_t rc) const { return rows[rc.row].get(rc.col); } - uint16_t get(row_col_rt rc) const { return rows[rc.row].get(rc.col + 3); } - - uint16_t get_ref(uint8_t value) const { return ref.get(value); } - - uint8_t value(row_col_t rc) const { - if (!is_finished(rc)) return 0; - return 1 + std::countr_zero(get(rc)); - } - - bool is_finished(row_col_t rc) const { return finished & (1 << rc); } - bool is_finished() const { return finished == (1 << 9) - 1; } - - void set(row_col_t rc, uint8_t value) { - assert(value < 9); - - rows[0].clear_all(value); - rows[1].clear_all(value); - rows[2].clear_all(value); - - _set(row_col_rt(rc), value); - _set(rc, value); - - ref.zero(value); - ref.clear(rc); - - finished |= 1 << rc; - } - - void clear(row_col_t rc, uint8_t value) { - _clear(row_col_rt(rc), value); - _clear(rc, value); - - ref.clear(rc, value); - } - - changes_t get_hidden_singles() const { - if (is_finished()) return {}; - return ref.get_hidden_singles(); - } - - changes_t get_naked_singles() const { - if (is_finished()) return {}; - - changes_t res; + auto get_naked_singles() const { + std::vector<std::tuple<uint8_t, uint8_t>> res; for (uint8_t i = 0; i < 9; i++) { const auto values = get(i); if (std::popcount(values) != 1) continue; const auto value = std::countr_zero(values); - if (!ref.get(value)) continue; + if (!ref[value]) continue; res.emplace_back(i, value); } return res; } - std::array<changes_t, 2> check_ray() const { - if (is_finished()) return {}; - - Row o012 = rows[0] | rows[1] | rows[2]; - uint64_t mask = { - // clang-format off - uint64_t(o012.get(1) | o012.get(2)) << 0 | - uint64_t(o012.get(0) | o012.get(2)) << 9 | - uint64_t(o012.get(0) | o012.get(1)) << 18 | - uint64_t(o012.get(4) | o012.get(5)) << 27 | - uint64_t(o012.get(3) | o012.get(5)) << 36 | - uint64_t(o012.get(3) | o012.get(4)) << 45 - // clang-format on - }; - - Row potent = (rows[0] & rows[1]) | (rows[0] & rows[2]) | (rows[1] & rows[2]); - Row shooting = potent & ~mask & ~shot; - - shot |= shooting; - - std::array<changes_t, 2> res; - - for (uint8_t i = 0; i < 3; i++) { - uint16_t val = shooting.get(i); - while (val) { - uint8_t idx = std::countr_zero(val); - res[0].emplace_back(i, idx); - val ^= 1ull << idx; - } - } - - for (uint8_t i = 0; i < 3; i++) { - uint16_t val = shooting.get(i + 3); - while (val) { - uint8_t idx = std::countr_zero(val); - res[1].emplace_back(i, idx); - val ^= 1ull << idx; - } - } - - return res; - } - - changes_t check_naked() const { - if (is_finished()) return {}; - - static uint8_t count[512]; - static uint16_t value[9]; - - for (uint8_t field = 0; field < 9; field++) { - value[field] = get(field); - count[value[field]]++; - } - - changes_t res; + auto get_naked_pairs() const { + std::vector<std::tuple<uint8_t, uint8_t, uint8_t>> res; - for (int field = 0; field < 9; field++) { - const uint8_t popcnt = std::popcount(value[field]); - const uint8_t found = count[value[field]]; - - // no need to check any more times - count[value[field]] = 0; - - // if current is not part of a tuple continue - if (popcnt <= 1 || popcnt != found) continue; - - for (uint8_t nfield = 0; nfield < 9; nfield++) { - - // skip part of the tuple - if (value[field] == value[nfield]) continue; - - // are we going to clear any bits? - if ((value[field] & value[nfield]) == 0) continue; - - res.emplace_back(nfield, value[field]); + for (uint8_t i = 0; i < 9; i++) { + if (std::popcount(value[i]) != 2) continue; + if (seen_naked_pair & (1ul << i)) continue; + for (uint8_t j = i + 1; j < 9; j++) { + if (value[i] != value[j]) continue; + if (seen_naked_pair & (1ul << j)) continue; + + seen_naked_pair |= 1ul << i; + seen_naked_pair |= 1ul << j; + + uint16_t tval = value[i]; + while (tval) { + const uint8_t idx = std::countr_zero(tval); + res.emplace_back(idx, i, j); + tval ^= 1ull << idx; + } } } return res; } - friend std::ostream &operator<<(std::ostream &os, Subgrid s) { - os << std::endl << "Field:" << std::endl; - for (uint8_t i = 0; i < 3; i++) { - for (uint8_t j = 0; j < 3; j++) { - std::cout << std::bitset<9>(s.get(row_col_t(i, j))) << " "; - } - std::cout << std::endl; - } - - os << std::endl << "Rev:" << std::endl; - for (uint8_t i = 0; i < 3; i++) { - for (uint8_t j = 0; j < 3; j++) { - std::cout << std::bitset<9>(s.get(row_col_rt(i, j))) << " "; - } - std::cout << std::endl; - } - - os << std::endl << "Refs:" << std::endl; - for (uint8_t i = 0; i < 3; i++) { - for (uint8_t j = 0; j < 3; j++) { - std::cout << std::bitset<9>(s.get_ref(row_col_t{i, j})) << " "; - } - std::cout << std::endl; - } - - return os; - } - private: - void _set(row_col_t rc, uint8_t value) { rows[rc.row].set_field(rc.col, value); } - void _set(row_col_rt rcr, uint8_t value) { rows[rcr.row].set_field(rcr.col + 3, value); } - - void _clear(row_col_t rc, uint8_t value) { rows[rc.row].clear(rc.col, value); } - void _clear(row_col_rt rcr, uint8_t value) { rows[rcr.row].clear(rcr.col + 3, value); } + uint16_t value[9] = {mask_field, mask_field, mask_field, mask_field, mask_field, + mask_field, mask_field, mask_field, mask_field}; + uint16_t ref[9] = {mask_field, mask_field, mask_field, mask_field, mask_field, + mask_field, mask_field, mask_field, mask_field}; - Row rows[3] = {~0, ~0, ~0}; - Ref ref; + uint16_t res[9] = {0}; - uint16_t finished = 0; - mutable uint64_t shot = 0; + mutable uint16_t seen_naked_pair = 0; }; class Grid { @@ -310,24 +126,20 @@ class Grid { } void set(row_col_t subgrid, row_col_t field, uint8_t value) { - assert(value < 9); - const auto ab = subgrid.absolute(field); std::cout << "setting " << (int)value << ": " << ab << std::endl; - _zero(ab, value); + rows[ab.row].set(ab.col, value); + cols[ab.col].set(ab.row, value); + subgrids[subgrid].set(field, value); + // clear subgrid, row and col for (uint8_t i = 0; i < 3; i++) { - // clear current block _clear_row(subgrid, i, value); - - // clear intersecting row and col _clear_row(3 * subgrid.row + i, field.row, value); _clear_col(subgrid.col + 3 * i, field.col, value); } - - subgrids[subgrid].set(field, value); } bool solve() { @@ -345,85 +157,81 @@ class Grid { set(subgrid, field, val); changes++; } - } - for (uint8_t row = 0; row < 9; row++) { - const auto hs = rows[row].get_hidden_singles(); - for (const auto [col, val] : hs) { - const uint8_t subgrid = (row / 3) * 3 + col / 3; - const uint8_t field = (row % 3) * 3 + col % 3; + const auto ns = subgrids[subgrid].get_naked_singles(); + for (const auto [field, val] : ns) { - std::cout << "hidden singles row: " << (int)subgrid << " " << (int)field << " " - << int(val) << std::endl; + std::cout << "naked singles: " << (int)subgrid << " " << (int)field << " " << int(val) + << std::endl; set(subgrid, field, val); changes++; } - } - for (uint8_t col = 0; col < 9; col++) { - const auto hs = rows[col].get_hidden_singles(); - for (const auto [row, val] : hs) { - const uint8_t subgrid = (col / 3) * 3 + row / 3; - const uint8_t field = (col % 3) * 3 + row % 3; + const auto np = subgrids[subgrid].get_naked_pairs(); + for (const auto [val, f1, f2] : np) { + std::cout << "naked pairs: " << (int)subgrid << " " << (int)f1 << " " << (int)f2 << " " + << (int)val << std::endl; - std::cout << "hidden singles col: " << (int)subgrid << " " << (int)field << " " - << int(val) << std::endl; + for (uint8_t field = 0; field < 9; field++) { + if (field == f1 || field == f2) continue; + _clear(subgrid, field, val); + } - set(subgrid, field, val); changes++; } } - for (uint8_t subgrid = 0; subgrid < 9; subgrid++) { - const auto naked = subgrids[subgrid].get_naked_singles(); - for (const auto [field, val] : naked) { - std::cout << "naked singles: " << (int)subgrid << " " << (int)field << " " << int(val) - << std::endl; + for (uint8_t row = 0; row < 9; row++) { + const auto hs = rows[row].get_hidden_singles(); + for (const auto [col, val] : hs) { + const auto [subgrid, field] = row_col_t(row, col).relative(); + + std::cout << "hidden singles row: " << (int)subgrid << " " << (int)field << " " + << int(val) << std::endl; set(subgrid, field, val); changes++; } - } - for (uint8_t i = 0; i < 9; i++) { - const auto naked = subgrids[i].check_naked(); - for (const auto [field, mask] : naked) { + const auto np = rows[row].get_naked_pairs(); + for (const auto [val, c1, c2] : np) { + std::cout << "naked pairs row: " << (int)row << " " << (int)c1 << " " << (int)c2 << " " + << (int)val << std::endl; - const auto rc = row_col_t(i).absolute(field); - uint16_t tmask = mask; - while (tmask) { - uint16_t idx = std::countr_zero(tmask); - _clear(i, field, idx); - tmask ^= 1ull << idx; + for (uint8_t col = 0; col < 9; col++) { + if (col == c1 || col == c2) continue; + const auto [subgrid, field] = row_col_t(row, col).relative(); + _clear(subgrid, field, val); } - std::cout << "naked: " << (int)i << " " << (int)field << " " << std::bitset<9>(mask) - << std::endl; changes++; } } - for (uint8_t i = 0; i < 9; i++) { - const auto rays = subgrids[i].check_ray(); + for (uint8_t col = 0; col < 9; col++) { + const auto hs = cols[col].get_hidden_singles(); + for (const auto [row, val] : hs) { + const auto [subgrid, field] = row_col_t(row, col).relative(); - for (const auto [col, val] : rays[0]) { - static uint8_t mapping[9][2] = {{3, 6}, {4, 7}, {5, 8}, {0, 6}, {1, 7}, - {2, 8}, {0, 3}, {1, 4}, {2, 5}}; - _clear_col(mapping[i][0], col, val); - _clear_col(mapping[i][1], col, val); + std::cout << "hidden singles col: " << (int)subgrid << " " << (int)field << " " + << int(val) << std::endl; - std::cout << "ray col: " << (int)i << " " << (int)col << " " << (int)val << std::endl; + set(subgrid, field, val); changes++; } - for (const auto [row, val] : rays[1]) { - static uint8_t mapping[9][2] = {{1, 2}, {0, 2}, {0, 1}, {4, 5}, {3, 5}, - {3, 4}, {7, 8}, {6, 8}, {6, 7}}; - _clear_row(mapping[i][0], row, val); - _clear_row(mapping[i][1], row, val); + const auto np = cols[col].get_naked_pairs(); + for (const auto [val, r1, r2] : np) { + std::cout << "naked pairs col: " << (int)col << " " << (int)r1 << " " << (int)r2 << " " + << (int)val << std::endl; + + for (uint8_t row = 0; row < 9; row++) { + if (row == r1 || row == r2) continue; + const auto [subgrid, field] = row_col_t(row, col).relative(); + _clear(subgrid, field, val); + } - std::cout << "ray row: " << (int)i << " " << (int)row << " " << (int)val << std::endl; changes++; } } @@ -431,11 +239,7 @@ class Grid { if (changes) iter++; } - for (uint8_t i = 0; i < 9; i++) { - if (!subgrids[i].is_finished()) return false; - } - - return true; + return _is_finished(); } void print() const { @@ -469,7 +273,7 @@ class Grid { std::cout << "Row Refs: " << std::endl; for (uint8_t i = 0; i < 9; i++) { for (uint8_t j = 0; j < 9; j++) { - std::cout << std::bitset<9>(rows[i].get(j)) << " "; + std::cout << std::bitset<9>(rows[i].get_ref(j)) << " "; } std::cout << std::endl; if (i % 3 == 2) std::cout << std::endl; @@ -478,7 +282,7 @@ class Grid { std::cout << "Col Refs: " << std::endl; for (uint8_t i = 0; i < 9; i++) { for (uint8_t j = 0; j < 9; j++) { - std::cout << std::bitset<9>(cols[i].get(j)) << " "; + std::cout << std::bitset<9>(cols[i].get_ref(j)) << " "; } std::cout << std::endl; if (i % 3 == 2) std::cout << std::endl; @@ -489,7 +293,7 @@ class Grid { for (uint8_t j = 0; j < 9; j++) { const row_col_t subgrid = {uint8_t(i / 3), uint8_t(j / 3)}; const row_col_t field = {uint8_t(i % 3), uint8_t(j % 3)}; - std::cout << (int)subgrids[subgrid].value(field) << " "; + std::cout << (int)subgrids[subgrid].get_value(field) << " "; if (j % 3 == 2) std::cout << " "; } std::cout << std::endl; @@ -498,24 +302,12 @@ class Grid { } private: - void _clear(row_col_t ab, uint8_t value) { - rows[ab.row].clear(ab.col, value); - cols[ab.col].clear(ab.row, value); - } - - void _clear(row_col_t ab) { - rows[ab.row].clear(ab.col); - cols[ab.col].clear(ab.row); - } - - void _zero(row_col_t ab, uint8_t value) { - rows[ab.row].zero(value); - cols[ab.col].zero(value); - } - void _clear(row_col_t sg, row_col_t field, uint8_t value) { - _clear(sg.absolute(field), value); subgrids[sg].clear(field, value); + + row_col_t ab = sg.absolute(field); + rows[ab.row].clear(ab.col, value); + cols[ab.col].clear(ab.row, value); } void _clear_row(row_col_t sg, uint8_t row, uint8_t value) { @@ -530,7 +322,21 @@ class Grid { } } - Subgrid subgrids[9]; + bool _is_finished(uint8_t subgrid) { + for (uint8_t i = 0; i < 9; i++) { + if (subgrids[subgrid].get_ref(i)) return false; + } + return true; + } + + bool _is_finished() { + for (uint8_t i = 0; i < 9; i++) { + if (!_is_finished(i)) return false; + } + return true; + } + + Ref subgrids[9]; Ref rows[9]; Ref cols[9]; };