doasku

Sudoku solver
git clone git://git.dimitrijedobrota.com/doasku.git
Log | Files | Refs | README | LICENSE

commit 026fc34262e261516f73ea642afee4c838dc2056
parent f95e9990a789decde3e871a3b48461ce51a40493
Author: Dimitrije Dobrota <mail@dimitrijedobrota.com>
Date:   Thu, 11 Apr 2024 19:53:03 +0200

Unify ref::get operations and simplify solve

Diffstat:
Mmain.cpp | 412+++++++++++++++++++++++++++++++++++--------------------------------------------
1 file changed, 181 insertions(+), 231 deletions(-)

diff --git a/main.cpp b/main.cpp @@ -1,3 +1,4 @@ +#include <algorithm> #include <array> #include <bit> #include <bitset> @@ -9,6 +10,10 @@ #include <iostream> #include <vector> +template <class Input, class UnaryFunc> UnaryFunc for_each(Input input, UnaryFunc f) { + return std::for_each(begin(input), end(input), f); +} + struct cord_t { cord_t(uint8_t row, uint8_t col) : value(row * 3 + col) { assert(row < 3 && col < 3); } cord_t(uint8_t value) : value(value) { assert(value < 9); } @@ -60,8 +65,13 @@ static std::tuple<uint8_t, uint8_t> other_subgrid_col(uint8_t subgrid) { } class Ref { + public: + using change_t = std::tuple<uint8_t, uint16_t>; + using changes_t = std::vector<change_t>; + + private: auto get_pointing(uint16_t mask, uint16_t &seen) const { - std::vector<std::tuple<uint8_t, uint8_t>> res; + changes_t res; for (uint8_t i = 0; i < 9; i++) { const uint8_t popcnt = std::popcount(ref[i]); @@ -90,10 +100,7 @@ class Ref { } } - using hidden_t = std::tuple<uint8_t, uint16_t>; - using hiddens_t = std::vector<hidden_t>; - - void get_hidden_changes(uint8_t fields[9], uint16_t mask, uint8_t n, hiddens_t &vec) const { + void get_hidden_changes(uint8_t fields[9], uint16_t mask, uint8_t n, changes_t &vec) const { for (uint8_t i = 0; i < n; i++) { const uint16_t change = value[fields[i]] & ~(value[fields[i]] & mask); if (!change) continue; @@ -121,18 +128,11 @@ class Ref { res[field] = value + 1; } - auto get_pointing_row() const { - static const uint16_t row_mask = 0x7; - return get_pointing(row_mask, seen_point_row); - } - - auto get_pointing_col() const { - static const uint16_t col_mask = 0x49; - return get_pointing(col_mask, seen_point_col); - } + auto get_pointing_row() const { return get_pointing(0x7, seen_point_row); } + auto get_pointing_col() const { return get_pointing(0x49, seen_point_col); } auto get_hidden_singles() const { - std::vector<std::tuple<uint8_t, uint8_t>> res; + changes_t res; for (uint8_t candidate = 0; candidate < 9; candidate++) { if (std::popcount(ref[candidate]) != 1) continue; @@ -143,7 +143,7 @@ class Ref { } auto get_hidden_pairs() const { - hiddens_t res; + changes_t res; for (uint8_t i = 0; i < 9; i++) { if (std::popcount(ref[i]) != 2) continue; @@ -167,7 +167,7 @@ class Ref { } auto get_hidden_triplets() const { - hiddens_t res; + changes_t res; for (uint8_t i = 0; i < 9; i++) { if (std::popcount(ref[i]) < 2) continue; @@ -201,7 +201,7 @@ class Ref { } auto get_hidden_quads() const { - std::vector<std::tuple<uint8_t, uint16_t>> res; + changes_t res; for (uint8_t i = 0; i < 9; i++) { if (std::popcount(ref[i]) < 2) continue; @@ -241,7 +241,7 @@ class Ref { } auto get_naked_singles() const { - std::vector<std::tuple<uint8_t, uint8_t>> res; + changes_t res; for (uint8_t i = 0; i < 9; i++) { const auto values = get(i); @@ -255,7 +255,7 @@ class Ref { } auto get_naked_pairs() const { - std::vector<std::tuple<uint8_t, uint8_t, uint8_t>> res; + changes_t res; for (uint8_t i = 0; i < 9; i++) { if (std::popcount(value[i]) != 2) continue; @@ -267,11 +267,14 @@ class Ref { seen_naked_pair |= 1ul << i; seen_naked_pair |= 1ul << j; - uint16_t tval = value[i]; - while (tval) { - const uint8_t idx = std::countr_zero(tval); - res.emplace_back(idx, i, j); - tval ^= 1ull << idx; + uint16_t val = value[i]; + while (val) { + const uint8_t idx = std::countr_zero(val); + for (uint8_t pos = 0; pos < 9; pos++) { + if (pos == i || pos == j) continue; + res.emplace_back(pos, idx); + } + val ^= 1ull << idx; } } } @@ -280,7 +283,8 @@ class Ref { } auto get_naked_triplets() const { - std::vector<std::tuple<uint8_t, uint8_t, uint8_t, uint8_t>> res; + changes_t res; + for (uint8_t i = 0; i < 9; i++) { if (this->res[i]) continue; if (seen_naked_triplet & (1ul << i)) continue; @@ -302,7 +306,10 @@ class Ref { while (val) { const uint8_t idx = std::countr_zero(val); - res.emplace_back(idx, i, j, k); + for (uint8_t pos = 0; pos < 9; pos++) { + if (pos == i || pos == j || pos == k) continue; + res.emplace_back(pos, idx); + } val ^= 1ull << idx; } } @@ -312,7 +319,7 @@ class Ref { } auto get_naked_quads() const { - std::vector<std::tuple<uint8_t, uint8_t, uint8_t, uint8_t, uint8_t>> res; + changes_t res; for (uint8_t i = 0; i < 9; i++) { if (this->res[i]) continue; @@ -340,7 +347,10 @@ class Ref { while (val) { const uint8_t idx = std::countr_zero(val); - res.emplace_back(idx, i, j, k, l); + for (uint8_t pos = 0; pos < 9; pos++) { + if (pos == i || pos == j || pos == k || pos == l) continue; + res.emplace_back(pos, idx); + } val ^= 1ull << idx; } } @@ -362,16 +372,9 @@ class Ref { uint16_t res[9] = {0}; - mutable uint16_t seen_hidden_pair = 0; - mutable uint16_t seen_hidden_triplet = 0; - mutable uint16_t seen_hidden_quads = 0; - - mutable uint16_t seen_naked_pair = 0; - mutable uint16_t seen_naked_triplet = 0; - mutable uint16_t seen_naked_quad = 0; - - mutable uint16_t seen_point_row = 0; - mutable uint16_t seen_point_col = 0; + mutable uint16_t seen_hidden_pair = 0, seen_hidden_triplet = 0, seen_hidden_quads = 0; + mutable uint16_t seen_naked_pair = 0, seen_naked_triplet = 0, seen_naked_quad = 0; + mutable uint16_t seen_point_row = 0, seen_point_col = 0; }; class Grid { @@ -381,241 +384,111 @@ class Grid { for (uint8_t i = 0; i < 9; i++) { for (uint8_t j = 0; j < 9; j++, idx++) { if (s[idx] == '0') continue; - set({i, j}, s[idx] - '1'); + _set({i, j}, s[idx] - '1'); } } } - void set(acord_t ab, uint8_t value) { - rows[ab.row()].set(ab.col(), value); - cols[ab.col()].set(ab.row(), value); - subgrids[(uint8_t)ab.subgrid()].set((uint8_t)ab.field(), value); - - _clear_row(ab.subgrid(), 0, value); - _clear_row(ab.subgrid(), 1, value); - _clear_row(ab.subgrid(), 2, value); - - const auto [r1, r2] = other_subgrid_row((uint8_t)ab.subgrid()); - _clear_row(r1, ab.field().row(), value); - _clear_row(r2, ab.field().row(), value); - - const auto [c1, c2] = other_subgrid_col((uint8_t)ab.subgrid()); - _clear_col(c1, ab.field().col(), value); - _clear_col(c2, ab.field().col(), value); - } - bool solve() { - int iter = 1; - while (iter--) { + // clang-format off + static const auto sub_op = + [this](void (Grid::*op)(operation_t), Ref::changes_t (Ref::*f)() const, uint8_t subgrid) { + for_each((subgrids[subgrid].*(f))(), [this, subgrid, op](const Ref::change_t ch) { + (this->*(op))(operation_t({cord_t(subgrid), cord_t(std::get<0>(ch))}, std::get<1>(ch))); + }); + }; + + static const auto row_op = + [this](void (Grid::*op)(operation_t), Ref::changes_t (Ref::*f)() const, uint8_t row) { + for_each((rows[row].*(f))(), [this, row, op](const Ref::change_t ch) { + (this->*(op))(operation_t({row, std::get<0>(ch)}, std::get<1>(ch))); + }); + }; + + static const auto col_op = + [this](void (Grid::*op)(operation_t), Ref::changes_t (Ref::*f)() const, uint8_t col) { + for_each((cols[col].*(f))(), [this, col, op](const Ref::change_t ch) { + (this->*(op))(operation_t({std::get<0>(ch), col}, std::get<1>(ch))); + }); + }; + // clang-format on + + changed = true; + while (changed) { + changed = false; for (uint8_t subgrid = 0; subgrid < 9; subgrid++) { - for (const auto [field, val] : subgrids[subgrid].get_naked_singles()) { - set({cord_t(subgrid), cord_t(field)}, val); - iter = true; - } + sub_op(&Grid::op_set, &Ref::get_naked_singles, subgrid); } - if (iter) continue; + if (changed) continue; for (uint8_t idx = 0; idx < 9; idx++) { - for (const auto [field, val] : subgrids[idx].get_hidden_singles()) { - set({cord_t(idx), cord_t(field)}, val); - iter = true; - } - - for (const auto [col, val] : rows[idx].get_hidden_singles()) { - set({idx, col}, val); - iter = true; - } - - for (const auto [row, val] : cols[idx].get_hidden_singles()) { - set({row, idx}, val); - iter = true; - } + sub_op(&Grid::op_set, &Ref::get_hidden_singles, idx); + row_op(&Grid::op_set, &Ref::get_hidden_singles, idx); + col_op(&Grid::op_set, &Ref::get_hidden_singles, idx); } - if (iter) continue; + if (changed) continue; for (uint8_t subgrid = 0; subgrid < 9; subgrid++) { - for (const auto [row, val] : subgrids[subgrid].get_pointing_row()) { - const auto [r1, r2] = other_subgrid_row(subgrid); - _clear_row(r1, row, val); - _clear_row(r2, row, val); - - iter = true; - } - - for (const auto [col, val] : subgrids[subgrid].get_pointing_col()) { - const auto [c1, c2] = other_subgrid_col(subgrid); - _clear_col(c1, col, val); - _clear_col(c2, col, val); - - iter = true; - } + sub_op(&Grid::op_clear_row, &Ref::get_pointing_row, subgrid); + sub_op(&Grid::op_clear_col, &Ref::get_pointing_col, subgrid); } - if (iter) continue; + if (changed) continue; for (uint8_t idx = 0; idx < 9; idx++) { - for (const auto [field, mask] : subgrids[idx].get_hidden_pairs()) { - _mask({cord_t(idx), cord_t(field)}, mask); - iter = true; - } - - for (const auto [col, mask] : rows[idx].get_hidden_pairs()) { - _mask({idx, col}, mask); - iter = true; - } - - for (const auto [row, mask] : cols[idx].get_hidden_pairs()) { - _mask({row, idx}, mask); - iter = true; - } + sub_op(&Grid::op_mask, &Ref::get_hidden_pairs, idx); + row_op(&Grid::op_mask, &Ref::get_hidden_pairs, idx); + col_op(&Grid::op_mask, &Ref::get_hidden_pairs, idx); } - if (iter) continue; + if (changed) continue; for (uint8_t idx = 0; idx < 9; idx++) { - for (const auto [field, mask] : subgrids[idx].get_hidden_triplets()) { - _mask({cord_t(idx), cord_t(field)}, mask); - iter = true; - } - - for (const auto [col, mask] : rows[idx].get_hidden_triplets()) { - _mask({idx, col}, mask); - iter = true; - } - - for (const auto [row, mask] : cols[idx].get_hidden_triplets()) { - _mask({row, idx}, mask); - iter = true; - } + sub_op(&Grid::op_mask, &Ref::get_hidden_triplets, idx); + row_op(&Grid::op_mask, &Ref::get_hidden_triplets, idx); + col_op(&Grid::op_mask, &Ref::get_hidden_triplets, idx); } - if (iter) continue; + if (changed) continue; for (uint8_t idx = 0; idx < 9; idx++) { - for (const auto [field, mask] : subgrids[idx].get_hidden_quads()) { - _mask({cord_t(idx), cord_t(field)}, mask); - iter = true; - } - - for (const auto [col, mask] : rows[idx].get_hidden_quads()) { - _mask({idx, col}, mask); - iter = true; - } - - for (const auto [row, mask] : cols[idx].get_hidden_quads()) { - _mask({row, idx}, mask); - iter = true; - } + sub_op(&Grid::op_mask, &Ref::get_hidden_quads, idx); + row_op(&Grid::op_mask, &Ref::get_hidden_quads, idx); + col_op(&Grid::op_mask, &Ref::get_hidden_quads, idx); } - if (iter) continue; + if (changed) continue; for (uint8_t idx = 0; idx < 9; idx++) { - for (const auto [val, f1, f2] : subgrids[idx].get_naked_pairs()) { - for (uint8_t field = 0; field < 9; field++) { - if (field == f1 || field == f2) continue; - _clear({cord_t(idx), cord_t(field)}, val); - } - - iter = true; - } - - for (const auto [val, c1, c2] : rows[idx].get_naked_pairs()) { - for (uint8_t col = 0; col < 9; col++) { - if (col == c1 || col == c2) continue; - _clear({idx, col}, val); - } - iter = true; - } - - for (const auto [val, r1, r2] : cols[idx].get_naked_pairs()) { - for (uint8_t row = 0; row < 9; row++) { - if (row == r1 || row == r2) continue; - _clear({row, idx}, val); - } - iter = true; - } + sub_op(&Grid::op_clear, &Ref::get_naked_pairs, idx); + row_op(&Grid::op_clear, &Ref::get_naked_pairs, idx); + col_op(&Grid::op_clear, &Ref::get_naked_pairs, idx); } - if (iter) continue; + if (changed) continue; for (uint8_t idx = 0; idx < 9; idx++) { - for (const auto [val, f1, f2, f3] : subgrids[idx].get_naked_triplets()) { - for (uint8_t field = 0; field < 9; field++) { - if (field == f1 || field == f2 || field == f3) continue; - _clear({cord_t(idx), cord_t(field)}, val); - } - - iter = true; - } - - for (const auto [val, c1, c2, c3] : rows[idx].get_naked_triplets()) { - for (uint8_t col = 0; col < 9; col++) { - if (col == c1 || col == c2 || col == c3) continue; - _clear({idx, col}, val); - } - iter = true; - } - - for (const auto [val, r1, r2, r3] : cols[idx].get_naked_triplets()) { - for (uint8_t row = 0; row < 9; row++) { - if (row == r1 || row == r2 || row == r3) continue; - _clear({row, idx}, val); - } - iter = true; - } + sub_op(&Grid::op_clear, &Ref::get_naked_triplets, idx); + row_op(&Grid::op_clear, &Ref::get_naked_triplets, idx); + col_op(&Grid::op_clear, &Ref::get_naked_triplets, idx); } - if (iter) continue; + if (changed) continue; for (uint8_t idx = 0; idx < 9; idx++) { - for (const auto [val, f1, f2, f3, f4] : subgrids[idx].get_naked_quads()) { - for (uint8_t field = 0; field < 9; field++) { - if (field == f1 || field == f2 || field == f3 || field == f4) continue; - _clear({cord_t(idx), cord_t(field)}, val); - } - - iter = true; - } - - for (const auto [val, c1, c2, c3, c4] : rows[idx].get_naked_quads()) { - for (uint8_t col = 0; col < 9; col++) { - if (col == c1 || col == c2 || col == c3 || col == c4) continue; - _clear({idx, col}, val); - } - iter = true; - } - - for (const auto [val, r1, r2, r3, r4] : cols[idx].get_naked_quads()) { - for (uint8_t row = 0; row < 9; row++) { - if (row == r1 || row == r2 || row == r3 || row == r4) continue; - _clear({row, idx}, val); - } - iter = true; - } + sub_op(&Grid::op_clear, &Ref::get_naked_quads, idx); + row_op(&Grid::op_clear, &Ref::get_naked_quads, idx); + col_op(&Grid::op_clear, &Ref::get_naked_quads, idx); } - if (iter) continue; + if (changed) continue; for (uint8_t idx = 0; idx < 9; idx++) { - for (const auto [index, val] : rows[idx].get_pointing_row()) { - const auto [r1, r2] = other_subgrid_row(idx); - _clear_row((r1 / 3) * 3 + index, r1 % 3, val); - _clear_row((r2 / 3) * 3 + index, r2 % 3, val); - - iter = true; - } - - for (const auto [index, val] : cols[idx].get_pointing_row()) { - const auto [c1, c2] = other_subgrid_row(idx); - _clear_col(index * 3 + c1 / 3, c1 % 3, val); - _clear_col(index * 3 + c2 / 3, c2 % 3, val); - - iter = true; - } + row_op(&Grid::op_clear_row_rel, &Ref::get_pointing_row, idx); + col_op(&Grid::op_clear_col_rel, &Ref::get_pointing_row, idx); } } @@ -684,6 +557,81 @@ class Grid { } private: + using operation_t = std::tuple<acord_t, uint16_t>; + + void op_set(operation_t op) { + _set(std::get<0>(op), std::get<1>(op)); + changed = true; + } + + void op_mask(operation_t op) { + _mask(std::get<0>(op), std::get<1>(op)); + changed = true; + } + + void op_clear(operation_t op) { + _clear(std::get<0>(op), std::get<1>(op)); + changed = true; + } + + void op_clear_row(operation_t op) { + const auto [ab, val] = op; + + const auto [r1, r2] = other_subgrid_row((uint8_t)ab.subgrid()); + _clear_row(r1, (uint8_t)ab.field(), val); + _clear_row(r2, (uint8_t)ab.field(), val); + + changed = true; + } + + void op_clear_col(operation_t op) { + const auto [ab, val] = op; + + const auto [c1, c2] = other_subgrid_col((uint8_t)ab.subgrid()); + _clear_col(c1, (uint8_t)ab.field(), val); + _clear_col(c2, (uint8_t)ab.field(), val); + + changed = true; + } + + void op_clear_row_rel(operation_t op) { + const auto [ab, val] = op; + + const auto [r1, r2] = other_subgrid_row((uint8_t)ab.row()); + _clear_row((r1 / 3) * 3 + (uint8_t)ab.col(), r1 % 3, val); + _clear_row((r2 / 3) * 3 + (uint8_t)ab.col(), r2 % 3, val); + + changed = true; + } + + void op_clear_col_rel(operation_t op) { + const auto [ab, val] = op; + + const auto [c1, c2] = other_subgrid_row((uint8_t)ab.col()); + _clear_col((uint8_t)ab.row() * 3 + c1 / 3, c1 % 3, val); + _clear_col((uint8_t)ab.row() * 3 + c2 / 3, c2 % 3, val); + + changed = true; + } + + void _set(acord_t ab, uint8_t value) { + rows[ab.row()].set(ab.col(), value); + cols[ab.col()].set(ab.row(), value); + subgrids[(uint8_t)ab.subgrid()].set((uint8_t)ab.field(), value); + + _clear_row(ab.subgrid(), 0, value); + _clear_row(ab.subgrid(), 1, value); + _clear_row(ab.subgrid(), 2, value); + + const auto [r1, r2] = other_subgrid_row((uint8_t)ab.subgrid()); + _clear_row(r1, ab.field().row(), value); + _clear_row(r2, ab.field().row(), value); + + const auto [c1, c2] = other_subgrid_col((uint8_t)ab.subgrid()); + _clear_col(c1, ab.field().col(), value); + _clear_col(c2, ab.field().col(), value); + } + void _mask(acord_t ab, uint16_t mask) { while (mask) { const uint8_t idx = std::countr_zero(mask); @@ -728,6 +676,8 @@ class Grid { Ref subgrids[9]; Ref rows[9]; Ref cols[9]; + + bool changed = false; }; int main(const int argc, const char *argv[]) {