doasku

Human-like solver for sudoku
git clone git://git.dimitrijedobrota.com/doasku.git
Log | Files | Refs | README | LICENSE | HACKING | CONTRIBUTING | CODE_OF_CONDUCT | BUILDING |

commite5cc8a937714a060616c910895c3b132cfe8ae15
parent260874dfa278ba73acb805eaf53464771e84d6e5
authorDimitrije Dobrota <mail@dimitrijedobrota.com>
dateTue, 2 Apr 2024 23:16:49 +0200

Better type safety

Diffstat:
Mmain.cpp|+++++++++++++++++++++++++++++++++++++++++++++++++---------------------------------

1 files changed, 60 insertions(+), 39 deletions(-)


diff --git a/main.cpp b/main.cpp

@@ -9,29 +9,48 @@

static constexpr const std::int64_t mask_field = (1 << 9) - 1;
static constexpr const std::int64_t mask_value = 0x201008040201;
using row_col_t = std::tuple<uint8_t, uint8_t>;
class row_col_t;
class row_col_rt;
std::ostream &operator<<(std::ostream &os, row_col_t rc) {
return os << std::format("({}, {})", std::get<0>(rc), std::get<1>(rc));
}
struct row_col_t {
row_col_t(uint8_t row, uint8_t col) : row(row), col(col) {}
row_col_t(uint8_t field) : row(field / 3), col(field % 3) {}
row_col_t(row_col_rt rc);
explicit operator row_col_rt() const;
row_col_t row_col(const uint8_t field) { return {field / 3, field % 3}; }
row_col_t row_col_rev(const uint8_t field) { return {2 - field % 3, field / 3}; }
friend std::ostream &operator<<(std::ostream &os, row_col_t rc) {
return os << std::format("({}, {})", rc.row, rc.col);
}
uint8_t row;
uint8_t col;
};
row_col_t row_col(const row_col_t rcr) { return {std::get<1>(rcr), 2 - std::get<0>(rcr)}; }
row_col_t row_col_rev(const row_col_t rc) { return {2 - std::get<1>(rc), std::get<0>(rc)}; }
struct row_col_rt {
row_col_rt(uint8_t row, uint8_t col) : row(row), col(col) {}
row_col_rt(uint8_t field) : row(2 - field % 3), col(field / 3) {}
row_col_rt(row_col_t rc);
friend std::ostream &operator<<(std::ostream &os, row_col_rt rc) {
return os << std::format("({}, {})", rc.row, rc.col);
}
uint8_t row;
uint8_t col;
};
uint8_t field(const row_col_t rc) { return 3 * std::get<0>(rc) + std::get<1>(rc); }
row_col_t::row_col_t(row_col_rt rc) : row(rc.col), col(2 - rc.row) {}
row_col_rt::row_col_rt(row_col_t rc) : row(2 - rc.col), col(rc.row) {}
uint8_t field(row_col_t rc) { return 3 * rc.row + rc.col; }
class Row {
public:
Row(uint64_t value = 0) : val(value) {}
operator uint64_t() const { return val; }
void set(uint8_t field, uint8_t value) {
clear(field);
toggle(field, value);
}
void set(uint8_t field, uint8_t value) { clear(field), toggle(field, value); }
uint64_t get(uint8_t field) const { return (val >> 9 * field) & mask_field; }
uint64_t get(uint8_t field, uint8_t value) const { return (val >> 9 * field) & (1 << value); }

@@ -57,8 +76,8 @@ class Subgrid {

public:
Subgrid() {}
uint64_t get(row_col_t rc) const { return rows[std::get<0>(rc)].get(std::get<1>(rc)); }
uint64_t get_rev(row_col_t rc) const { return rows[std::get<0>(rc)].get(std::get<1>(rc) + 3); }
uint64_t get(row_col_t rc) const { return rows[rc.row].get(rc.col); }
uint64_t get(row_col_rt rc) const { return rows[rc.row].get(rc.col + 3); }
void set(uint8_t field, uint8_t value) {
assert(field < 9 && value < 9);

@@ -67,30 +86,31 @@ class Subgrid {

rows[1].clear_all(value);
rows[2].clear_all(value);
set(row_col(field), value);
set_rev(row_col_rev(field), value);
set(row_col_t(field), value);
set(row_col_rt(field), value);
}
void clear_row(uint8_t row, uint8_t value) {
assert(row < 3 && value < 9);
for (int i = 0; i < 3; i++) {
clear({row, i}, value);
clear_rev(row_col_rev({row, i}), value);
for (uint8_t i = 0; i < 3; i++) {
const row_col_t rc = {row, i};
clear(row_col_rt(rc), value);
clear(rc, value);
}
}
void clear_col(uint8_t col, uint8_t value) {
assert(col < 3 && value < 9);
for (int i = 0; i < 3; i++) {
clear({i, col}, value);
clear_rev(row_col_rev({i, col}), value);
for (uint8_t i = 0; i < 3; i++) {
const row_col_t rc = {i, col};
clear(row_col_rt(rc), value);
clear(rc, value);
}
}
friend std::ostream &operator<<(std::ostream &os, const Subgrid &b) {
for (int i = 0; i < 3; i++) {
os << b.rows[i] << " ";
}

@@ -98,11 +118,11 @@ class Subgrid {

}
private:
void set(row_col_t rc, uint8_t value) { rows[std::get<0>(rc)].set(std::get<1>(rc), value); }
void set_rev(row_col_t rc, uint8_t value) { rows[std::get<0>(rc)].set(std::get<1>(rc) + 3, value); }
void set(row_col_t rc, uint8_t value) { rows[rc.row].set(rc.col, value); }
void set(row_col_rt rcr, uint8_t value) { rows[rcr.row].set(rcr.col + 3, value); }
void clear(row_col_t rc, uint8_t value) { rows[std::get<0>(rc)].clear(std::get<1>(rc), value); }
void clear_rev(row_col_t rc, uint8_t value) { rows[std::get<0>(rc)].clear(std::get<1>(rc) + 3, value); }
void clear(row_col_t rc, uint8_t value) { rows[rc.row].clear(rc.col, value); }
void clear(row_col_rt rcr, uint8_t value) { rows[rcr.row].clear(rcr.col + 3, value); }
Row rows[3] = {~0, ~0, ~0};
};

@@ -112,8 +132,8 @@ class Grid {

void set(uint8_t subgrid, uint8_t field, uint8_t value) {
assert(subgrid < 9 && field < 9 && value < 9);
const auto [row, col] = row_col(subgrid);
const auto [frow, fcol] = row_col(field);
const auto [row, col] = row_col_t(subgrid);
const auto [frow, fcol] = row_col_t(field);
subgrids[3 * row + 0].clear_row(frow, value);
subgrids[3 * row + 1].clear_row(frow, value);

@@ -129,10 +149,10 @@ class Grid {

void print() const {
std::cout << "Field: " << std::endl;
for (int i = 0; i < 9; i++) {
for (int j = 0; j < 9; j++) {
const row_col_t subgrid = {i / 3, j / 3};
const row_col_t field = {i % 3, j % 3};
for (uint8_t i = 0; i < 9; i++) {
for (uint8_t j = 0; j < 9; j++) {
const row_col_t subgrid = {uint8_t(i / 3), uint8_t(j / 3)};
const row_col_t field = {uint8_t(i % 3), uint8_t(j % 3)};
std::bitset<9> value = subgrids[::field(subgrid)].get(field);
std::cout << value << " ";
if (j % 3 == 2) std::cout << " ";

@@ -144,11 +164,12 @@ class Grid {

std::cout << std::endl;
std::cout << "Reversed Field: " << std::endl;
for (int i = 0; i < 9; i++) {
for (int j = 0; j < 9; j++) {
const row_col_t subgrid = {i / 3, j / 3};
const row_col_t rfield = row_col_rev({i % 3, j % 3});
std::bitset<9> value = subgrids[::field(subgrid)].get_rev(rfield);
for (uint8_t i = 0; i < 9; i++) {
for (uint8_t j = 0; j < 9; j++) {
const row_col_t subgrid = {uint8_t(i / 3), uint8_t(j / 3)};
const row_col_t field = {uint8_t(i % 3), uint8_t(j % 3)};
const row_col_rt rfield = row_col_rt(field);
std::bitset<9> value = subgrids[::field(subgrid)].get(rfield);
std::cout << value << " ";
if (j % 3 == 2) std::cout << " ";
}