commit e5cc8a937714a060616c910895c3b132cfe8ae15
parent 260874dfa278ba73acb805eaf53464771e84d6e5
Author: Dimitrije Dobrota <mail@dimitrijedobrota.com>
Date: Wed, 3 Apr 2024 01:16:49 +0200
Better type safety
Diffstat:
M | main.cpp | | | 99 | ++++++++++++++++++++++++++++++++++++++++++++++++------------------------------- |
1 file changed, 60 insertions(+), 39 deletions(-)
diff --git a/main.cpp b/main.cpp
@@ -9,29 +9,48 @@
static constexpr const std::int64_t mask_field = (1 << 9) - 1;
static constexpr const std::int64_t mask_value = 0x201008040201;
-using row_col_t = std::tuple<uint8_t, uint8_t>;
+class row_col_t;
+class row_col_rt;
-std::ostream &operator<<(std::ostream &os, row_col_t rc) {
- return os << std::format("({}, {})", std::get<0>(rc), std::get<1>(rc));
-}
+struct row_col_t {
+ row_col_t(uint8_t row, uint8_t col) : row(row), col(col) {}
+ row_col_t(uint8_t field) : row(field / 3), col(field % 3) {}
+ row_col_t(row_col_rt rc);
+
+ explicit operator row_col_rt() const;
-row_col_t row_col(const uint8_t field) { return {field / 3, field % 3}; }
-row_col_t row_col_rev(const uint8_t field) { return {2 - field % 3, field / 3}; }
+ friend std::ostream &operator<<(std::ostream &os, row_col_t rc) {
+ return os << std::format("({}, {})", rc.row, rc.col);
+ }
+
+ uint8_t row;
+ uint8_t col;
+};
-row_col_t row_col(const row_col_t rcr) { return {std::get<1>(rcr), 2 - std::get<0>(rcr)}; }
-row_col_t row_col_rev(const row_col_t rc) { return {2 - std::get<1>(rc), std::get<0>(rc)}; }
+struct row_col_rt {
+ row_col_rt(uint8_t row, uint8_t col) : row(row), col(col) {}
+ row_col_rt(uint8_t field) : row(2 - field % 3), col(field / 3) {}
+ row_col_rt(row_col_t rc);
+
+ friend std::ostream &operator<<(std::ostream &os, row_col_rt rc) {
+ return os << std::format("({}, {})", rc.row, rc.col);
+ }
+
+ uint8_t row;
+ uint8_t col;
+};
-uint8_t field(const row_col_t rc) { return 3 * std::get<0>(rc) + std::get<1>(rc); }
+row_col_t::row_col_t(row_col_rt rc) : row(rc.col), col(2 - rc.row) {}
+row_col_rt::row_col_rt(row_col_t rc) : row(2 - rc.col), col(rc.row) {}
+
+uint8_t field(row_col_t rc) { return 3 * rc.row + rc.col; }
class Row {
public:
Row(uint64_t value = 0) : val(value) {}
operator uint64_t() const { return val; }
- void set(uint8_t field, uint8_t value) {
- clear(field);
- toggle(field, value);
- }
+ void set(uint8_t field, uint8_t value) { clear(field), toggle(field, value); }
uint64_t get(uint8_t field) const { return (val >> 9 * field) & mask_field; }
uint64_t get(uint8_t field, uint8_t value) const { return (val >> 9 * field) & (1 << value); }
@@ -57,8 +76,8 @@ class Subgrid {
public:
Subgrid() {}
- uint64_t get(row_col_t rc) const { return rows[std::get<0>(rc)].get(std::get<1>(rc)); }
- uint64_t get_rev(row_col_t rc) const { return rows[std::get<0>(rc)].get(std::get<1>(rc) + 3); }
+ uint64_t get(row_col_t rc) const { return rows[rc.row].get(rc.col); }
+ uint64_t get(row_col_rt rc) const { return rows[rc.row].get(rc.col + 3); }
void set(uint8_t field, uint8_t value) {
assert(field < 9 && value < 9);
@@ -67,30 +86,31 @@ class Subgrid {
rows[1].clear_all(value);
rows[2].clear_all(value);
- set(row_col(field), value);
- set_rev(row_col_rev(field), value);
+ set(row_col_t(field), value);
+ set(row_col_rt(field), value);
}
void clear_row(uint8_t row, uint8_t value) {
assert(row < 3 && value < 9);
- for (int i = 0; i < 3; i++) {
- clear({row, i}, value);
- clear_rev(row_col_rev({row, i}), value);
+ for (uint8_t i = 0; i < 3; i++) {
+ const row_col_t rc = {row, i};
+ clear(row_col_rt(rc), value);
+ clear(rc, value);
}
}
void clear_col(uint8_t col, uint8_t value) {
assert(col < 3 && value < 9);
- for (int i = 0; i < 3; i++) {
- clear({i, col}, value);
- clear_rev(row_col_rev({i, col}), value);
+ for (uint8_t i = 0; i < 3; i++) {
+ const row_col_t rc = {i, col};
+ clear(row_col_rt(rc), value);
+ clear(rc, value);
}
}
friend std::ostream &operator<<(std::ostream &os, const Subgrid &b) {
-
for (int i = 0; i < 3; i++) {
os << b.rows[i] << " ";
}
@@ -98,11 +118,11 @@ class Subgrid {
}
private:
- void set(row_col_t rc, uint8_t value) { rows[std::get<0>(rc)].set(std::get<1>(rc), value); }
- void set_rev(row_col_t rc, uint8_t value) { rows[std::get<0>(rc)].set(std::get<1>(rc) + 3, value); }
+ void set(row_col_t rc, uint8_t value) { rows[rc.row].set(rc.col, value); }
+ void set(row_col_rt rcr, uint8_t value) { rows[rcr.row].set(rcr.col + 3, value); }
- void clear(row_col_t rc, uint8_t value) { rows[std::get<0>(rc)].clear(std::get<1>(rc), value); }
- void clear_rev(row_col_t rc, uint8_t value) { rows[std::get<0>(rc)].clear(std::get<1>(rc) + 3, value); }
+ void clear(row_col_t rc, uint8_t value) { rows[rc.row].clear(rc.col, value); }
+ void clear(row_col_rt rcr, uint8_t value) { rows[rcr.row].clear(rcr.col + 3, value); }
Row rows[3] = {~0, ~0, ~0};
};
@@ -112,8 +132,8 @@ class Grid {
void set(uint8_t subgrid, uint8_t field, uint8_t value) {
assert(subgrid < 9 && field < 9 && value < 9);
- const auto [row, col] = row_col(subgrid);
- const auto [frow, fcol] = row_col(field);
+ const auto [row, col] = row_col_t(subgrid);
+ const auto [frow, fcol] = row_col_t(field);
subgrids[3 * row + 0].clear_row(frow, value);
subgrids[3 * row + 1].clear_row(frow, value);
@@ -129,10 +149,10 @@ class Grid {
void print() const {
std::cout << "Field: " << std::endl;
- for (int i = 0; i < 9; i++) {
- for (int j = 0; j < 9; j++) {
- const row_col_t subgrid = {i / 3, j / 3};
- const row_col_t field = {i % 3, j % 3};
+ for (uint8_t i = 0; i < 9; i++) {
+ for (uint8_t j = 0; j < 9; j++) {
+ const row_col_t subgrid = {uint8_t(i / 3), uint8_t(j / 3)};
+ const row_col_t field = {uint8_t(i % 3), uint8_t(j % 3)};
std::bitset<9> value = subgrids[::field(subgrid)].get(field);
std::cout << value << " ";
if (j % 3 == 2) std::cout << " ";
@@ -144,11 +164,12 @@ class Grid {
std::cout << std::endl;
std::cout << "Reversed Field: " << std::endl;
- for (int i = 0; i < 9; i++) {
- for (int j = 0; j < 9; j++) {
- const row_col_t subgrid = {i / 3, j / 3};
- const row_col_t rfield = row_col_rev({i % 3, j % 3});
- std::bitset<9> value = subgrids[::field(subgrid)].get_rev(rfield);
+ for (uint8_t i = 0; i < 9; i++) {
+ for (uint8_t j = 0; j < 9; j++) {
+ const row_col_t subgrid = {uint8_t(i / 3), uint8_t(j / 3)};
+ const row_col_t field = {uint8_t(i % 3), uint8_t(j % 3)};
+ const row_col_rt rfield = row_col_rt(field);
+ std::bitset<9> value = subgrids[::field(subgrid)].get(rfield);
std::cout << value << " ";
if (j % 3 == 2) std::cout << " ";
}