doasku

Sudoku solver
git clone git://git.dimitrijedobrota.com/doasku.git
Log | Files | Refs | README | LICENSE

commit 4bd367d95546d5af024d7a6cce8b631ae75cb53a
parent dacb9c0718e06fafa9f537593f512b833aa29db5
Author: Dimitrije Dobrota <mail@dimitrijedobrota.com>
Date:   Thu,  4 Apr 2024 20:25:39 +0200

Add row and column reference counting

Diffstat:
Mmain.cpp | 150++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------
1 file changed, 115 insertions(+), 35 deletions(-)

diff --git a/main.cpp b/main.cpp @@ -25,6 +25,8 @@ struct row_col_t { explicit operator row_col_rt() const; + row_col_t absolute(row_col_t rc) { return {uint8_t(row * 3 + rc.row), uint8_t(col * 3 + rc.col)}; } + friend std::ostream &operator<<(std::ostream &os, row_col_t rc) { return os << std::format("({}, {})", rc.row, rc.col); } @@ -53,14 +55,14 @@ uint8_t field(row_col_t rc) { return 3 * rc.row + rc.col; } class Ref { public: - uint16_t get(row_col_t rc) const { return ref[::field(rc)]; } + uint16_t get(uint8_t value) const { return ref[value]; } void clear(uint8_t value) { ref[value] = 0; } - void remove(row_col_t rc, uint8_t value) { ref[value] &= ~(1 << ::field(rc)); } - void remove(row_col_t rc) { + void remove(uint8_t field, uint8_t value) { ref[value] &= ~(1 << field); } + void remove(uint8_t field) { for (uint8_t i = 0; i < 9; i++) { - ref[i] &= ~(1 << ::field(rc)); + ref[i] &= ~(1 << field); } } @@ -121,7 +123,7 @@ class Subgrid { uint16_t get(row_col_t rc) const { return rows[rc.row].get(rc.col); } uint16_t get(row_col_rt rc) const { return rows[rc.row].get(rc.col + 3); } - uint16_t get_ref(row_col_t rc) const { return ref.get(rc); } + uint16_t get_ref(uint8_t value) const { return ref.get(value); } uint8_t value(row_col_t rc) const { if (!is_finished(rc)) return 0; @@ -141,10 +143,12 @@ class Subgrid { _set(row_col_rt(rc), value); _set(rc, value); - ref.remove(rc); + const uint8_t field = ::field(rc); + ref.clear(value); + ref.remove(field); - finished |= 1 << ::field(rc); + finished |= 1 << field; } void clear_row(uint8_t row, uint8_t value) { @@ -152,9 +156,10 @@ class Subgrid { for (uint8_t i = 0; i < 3; i++) { const row_col_t rc = {row, i}; - ref.remove(rc, value); _clear(row_col_rt(rc), value); _clear(rc, value); + + ref.remove(::field(rc), value); } } @@ -163,9 +168,10 @@ class Subgrid { for (uint8_t i = 0; i < 3; i++) { const row_col_t rc = {i, col}; - ref.remove(rc, value); _clear(row_col_rt(rc), value); _clear(rc, value); + + ref.remove(::field(rc), value); } } @@ -174,9 +180,11 @@ class Subgrid { _mask(rc, mask); } - void remove_ref(row_col_t rc, uint8_t value) { ref.remove(rc, value); } + void ref_clear(row_col_t rc, uint8_t value) { ref.clear(value); } + void ref_remove(row_col_t rc, uint8_t value) { ref.remove(::field(rc), value); } + void ref_remove(row_col_t rc) { ref.remove(::field(rc)); } - changes_t check_lone() const { + changes_t get_lone() const { if (is_finished()) return {}; return ref.get_lone(); } @@ -264,25 +272,25 @@ class Subgrid { friend std::ostream &operator<<(std::ostream &os, Subgrid s) { os << std::endl << "Field:" << std::endl; - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 3; j++) { + for (uint8_t i = 0; i < 3; i++) { + for (uint8_t j = 0; j < 3; j++) { std::cout << std::bitset<9>(s.get(row_col_t(i, j))) << " "; } std::cout << std::endl; } os << std::endl << "Rev:" << std::endl; - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 3; j++) { + for (uint8_t i = 0; i < 3; i++) { + for (uint8_t j = 0; j < 3; j++) { std::cout << std::bitset<9>(s.get(row_col_rt(i, j))) << " "; } std::cout << std::endl; } os << std::endl << "Refs:" << std::endl; - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 3; j++) { - std::cout << std::bitset<9>(s.get_ref(row_col_t(i, j))) << " "; + for (uint8_t i = 0; i < 3; i++) { + for (uint8_t j = 0; j < 3; j++) { + std::cout << std::bitset<9>(s.get_ref(::field({i, j}))) << " "; } std::cout << std::endl; } @@ -319,21 +327,39 @@ class Grid { } } - void set(uint8_t subgrid, uint8_t field, uint8_t value) { - assert(subgrid < 9 && field < 9 && value < 9); + void set(row_col_t subgrid, row_col_t field, uint8_t value) { + assert(value < 9); + + const auto ab = subgrid.absolute(field); + + subgrids[3 * subgrid.row + 0].clear_row(field.row, value); + subgrids[3 * subgrid.row + 1].clear_row(field.row, value); + subgrids[3 * subgrid.row + 2].clear_row(field.row, value); - const auto [row, col] = row_col_t(subgrid); - const auto [frow, fcol] = row_col_t(field); + subgrids[subgrid.col + 0].clear_col(field.col, value); + subgrids[subgrid.col + 3].clear_col(field.col, value); + subgrids[subgrid.col + 6].clear_col(field.col, value); - subgrids[3 * row + 0].clear_row(frow, value); - subgrids[3 * row + 1].clear_row(frow, value); - subgrids[3 * row + 2].clear_row(frow, value); + rows[ab.row].remove(ab.col); + cols[ab.col].remove(ab.row); - subgrids[col + 0].clear_col(fcol, value); - subgrids[col + 3].clear_col(fcol, value); - subgrids[col + 6].clear_col(fcol, value); + rows[ab.row].clear(value); + cols[ab.col].clear(value); - subgrids[subgrid].set(field, value); + for (uint8_t i = 0; i < 9; i++) { + rows[i].remove(ab.col, value); + cols[i].remove(ab.row, value); + + uint8_t trow = subgrid.row * 3 + i / 3; + uint8_t tcol = subgrid.col * 3 + i % 3; + + rows[trow].remove(tcol, value); + cols[tcol].remove(trow, value); + } + + std::cout << "setting " << (int)value << ": " << ab << std::endl; + + subgrids[::field(subgrid)].set(field, value); } bool solve() { @@ -341,12 +367,41 @@ class Grid { while (iter--) { uint8_t changes = 0; - for (uint8_t i = 0; i < 9; i++) { - const auto lones = subgrids[i].check_lone(); + for (uint8_t subgrid = 0; subgrid < 9; subgrid++) { + const auto lones = subgrids[subgrid].get_lone(); for (const auto [field, val] : lones) { - set(i, field, val); + set(subgrid, field, val); + + std::cout << "lone: " << subgrid << " " << (int)field << " " << int(val) << std::endl; + changes++; + } + } + + for (uint8_t row = 0; row < 9; row++) { + const auto lones = rows[row].get_lone(); + for (const auto [col, val] : lones) { + const uint8_t subgrid = (row / 3) * 3 + col / 3; + const uint8_t field = (row % 3) * 3 + col % 3; - std::cout << "lone: " << i << " " << (int)field << " " << int(val) << std::endl; + std::cout << (int)row << " " << (int)col << std::endl; + std::cout << "lone row: " << (int)subgrid << " " << (int)field << " " << int(val) + << std::endl; + + set(subgrid, field, val); + changes++; + } + } + + for (uint8_t col = 0; col < 9; col++) { + const auto lones = rows[col].get_lone(); + for (const auto [row, val] : lones) { + const uint8_t subgrid = (col / 3) * 3 + row / 3; + const uint8_t field = (col % 3) * 3 + row % 3; + + std::cout << "lone col: " << (int)subgrid << " " << (int)field << " " << int(val) + << std::endl; + + set(subgrid, field, val); changes++; } } @@ -356,10 +411,15 @@ class Grid { for (const auto [field, mask] : naked) { subgrids[i].mask(field, mask); + const auto rc = row_col_t(i).absolute(field); uint16_t tmask = mask; while (tmask) { uint16_t idx = std::countr_zero(tmask); - subgrids[i].remove_ref(field, idx); + + subgrids[i].ref_remove(field, idx); + rows[rc.row].remove(rc.col); + cols[rc.col].remove(rc.row); + tmask ^= 1ull << idx; } @@ -421,7 +481,7 @@ class Grid { for (uint8_t j = 0; j < 9; j++) { const row_col_t subgrid = {uint8_t(i / 3), uint8_t(j / 3)}; const row_col_t field = {uint8_t(i % 3), uint8_t(j % 3)}; - std::bitset<9> value = subgrids[::field(subgrid)].get_ref(field); + std::bitset<9> value = subgrids[::field(subgrid)].get_ref(::field(field)); std::cout << value << " "; if (j % 3 == 2) std::cout << " "; } @@ -429,6 +489,24 @@ class Grid { if (i % 3 == 2) std::cout << std::endl; } + std::cout << "Row Refs: " << std::endl; + for (uint8_t i = 0; i < 9; i++) { + for (uint8_t j = 0; j < 9; j++) { + std::cout << std::bitset<9>(rows[i].get(j)) << " "; + } + std::cout << std::endl; + if (i % 3 == 2) std::cout << std::endl; + } + + std::cout << "Col Refs: " << std::endl; + for (uint8_t i = 0; i < 9; i++) { + for (uint8_t j = 0; j < 9; j++) { + std::cout << std::bitset<9>(cols[i].get(j)) << " "; + } + std::cout << std::endl; + if (i % 3 == 2) std::cout << std::endl; + } + std::cout << "Board: " << std::endl; for (uint8_t i = 0; i < 9; i++) { for (uint8_t j = 0; j < 9; j++) { @@ -444,6 +522,8 @@ class Grid { private: Subgrid subgrids[9]; + Ref rows[9]; + Ref cols[9]; }; int main(const int argc, const char *argv[]) {