doasku

Sudoku solver
git clone git://git.dimitrijedobrota.com/doasku.git
Log | Files | Refs

main.cpp (17013B)


      1 #include <algorithm>
      2 #include <array>
      3 #include <bit>
      4 #include <bitset>
      5 #include <cassert>
      6 #include <cinttypes>
      7 #include <cstring>
      8 #include <format>
      9 #include <iomanip>
     10 #include <iostream>
     11 #include <vector>
     12 
     13 template <class Input, class UnaryFunc> UnaryFunc for_each(Input input, UnaryFunc f) {
     14     return std::for_each(begin(input), end(input), f);
     15 }
     16 
     17 struct cord_t {
     18     cord_t(uint8_t row, uint8_t col) : value(row * 3 + col) { assert(row < 3 && col < 3); }
     19     cord_t(uint8_t value) : value(value) { assert(value < 9); }
     20 
     21     operator uint8_t() const { return value; }
     22 
     23     uint8_t row() const { return value / 3; }
     24     uint8_t col() const { return value % 3; }
     25 
     26     friend std::ostream &operator<<(std::ostream &os, cord_t cord) {
     27         return os << std::format("({}, {})", cord.row(), cord.col());
     28     }
     29 
     30     uint8_t value;
     31 };
     32 
     33 class acord_t {
     34   public:
     35     acord_t(cord_t subgrid, cord_t field) : subgrid_i(subgrid), field_i(field) {}
     36     acord_t(uint8_t row, uint8_t col) : subgrid_i(row / 3, col / 3), field_i(row % 3, col % 3) {}
     37 
     38     cord_t subgrid() const { return subgrid_i; }
     39     cord_t field() const { return field_i; }
     40 
     41     uint8_t row() const { return subgrid_i.row() * 3 + field_i.row(); }
     42     uint8_t col() const { return subgrid_i.col() * 3 + field_i.col(); }
     43 
     44     std::tuple<cord_t, cord_t> relative() const { return {subgrid_i, field_i}; }
     45 
     46     friend std::ostream &operator<<(std::ostream &os, acord_t acord) {
     47         return os << std::format("(({}, {}))", acord.row(), acord.col());
     48     }
     49 
     50   private:
     51     cord_t subgrid_i;
     52     cord_t field_i;
     53 };
     54 
     55 static std::tuple<uint8_t, uint8_t> other_subgrid_row(uint8_t subgrid) {
     56     static std::tuple<uint8_t, uint8_t> mapping[9] = {{1, 2}, {0, 2}, {0, 1}, {4, 5}, {3, 5},
     57                                                       {3, 4}, {7, 8}, {6, 8}, {6, 7}};
     58     return mapping[subgrid];
     59 }
     60 
     61 static std::tuple<uint8_t, uint8_t> other_subgrid_col(uint8_t subgrid) {
     62     static std::tuple<uint8_t, uint8_t> mapping[9] = {{3, 6}, {4, 7}, {5, 8}, {0, 6}, {1, 7},
     63                                                       {2, 8}, {0, 3}, {1, 4}, {2, 5}};
     64     return mapping[subgrid];
     65 }
     66 
     67 class Ref {
     68   public:
     69     using change_t = std::tuple<uint8_t, uint16_t>;
     70     using changes_t = std::vector<change_t>;
     71 
     72   public:
     73     uint16_t get(uint8_t field) const { return value[field]; }
     74     uint16_t get_ref(uint8_t value) const { return ref[value]; }
     75     uint16_t get_value(uint8_t field) const { return res[field]; }
     76 
     77     void clear(uint8_t field, uint8_t value) {
     78         this->value[field] &= ~(1 << value);
     79         ref[value] &= ~(1 << field);
     80     }
     81 
     82     void set(uint8_t field, uint8_t value) {
     83         for (uint8_t i = 0; i < 9; i++) {
     84             ref[i] &= ~(1 << field);
     85         }
     86 
     87         this->value[field] = ref[value] = 0;
     88         res[field] = value + 1;
     89     }
     90 
     91     changes_t get_hidden_singles() const {
     92         changes_t res;
     93 
     94         for (uint8_t candidate = 0; candidate < 9; candidate++) {
     95             if (std::popcount(ref[candidate]) != 1) continue;
     96             res.emplace_back(std::countr_zero(ref[candidate]), candidate);
     97         }
     98 
     99         return res;
    100     }
    101 
    102     changes_t get_naked_singles() const {
    103         changes_t res;
    104 
    105         for (uint8_t i = 0; i < 9; i++) {
    106             const auto values = get(i);
    107             if (std::popcount(values) != 1) continue;
    108             const auto value = std::countr_zero(values);
    109             if (!ref[value]) continue;
    110             res.emplace_back(i, value);
    111         }
    112 
    113         return res;
    114     }
    115 
    116     auto get_hidden_pairs() const { return get_hidden(2); }
    117     auto get_hidden_triplets() const { return get_hidden(3); }
    118     auto get_hidden_quads() const { return get_hidden(4); }
    119 
    120     auto get_naked_pairs() const { return get_naked(2); }
    121     auto get_naked_triplets() const { return get_naked(3); }
    122     auto get_naked_quads() const { return get_naked(4); }
    123 
    124     auto get_pointing_row() const { return get_pointing(0x7, 0); }
    125     auto get_pointing_col() const { return get_pointing(0x49, 1); }
    126 
    127   private:
    128     changes_t get_hidden(int number) const {
    129         assert(number > 1 && number <= 4);
    130 
    131         changes_t res;
    132         get_hidden(res, number, number, 0, 0, 0);
    133         return res;
    134     }
    135 
    136     bool get_hidden(changes_t &res, uint8_t og, uint8_t number, uint8_t first, uint16_t val,
    137                     uint16_t mask) const {
    138         if (number != 0) {
    139             for (uint8_t i = first; i < 9; i++) {
    140                 if (std::popcount(ref[i]) < 2) continue;
    141                 if (seen_hidden[og] & (1ul << i)) continue;
    142 
    143                 bool used = get_hidden(res, og, number - 1, i + 1, val | ref[i], mask | (1 << i));
    144                 if (!used) continue;
    145 
    146                 seen_hidden[og] |= 1ul << i;
    147                 if (number != og) return true;
    148             }
    149 
    150             return false;
    151         }
    152 
    153         if (std::popcount(val) != og) return false;
    154 
    155         static uint8_t fields[9];
    156         uint8_t size = 0;
    157 
    158         while (val) {
    159             const uint8_t idx = std::countr_zero(val);
    160             fields[size++] = idx;
    161             val ^= 1ull << idx;
    162         }
    163 
    164         for (uint8_t i = 0; i < og; i++) {
    165             const uint16_t change = value[fields[i]] & ~(value[fields[i]] & mask);
    166             if (!change) continue;
    167             res.emplace_back(fields[i], change);
    168         }
    169 
    170         return true;
    171     }
    172 
    173     changes_t get_naked(int number) const {
    174         assert(number > 1 && number <= 4);
    175 
    176         changes_t res;
    177         get_naked(res, number, number, 0, 0);
    178         return res;
    179     }
    180 
    181     bool get_naked(changes_t &res, uint8_t og, uint8_t number, uint8_t first, uint16_t val) const {
    182         static uint8_t seen[4] = {0};
    183 
    184         if (number != 0) {
    185             for (uint8_t i = first; i < 9; i++) {
    186                 if (this->res[i]) continue;
    187                 if (number == og && seen_naked[og] & (1ul << i)) continue;
    188 
    189                 seen[og - number] = i;
    190                 bool used = get_naked(res, og, number - 1, i + 1, val | value[i]);
    191                 if (!used) continue;
    192 
    193                 if (number == og) seen_naked[og] |= 1ul << i;
    194                 if (number != og) return true;
    195             }
    196 
    197             return false;
    198         }
    199 
    200         if (std::popcount(val) != og) return false;
    201 
    202         while (val) {
    203             const uint8_t idx = std::countr_zero(val);
    204             for (uint8_t pos = 0, i; pos < 9; pos++) {
    205                 if ((value[pos] & idx) == 0) continue;
    206                 for (i = 0; i < og; i++) {
    207                     if (pos == seen[i]) break;
    208                 }
    209                 if (i == og) res.emplace_back(pos, idx);
    210             }
    211             val ^= 1ull << idx;
    212         }
    213 
    214         return true;
    215     }
    216 
    217     changes_t get_pointing(uint16_t mask, bool seen) const {
    218         changes_t res;
    219 
    220         for (uint8_t i = 0; i < 9; i++) {
    221             const uint8_t popcnt = std::popcount(ref[i]);
    222             if (popcnt < 2 || popcnt > 3) continue;
    223 
    224             if ((seen_point[seen] & (1 << i)) == 0) {
    225                 uint16_t cmask = mask;
    226                 for (uint8_t k = 0; k < 3; k++, cmask <<= 3) {
    227                     if (std::popcount(uint16_t(cmask & ref[i])) != popcnt) continue;
    228                     seen_point[seen] |= 1 << i;
    229                     res.emplace_back(k, i);
    230                     break;
    231                 }
    232             }
    233         }
    234 
    235         return res;
    236     }
    237 
    238     static constexpr const std::int64_t mask_field = (1 << 9) - 1;
    239     static constexpr const std::int64_t mask_value = 0x201008040201;
    240 
    241     uint16_t value[9] = {mask_field, mask_field, mask_field, mask_field, mask_field,
    242                          mask_field, mask_field, mask_field, mask_field};
    243 
    244     uint16_t ref[9] = {mask_field, mask_field, mask_field, mask_field, mask_field,
    245                        mask_field, mask_field, mask_field, mask_field};
    246 
    247     uint16_t res[9] = {0};
    248 
    249     mutable uint16_t seen_hidden[4] = {0}, seen_naked[4] = {0}, seen_point[2] = {0};
    250 };
    251 
    252 class Grid {
    253   public:
    254     Grid(const std::string &s) {
    255         int idx = 0;
    256         for (uint8_t i = 0; i < 9; i++) {
    257             for (uint8_t j = 0; j < 9; j++, idx++) {
    258                 if (s[idx] == '0') continue;
    259                 _set({i, j}, s[idx] - '1');
    260             }
    261         }
    262     }
    263 
    264     bool solve() {
    265         // clang-format off
    266         static const auto sub_op =
    267             [this](void (Grid::*op)(operation_t), Ref::changes_t (Ref::*f)() const, uint8_t subgrid) {
    268                 for_each((subgrids[subgrid].*(f))(), [this, subgrid, op](const Ref::change_t ch) {
    269                     (this->*(op))(operation_t({cord_t(subgrid), cord_t(std::get<0>(ch))}, std::get<1>(ch)));
    270                 });
    271             };
    272 
    273         static const auto row_op =
    274             [this](void (Grid::*op)(operation_t), Ref::changes_t (Ref::*f)() const, uint8_t row) {
    275                 for_each((rows[row].*(f))(), [this, row, op](const Ref::change_t ch) {
    276                     (this->*(op))(operation_t({row, std::get<0>(ch)}, std::get<1>(ch)));
    277                 });
    278             };
    279 
    280         static const auto col_op =
    281             [this](void (Grid::*op)(operation_t), Ref::changes_t (Ref::*f)() const, uint8_t col) {
    282                 for_each((cols[col].*(f))(), [this, col, op](const Ref::change_t ch) {
    283                     (this->*(op))(operation_t({std::get<0>(ch), col}, std::get<1>(ch)));
    284                 });
    285             };
    286         // clang-format on
    287 
    288         changed = true;
    289         while (changed) {
    290             changed = false;
    291 
    292             for (uint8_t subgrid = 0; subgrid < 9; subgrid++) {
    293                 sub_op(&Grid::op_set, &Ref::get_naked_singles, subgrid);
    294             }
    295 
    296             if (changed) continue;
    297 
    298             for (uint8_t idx = 0; idx < 9; idx++) {
    299                 sub_op(&Grid::op_set, &Ref::get_hidden_singles, idx);
    300                 row_op(&Grid::op_set, &Ref::get_hidden_singles, idx);
    301                 col_op(&Grid::op_set, &Ref::get_hidden_singles, idx);
    302             }
    303 
    304             if (changed) continue;
    305 
    306             for (uint8_t subgrid = 0; subgrid < 9; subgrid++) {
    307                 sub_op(&Grid::op_clear_row, &Ref::get_pointing_row, subgrid);
    308                 sub_op(&Grid::op_clear_col, &Ref::get_pointing_col, subgrid);
    309             }
    310 
    311             if (changed) continue;
    312 
    313             for (uint8_t idx = 0; idx < 9; idx++) {
    314                 sub_op(&Grid::op_mask, &Ref::get_hidden_pairs, idx);
    315                 row_op(&Grid::op_mask, &Ref::get_hidden_pairs, idx);
    316                 col_op(&Grid::op_mask, &Ref::get_hidden_pairs, idx);
    317             }
    318 
    319             if (changed) continue;
    320 
    321             for (uint8_t idx = 0; idx < 9; idx++) {
    322                 sub_op(&Grid::op_mask, &Ref::get_hidden_triplets, idx);
    323                 row_op(&Grid::op_mask, &Ref::get_hidden_triplets, idx);
    324                 col_op(&Grid::op_mask, &Ref::get_hidden_triplets, idx);
    325             }
    326 
    327             if (changed) continue;
    328 
    329             for (uint8_t idx = 0; idx < 9; idx++) {
    330                 sub_op(&Grid::op_mask, &Ref::get_hidden_quads, idx);
    331                 row_op(&Grid::op_mask, &Ref::get_hidden_quads, idx);
    332                 col_op(&Grid::op_mask, &Ref::get_hidden_quads, idx);
    333             }
    334 
    335             if (changed) continue;
    336 
    337             for (uint8_t idx = 0; idx < 9; idx++) {
    338                 sub_op(&Grid::op_clear, &Ref::get_naked_pairs, idx);
    339                 row_op(&Grid::op_clear, &Ref::get_naked_pairs, idx);
    340                 col_op(&Grid::op_clear, &Ref::get_naked_pairs, idx);
    341             }
    342 
    343             if (changed) continue;
    344 
    345             for (uint8_t idx = 0; idx < 9; idx++) {
    346                 sub_op(&Grid::op_clear, &Ref::get_naked_triplets, idx);
    347                 row_op(&Grid::op_clear, &Ref::get_naked_triplets, idx);
    348                 col_op(&Grid::op_clear, &Ref::get_naked_triplets, idx);
    349             }
    350 
    351             if (changed) continue;
    352 
    353             for (uint8_t idx = 0; idx < 9; idx++) {
    354                 sub_op(&Grid::op_clear, &Ref::get_naked_quads, idx);
    355                 row_op(&Grid::op_clear, &Ref::get_naked_quads, idx);
    356                 col_op(&Grid::op_clear, &Ref::get_naked_quads, idx);
    357             }
    358 
    359             if (changed) continue;
    360 
    361             for (uint8_t idx = 0; idx < 9; idx++) {
    362                 row_op(&Grid::op_clear_row_rel, &Ref::get_pointing_row, idx);
    363                 col_op(&Grid::op_clear_col_rel, &Ref::get_pointing_row, idx);
    364             }
    365         }
    366 
    367         return _is_finished();
    368     }
    369 
    370     void print() const {
    371         // clang-format off
    372         static const auto print_i = [this](const Ref refs[9], uint16_t (Ref::*f)(uint8_t) const, bool bits) {
    373             for (uint8_t i = 0; i < 9; i++) {
    374                 for (uint8_t j = 0; j < 9; j++) {
    375                     const cord_t subgrid = {uint8_t(i / 3), uint8_t(j / 3)};
    376                     const cord_t field = {uint8_t(i % 3), uint8_t(j % 3)};
    377 
    378                     uint16_t value = (refs[subgrid].*(f))((uint8_t)field);
    379                     if (bits) std::cout << std::bitset<9>(value) << " ";
    380                     else std::cout << value << " ";
    381 
    382                     if (j % 3 == 2) std::cout << " ";
    383                 }
    384 
    385                 std::cout << std::endl;
    386                 if (i % 3 == 2) std::cout << std::endl;
    387             }
    388         };
    389         // clang-format on
    390 
    391         std::cout << "Field: " << std::endl;
    392         print_i(subgrids, &Ref::get, true);
    393 
    394         std::cout << "Refs: " << std::endl;
    395         print_i(subgrids, &Ref::get_ref, true);
    396 
    397         std::cout << "Board: " << std::endl;
    398         print_i(subgrids, &Ref::get_value, false);
    399     }
    400 
    401   private:
    402     using operation_t = std::tuple<acord_t, uint16_t>;
    403 
    404     void op_set(operation_t op) {
    405         _set(std::get<0>(op), std::get<1>(op));
    406         changed = true;
    407     }
    408 
    409     void op_mask(operation_t op) {
    410         _mask(std::get<0>(op), std::get<1>(op));
    411         changed = true;
    412     }
    413 
    414     void op_clear(operation_t op) {
    415         _clear(std::get<0>(op), std::get<1>(op));
    416         changed = true;
    417     }
    418 
    419     void op_clear_row(operation_t op) {
    420         const auto [ab, val] = op;
    421 
    422         const auto [r1, r2] = other_subgrid_row(ab.subgrid());
    423         _clear_row(r1, ab.field(), val);
    424         _clear_row(r2, ab.field(), val);
    425 
    426         changed = true;
    427     }
    428 
    429     void op_clear_col(operation_t op) {
    430         const auto [ab, val] = op;
    431 
    432         const auto [c1, c2] = other_subgrid_col(ab.subgrid());
    433         _clear_col(c1, ab.field(), val);
    434         _clear_col(c2, ab.field(), val);
    435 
    436         changed = true;
    437     }
    438 
    439     void op_clear_row_rel(operation_t op) {
    440         const auto [ab, val] = op;
    441 
    442         const auto [r1, r2] = other_subgrid_row(ab.row());
    443         _clear_row((r1 / 3) * 3 + ab.col(), r1 % 3, val);
    444         _clear_row((r2 / 3) * 3 + ab.col(), r2 % 3, val);
    445 
    446         changed = true;
    447     }
    448 
    449     void op_clear_col_rel(operation_t op) {
    450         const auto [ab, val] = op;
    451 
    452         const auto [c1, c2] = other_subgrid_row(ab.col());
    453         _clear_col(ab.row() * 3 + c1 / 3, c1 % 3, val);
    454         _clear_col(ab.row() * 3 + c2 / 3, c2 % 3, val);
    455 
    456         changed = true;
    457     }
    458 
    459     void _set(acord_t ab, uint8_t value) {
    460         rows[ab.row()].set(ab.col(), value);
    461         cols[ab.col()].set(ab.row(), value);
    462         subgrids[ab.subgrid()].set(ab.field(), value);
    463 
    464         _clear_row(ab.subgrid(), 0, value);
    465         _clear_row(ab.subgrid(), 1, value);
    466         _clear_row(ab.subgrid(), 2, value);
    467 
    468         const auto [r1, r2] = other_subgrid_row(ab.subgrid());
    469         _clear_row(r1, ab.field().row(), value);
    470         _clear_row(r2, ab.field().row(), value);
    471 
    472         const auto [c1, c2] = other_subgrid_col(ab.subgrid());
    473         _clear_col(c1, ab.field().col(), value);
    474         _clear_col(c2, ab.field().col(), value);
    475     }
    476 
    477     void _mask(acord_t ab, uint16_t mask) {
    478         while (mask) {
    479             const uint8_t idx = std::countr_zero(mask);
    480             _clear(ab, idx);
    481             mask ^= 1ull << idx;
    482         }
    483     }
    484 
    485     void _clear(acord_t ab, uint8_t value) {
    486         subgrids[ab.subgrid()].clear(ab.field(), value);
    487 
    488         rows[ab.row()].clear(ab.col(), value);
    489         cols[ab.col()].clear(ab.row(), value);
    490     }
    491 
    492     void _clear_row(cord_t sg, uint8_t row, uint8_t value) {
    493         for (uint8_t i = 0; i < 3; i++) {
    494             _clear({sg, {row, i}}, value);
    495         }
    496     }
    497 
    498     void _clear_col(cord_t sg, uint8_t col, uint8_t value) {
    499         for (uint8_t i = 0; i < 3; i++) {
    500             _clear({sg, {i, col}}, value);
    501         }
    502     }
    503 
    504     bool _is_finished(uint8_t subgrid) {
    505         for (uint8_t i = 0; i < 9; i++) {
    506             if (!subgrids[subgrid].get_value(i)) return false;
    507         }
    508         return true;
    509     }
    510 
    511     bool _is_finished() {
    512         for (uint8_t i = 0; i < 9; i++) {
    513             if (!_is_finished(i)) return false;
    514         }
    515         return true;
    516     }
    517 
    518     Ref subgrids[9];
    519     Ref rows[9];
    520     Ref cols[9];
    521 
    522     bool changed = false;
    523 };
    524 
    525 int main(const int argc, const char *argv[]) {
    526 
    527     for (int i = 1; i < argc; i++) {
    528         Grid g(argv[i]);
    529 
    530         g.print();
    531         std::cout << (g.solve() ? "solved" : "unable to solve") << std::endl;
    532         g.print();
    533     }
    534 
    535     return 0;
    536 }