commit 026fc34262e261516f73ea642afee4c838dc2056
parent f95e9990a789decde3e871a3b48461ce51a40493
Author: Dimitrije Dobrota <mail@dimitrijedobrota.com>
Date: Thu, 11 Apr 2024 19:53:03 +0200
Unify ref::get operations and simplify solve
Diffstat:
M | main.cpp | | | 412 | +++++++++++++++++++++++++++++++++++-------------------------------------------- |
1 file changed, 181 insertions(+), 231 deletions(-)
diff --git a/main.cpp b/main.cpp
@@ -1,3 +1,4 @@
+#include <algorithm>
#include <array>
#include <bit>
#include <bitset>
@@ -9,6 +10,10 @@
#include <iostream>
#include <vector>
+template <class Input, class UnaryFunc> UnaryFunc for_each(Input input, UnaryFunc f) {
+ return std::for_each(begin(input), end(input), f);
+}
+
struct cord_t {
cord_t(uint8_t row, uint8_t col) : value(row * 3 + col) { assert(row < 3 && col < 3); }
cord_t(uint8_t value) : value(value) { assert(value < 9); }
@@ -60,8 +65,13 @@ static std::tuple<uint8_t, uint8_t> other_subgrid_col(uint8_t subgrid) {
}
class Ref {
+ public:
+ using change_t = std::tuple<uint8_t, uint16_t>;
+ using changes_t = std::vector<change_t>;
+
+ private:
auto get_pointing(uint16_t mask, uint16_t &seen) const {
- std::vector<std::tuple<uint8_t, uint8_t>> res;
+ changes_t res;
for (uint8_t i = 0; i < 9; i++) {
const uint8_t popcnt = std::popcount(ref[i]);
@@ -90,10 +100,7 @@ class Ref {
}
}
- using hidden_t = std::tuple<uint8_t, uint16_t>;
- using hiddens_t = std::vector<hidden_t>;
-
- void get_hidden_changes(uint8_t fields[9], uint16_t mask, uint8_t n, hiddens_t &vec) const {
+ void get_hidden_changes(uint8_t fields[9], uint16_t mask, uint8_t n, changes_t &vec) const {
for (uint8_t i = 0; i < n; i++) {
const uint16_t change = value[fields[i]] & ~(value[fields[i]] & mask);
if (!change) continue;
@@ -121,18 +128,11 @@ class Ref {
res[field] = value + 1;
}
- auto get_pointing_row() const {
- static const uint16_t row_mask = 0x7;
- return get_pointing(row_mask, seen_point_row);
- }
-
- auto get_pointing_col() const {
- static const uint16_t col_mask = 0x49;
- return get_pointing(col_mask, seen_point_col);
- }
+ auto get_pointing_row() const { return get_pointing(0x7, seen_point_row); }
+ auto get_pointing_col() const { return get_pointing(0x49, seen_point_col); }
auto get_hidden_singles() const {
- std::vector<std::tuple<uint8_t, uint8_t>> res;
+ changes_t res;
for (uint8_t candidate = 0; candidate < 9; candidate++) {
if (std::popcount(ref[candidate]) != 1) continue;
@@ -143,7 +143,7 @@ class Ref {
}
auto get_hidden_pairs() const {
- hiddens_t res;
+ changes_t res;
for (uint8_t i = 0; i < 9; i++) {
if (std::popcount(ref[i]) != 2) continue;
@@ -167,7 +167,7 @@ class Ref {
}
auto get_hidden_triplets() const {
- hiddens_t res;
+ changes_t res;
for (uint8_t i = 0; i < 9; i++) {
if (std::popcount(ref[i]) < 2) continue;
@@ -201,7 +201,7 @@ class Ref {
}
auto get_hidden_quads() const {
- std::vector<std::tuple<uint8_t, uint16_t>> res;
+ changes_t res;
for (uint8_t i = 0; i < 9; i++) {
if (std::popcount(ref[i]) < 2) continue;
@@ -241,7 +241,7 @@ class Ref {
}
auto get_naked_singles() const {
- std::vector<std::tuple<uint8_t, uint8_t>> res;
+ changes_t res;
for (uint8_t i = 0; i < 9; i++) {
const auto values = get(i);
@@ -255,7 +255,7 @@ class Ref {
}
auto get_naked_pairs() const {
- std::vector<std::tuple<uint8_t, uint8_t, uint8_t>> res;
+ changes_t res;
for (uint8_t i = 0; i < 9; i++) {
if (std::popcount(value[i]) != 2) continue;
@@ -267,11 +267,14 @@ class Ref {
seen_naked_pair |= 1ul << i;
seen_naked_pair |= 1ul << j;
- uint16_t tval = value[i];
- while (tval) {
- const uint8_t idx = std::countr_zero(tval);
- res.emplace_back(idx, i, j);
- tval ^= 1ull << idx;
+ uint16_t val = value[i];
+ while (val) {
+ const uint8_t idx = std::countr_zero(val);
+ for (uint8_t pos = 0; pos < 9; pos++) {
+ if (pos == i || pos == j) continue;
+ res.emplace_back(pos, idx);
+ }
+ val ^= 1ull << idx;
}
}
}
@@ -280,7 +283,8 @@ class Ref {
}
auto get_naked_triplets() const {
- std::vector<std::tuple<uint8_t, uint8_t, uint8_t, uint8_t>> res;
+ changes_t res;
+
for (uint8_t i = 0; i < 9; i++) {
if (this->res[i]) continue;
if (seen_naked_triplet & (1ul << i)) continue;
@@ -302,7 +306,10 @@ class Ref {
while (val) {
const uint8_t idx = std::countr_zero(val);
- res.emplace_back(idx, i, j, k);
+ for (uint8_t pos = 0; pos < 9; pos++) {
+ if (pos == i || pos == j || pos == k) continue;
+ res.emplace_back(pos, idx);
+ }
val ^= 1ull << idx;
}
}
@@ -312,7 +319,7 @@ class Ref {
}
auto get_naked_quads() const {
- std::vector<std::tuple<uint8_t, uint8_t, uint8_t, uint8_t, uint8_t>> res;
+ changes_t res;
for (uint8_t i = 0; i < 9; i++) {
if (this->res[i]) continue;
@@ -340,7 +347,10 @@ class Ref {
while (val) {
const uint8_t idx = std::countr_zero(val);
- res.emplace_back(idx, i, j, k, l);
+ for (uint8_t pos = 0; pos < 9; pos++) {
+ if (pos == i || pos == j || pos == k || pos == l) continue;
+ res.emplace_back(pos, idx);
+ }
val ^= 1ull << idx;
}
}
@@ -362,16 +372,9 @@ class Ref {
uint16_t res[9] = {0};
- mutable uint16_t seen_hidden_pair = 0;
- mutable uint16_t seen_hidden_triplet = 0;
- mutable uint16_t seen_hidden_quads = 0;
-
- mutable uint16_t seen_naked_pair = 0;
- mutable uint16_t seen_naked_triplet = 0;
- mutable uint16_t seen_naked_quad = 0;
-
- mutable uint16_t seen_point_row = 0;
- mutable uint16_t seen_point_col = 0;
+ mutable uint16_t seen_hidden_pair = 0, seen_hidden_triplet = 0, seen_hidden_quads = 0;
+ mutable uint16_t seen_naked_pair = 0, seen_naked_triplet = 0, seen_naked_quad = 0;
+ mutable uint16_t seen_point_row = 0, seen_point_col = 0;
};
class Grid {
@@ -381,241 +384,111 @@ class Grid {
for (uint8_t i = 0; i < 9; i++) {
for (uint8_t j = 0; j < 9; j++, idx++) {
if (s[idx] == '0') continue;
- set({i, j}, s[idx] - '1');
+ _set({i, j}, s[idx] - '1');
}
}
}
- void set(acord_t ab, uint8_t value) {
- rows[ab.row()].set(ab.col(), value);
- cols[ab.col()].set(ab.row(), value);
- subgrids[(uint8_t)ab.subgrid()].set((uint8_t)ab.field(), value);
-
- _clear_row(ab.subgrid(), 0, value);
- _clear_row(ab.subgrid(), 1, value);
- _clear_row(ab.subgrid(), 2, value);
-
- const auto [r1, r2] = other_subgrid_row((uint8_t)ab.subgrid());
- _clear_row(r1, ab.field().row(), value);
- _clear_row(r2, ab.field().row(), value);
-
- const auto [c1, c2] = other_subgrid_col((uint8_t)ab.subgrid());
- _clear_col(c1, ab.field().col(), value);
- _clear_col(c2, ab.field().col(), value);
- }
-
bool solve() {
- int iter = 1;
- while (iter--) {
+ // clang-format off
+ static const auto sub_op =
+ [this](void (Grid::*op)(operation_t), Ref::changes_t (Ref::*f)() const, uint8_t subgrid) {
+ for_each((subgrids[subgrid].*(f))(), [this, subgrid, op](const Ref::change_t ch) {
+ (this->*(op))(operation_t({cord_t(subgrid), cord_t(std::get<0>(ch))}, std::get<1>(ch)));
+ });
+ };
+
+ static const auto row_op =
+ [this](void (Grid::*op)(operation_t), Ref::changes_t (Ref::*f)() const, uint8_t row) {
+ for_each((rows[row].*(f))(), [this, row, op](const Ref::change_t ch) {
+ (this->*(op))(operation_t({row, std::get<0>(ch)}, std::get<1>(ch)));
+ });
+ };
+
+ static const auto col_op =
+ [this](void (Grid::*op)(operation_t), Ref::changes_t (Ref::*f)() const, uint8_t col) {
+ for_each((cols[col].*(f))(), [this, col, op](const Ref::change_t ch) {
+ (this->*(op))(operation_t({std::get<0>(ch), col}, std::get<1>(ch)));
+ });
+ };
+ // clang-format on
+
+ changed = true;
+ while (changed) {
+ changed = false;
for (uint8_t subgrid = 0; subgrid < 9; subgrid++) {
- for (const auto [field, val] : subgrids[subgrid].get_naked_singles()) {
- set({cord_t(subgrid), cord_t(field)}, val);
- iter = true;
- }
+ sub_op(&Grid::op_set, &Ref::get_naked_singles, subgrid);
}
- if (iter) continue;
+ if (changed) continue;
for (uint8_t idx = 0; idx < 9; idx++) {
- for (const auto [field, val] : subgrids[idx].get_hidden_singles()) {
- set({cord_t(idx), cord_t(field)}, val);
- iter = true;
- }
-
- for (const auto [col, val] : rows[idx].get_hidden_singles()) {
- set({idx, col}, val);
- iter = true;
- }
-
- for (const auto [row, val] : cols[idx].get_hidden_singles()) {
- set({row, idx}, val);
- iter = true;
- }
+ sub_op(&Grid::op_set, &Ref::get_hidden_singles, idx);
+ row_op(&Grid::op_set, &Ref::get_hidden_singles, idx);
+ col_op(&Grid::op_set, &Ref::get_hidden_singles, idx);
}
- if (iter) continue;
+ if (changed) continue;
for (uint8_t subgrid = 0; subgrid < 9; subgrid++) {
- for (const auto [row, val] : subgrids[subgrid].get_pointing_row()) {
- const auto [r1, r2] = other_subgrid_row(subgrid);
- _clear_row(r1, row, val);
- _clear_row(r2, row, val);
-
- iter = true;
- }
-
- for (const auto [col, val] : subgrids[subgrid].get_pointing_col()) {
- const auto [c1, c2] = other_subgrid_col(subgrid);
- _clear_col(c1, col, val);
- _clear_col(c2, col, val);
-
- iter = true;
- }
+ sub_op(&Grid::op_clear_row, &Ref::get_pointing_row, subgrid);
+ sub_op(&Grid::op_clear_col, &Ref::get_pointing_col, subgrid);
}
- if (iter) continue;
+ if (changed) continue;
for (uint8_t idx = 0; idx < 9; idx++) {
- for (const auto [field, mask] : subgrids[idx].get_hidden_pairs()) {
- _mask({cord_t(idx), cord_t(field)}, mask);
- iter = true;
- }
-
- for (const auto [col, mask] : rows[idx].get_hidden_pairs()) {
- _mask({idx, col}, mask);
- iter = true;
- }
-
- for (const auto [row, mask] : cols[idx].get_hidden_pairs()) {
- _mask({row, idx}, mask);
- iter = true;
- }
+ sub_op(&Grid::op_mask, &Ref::get_hidden_pairs, idx);
+ row_op(&Grid::op_mask, &Ref::get_hidden_pairs, idx);
+ col_op(&Grid::op_mask, &Ref::get_hidden_pairs, idx);
}
- if (iter) continue;
+ if (changed) continue;
for (uint8_t idx = 0; idx < 9; idx++) {
- for (const auto [field, mask] : subgrids[idx].get_hidden_triplets()) {
- _mask({cord_t(idx), cord_t(field)}, mask);
- iter = true;
- }
-
- for (const auto [col, mask] : rows[idx].get_hidden_triplets()) {
- _mask({idx, col}, mask);
- iter = true;
- }
-
- for (const auto [row, mask] : cols[idx].get_hidden_triplets()) {
- _mask({row, idx}, mask);
- iter = true;
- }
+ sub_op(&Grid::op_mask, &Ref::get_hidden_triplets, idx);
+ row_op(&Grid::op_mask, &Ref::get_hidden_triplets, idx);
+ col_op(&Grid::op_mask, &Ref::get_hidden_triplets, idx);
}
- if (iter) continue;
+ if (changed) continue;
for (uint8_t idx = 0; idx < 9; idx++) {
- for (const auto [field, mask] : subgrids[idx].get_hidden_quads()) {
- _mask({cord_t(idx), cord_t(field)}, mask);
- iter = true;
- }
-
- for (const auto [col, mask] : rows[idx].get_hidden_quads()) {
- _mask({idx, col}, mask);
- iter = true;
- }
-
- for (const auto [row, mask] : cols[idx].get_hidden_quads()) {
- _mask({row, idx}, mask);
- iter = true;
- }
+ sub_op(&Grid::op_mask, &Ref::get_hidden_quads, idx);
+ row_op(&Grid::op_mask, &Ref::get_hidden_quads, idx);
+ col_op(&Grid::op_mask, &Ref::get_hidden_quads, idx);
}
- if (iter) continue;
+ if (changed) continue;
for (uint8_t idx = 0; idx < 9; idx++) {
- for (const auto [val, f1, f2] : subgrids[idx].get_naked_pairs()) {
- for (uint8_t field = 0; field < 9; field++) {
- if (field == f1 || field == f2) continue;
- _clear({cord_t(idx), cord_t(field)}, val);
- }
-
- iter = true;
- }
-
- for (const auto [val, c1, c2] : rows[idx].get_naked_pairs()) {
- for (uint8_t col = 0; col < 9; col++) {
- if (col == c1 || col == c2) continue;
- _clear({idx, col}, val);
- }
- iter = true;
- }
-
- for (const auto [val, r1, r2] : cols[idx].get_naked_pairs()) {
- for (uint8_t row = 0; row < 9; row++) {
- if (row == r1 || row == r2) continue;
- _clear({row, idx}, val);
- }
- iter = true;
- }
+ sub_op(&Grid::op_clear, &Ref::get_naked_pairs, idx);
+ row_op(&Grid::op_clear, &Ref::get_naked_pairs, idx);
+ col_op(&Grid::op_clear, &Ref::get_naked_pairs, idx);
}
- if (iter) continue;
+ if (changed) continue;
for (uint8_t idx = 0; idx < 9; idx++) {
- for (const auto [val, f1, f2, f3] : subgrids[idx].get_naked_triplets()) {
- for (uint8_t field = 0; field < 9; field++) {
- if (field == f1 || field == f2 || field == f3) continue;
- _clear({cord_t(idx), cord_t(field)}, val);
- }
-
- iter = true;
- }
-
- for (const auto [val, c1, c2, c3] : rows[idx].get_naked_triplets()) {
- for (uint8_t col = 0; col < 9; col++) {
- if (col == c1 || col == c2 || col == c3) continue;
- _clear({idx, col}, val);
- }
- iter = true;
- }
-
- for (const auto [val, r1, r2, r3] : cols[idx].get_naked_triplets()) {
- for (uint8_t row = 0; row < 9; row++) {
- if (row == r1 || row == r2 || row == r3) continue;
- _clear({row, idx}, val);
- }
- iter = true;
- }
+ sub_op(&Grid::op_clear, &Ref::get_naked_triplets, idx);
+ row_op(&Grid::op_clear, &Ref::get_naked_triplets, idx);
+ col_op(&Grid::op_clear, &Ref::get_naked_triplets, idx);
}
- if (iter) continue;
+ if (changed) continue;
for (uint8_t idx = 0; idx < 9; idx++) {
- for (const auto [val, f1, f2, f3, f4] : subgrids[idx].get_naked_quads()) {
- for (uint8_t field = 0; field < 9; field++) {
- if (field == f1 || field == f2 || field == f3 || field == f4) continue;
- _clear({cord_t(idx), cord_t(field)}, val);
- }
-
- iter = true;
- }
-
- for (const auto [val, c1, c2, c3, c4] : rows[idx].get_naked_quads()) {
- for (uint8_t col = 0; col < 9; col++) {
- if (col == c1 || col == c2 || col == c3 || col == c4) continue;
- _clear({idx, col}, val);
- }
- iter = true;
- }
-
- for (const auto [val, r1, r2, r3, r4] : cols[idx].get_naked_quads()) {
- for (uint8_t row = 0; row < 9; row++) {
- if (row == r1 || row == r2 || row == r3 || row == r4) continue;
- _clear({row, idx}, val);
- }
- iter = true;
- }
+ sub_op(&Grid::op_clear, &Ref::get_naked_quads, idx);
+ row_op(&Grid::op_clear, &Ref::get_naked_quads, idx);
+ col_op(&Grid::op_clear, &Ref::get_naked_quads, idx);
}
- if (iter) continue;
+ if (changed) continue;
for (uint8_t idx = 0; idx < 9; idx++) {
- for (const auto [index, val] : rows[idx].get_pointing_row()) {
- const auto [r1, r2] = other_subgrid_row(idx);
- _clear_row((r1 / 3) * 3 + index, r1 % 3, val);
- _clear_row((r2 / 3) * 3 + index, r2 % 3, val);
-
- iter = true;
- }
-
- for (const auto [index, val] : cols[idx].get_pointing_row()) {
- const auto [c1, c2] = other_subgrid_row(idx);
- _clear_col(index * 3 + c1 / 3, c1 % 3, val);
- _clear_col(index * 3 + c2 / 3, c2 % 3, val);
-
- iter = true;
- }
+ row_op(&Grid::op_clear_row_rel, &Ref::get_pointing_row, idx);
+ col_op(&Grid::op_clear_col_rel, &Ref::get_pointing_row, idx);
}
}
@@ -684,6 +557,81 @@ class Grid {
}
private:
+ using operation_t = std::tuple<acord_t, uint16_t>;
+
+ void op_set(operation_t op) {
+ _set(std::get<0>(op), std::get<1>(op));
+ changed = true;
+ }
+
+ void op_mask(operation_t op) {
+ _mask(std::get<0>(op), std::get<1>(op));
+ changed = true;
+ }
+
+ void op_clear(operation_t op) {
+ _clear(std::get<0>(op), std::get<1>(op));
+ changed = true;
+ }
+
+ void op_clear_row(operation_t op) {
+ const auto [ab, val] = op;
+
+ const auto [r1, r2] = other_subgrid_row((uint8_t)ab.subgrid());
+ _clear_row(r1, (uint8_t)ab.field(), val);
+ _clear_row(r2, (uint8_t)ab.field(), val);
+
+ changed = true;
+ }
+
+ void op_clear_col(operation_t op) {
+ const auto [ab, val] = op;
+
+ const auto [c1, c2] = other_subgrid_col((uint8_t)ab.subgrid());
+ _clear_col(c1, (uint8_t)ab.field(), val);
+ _clear_col(c2, (uint8_t)ab.field(), val);
+
+ changed = true;
+ }
+
+ void op_clear_row_rel(operation_t op) {
+ const auto [ab, val] = op;
+
+ const auto [r1, r2] = other_subgrid_row((uint8_t)ab.row());
+ _clear_row((r1 / 3) * 3 + (uint8_t)ab.col(), r1 % 3, val);
+ _clear_row((r2 / 3) * 3 + (uint8_t)ab.col(), r2 % 3, val);
+
+ changed = true;
+ }
+
+ void op_clear_col_rel(operation_t op) {
+ const auto [ab, val] = op;
+
+ const auto [c1, c2] = other_subgrid_row((uint8_t)ab.col());
+ _clear_col((uint8_t)ab.row() * 3 + c1 / 3, c1 % 3, val);
+ _clear_col((uint8_t)ab.row() * 3 + c2 / 3, c2 % 3, val);
+
+ changed = true;
+ }
+
+ void _set(acord_t ab, uint8_t value) {
+ rows[ab.row()].set(ab.col(), value);
+ cols[ab.col()].set(ab.row(), value);
+ subgrids[(uint8_t)ab.subgrid()].set((uint8_t)ab.field(), value);
+
+ _clear_row(ab.subgrid(), 0, value);
+ _clear_row(ab.subgrid(), 1, value);
+ _clear_row(ab.subgrid(), 2, value);
+
+ const auto [r1, r2] = other_subgrid_row((uint8_t)ab.subgrid());
+ _clear_row(r1, ab.field().row(), value);
+ _clear_row(r2, ab.field().row(), value);
+
+ const auto [c1, c2] = other_subgrid_col((uint8_t)ab.subgrid());
+ _clear_col(c1, ab.field().col(), value);
+ _clear_col(c2, ab.field().col(), value);
+ }
+
void _mask(acord_t ab, uint16_t mask) {
while (mask) {
const uint8_t idx = std::countr_zero(mask);
@@ -728,6 +676,8 @@ class Grid {
Ref subgrids[9];
Ref rows[9];
Ref cols[9];
+
+ bool changed = false;
};
int main(const int argc, const char *argv[]) {