stellar

Stellar - UCI Chess engine written in C++20
git clone git://git.dimitrijedobrota.com/stellar.git
Log | Files | Refs | README | LICENSE

engine.cpp (15451B)


      1 #include <algorithm>
      2 #include <cctype>
      3 #include <cstdio>
      4 #include <cstdlib>
      5 #include <cstring>
      6 
      7 #include "attack.hpp"
      8 #include "board.hpp"
      9 #include "evaluate.hpp"
     10 #include "move.hpp"
     11 #include "movelist.hpp"
     12 #include "piece.hpp"
     13 #include "repetition.hpp"
     14 #include "score.hpp"
     15 #include "timer.hpp"
     16 #include "uci.hpp"
     17 #include "utils.hpp"
     18 
     19 enum {
     20     FULL_DEPTH = 4,
     21     REDUCTION_LIMIT = 3,
     22     REDUCTION_MOVE = 2
     23 };
     24 
     25 enum {
     26     WINDOW = 50
     27 };
     28 
     29 namespace engine {
     30 
     31 struct Hashe {
     32     enum class Flag : uint8_t {
     33         Exact,
     34         Alpha,
     35         Beta
     36     };
     37     U64 key;
     38     Move best;
     39     uint8_t depth;
     40     int16_t score;
     41     Flag flag;
     42 };
     43 
     44 template <U64 size> class TTable_internal {
     45   public:
     46     static inline constexpr const int16_t unknown = 32500;
     47 
     48     static void clear() {
     49         memset(table.data(), 0x00, size * sizeof(Hashe));
     50 #ifdef USE_STATS
     51         accessed = 0, rewrite = 0, miss = 0;
     52 #endif
     53     };
     54 
     55 #ifdef USE_STATS
     56     static void print() {
     57         std::cout << "Transposition table: " << std::endl;
     58         std::cout << "\tSize:    " << size << " entries (" << sizeof(Hashe) << "B per entry)" << std::endl;
     59         std::cout << "\tSize:    " << std::fixed << std::setprecision(2)
     60                   << (double)size * sizeof(Hashe) / (1 << 20) << "MB" << std::endl;
     61         std::cout << "\tReads:   " << accessed << std::endl;
     62         std::cout << "\tMisses:  " << miss << "(" << (double)miss / accessed << "%)" << std::endl;
     63         std::cout << "\tRewrite: " << rewrite << std::endl;
     64         std::cout << "\tUsed     " << (double)used() / size << "%" << std::endl;
     65     }
     66 
     67     static U64 used() {
     68         U64 res = 0;
     69 
     70         for (int i = 0; i < size; i++) {
     71             if (table[i].key) res++;
     72         }
     73 
     74         return res;
     75     }
     76 #endif
     77 
     78     static int16_t read(const Board &board, int ply, Move *best, int16_t alpha, int16_t beta, uint8_t depth) {
     79         U64 hash = board.get_hash();
     80         const Hashe &phashe = table[hash % size];
     81 
     82 #ifdef USE_STATS
     83         accessed++;
     84 #endif
     85 
     86         if (phashe.key == hash) {
     87             if (phashe.depth >= depth) {
     88                 int16_t score = phashe.score;
     89 
     90                 if (score < -MATE_SCORE) score += ply;
     91                 if (score > MATE_SCORE) score -= ply;
     92 
     93                 if (phashe.flag == Hashe::Flag::Exact) return score;
     94                 if ((phashe.flag == Hashe::Flag::Alpha) && (score <= alpha)) return alpha;
     95                 if ((phashe.flag == Hashe::Flag::Beta) && (score >= beta)) return beta;
     96             }
     97             *best = phashe.best;
     98         }
     99 #ifdef USE_STATS
    100         else {
    101             miss++;
    102         }
    103 #endif
    104         return unknown;
    105     }
    106 
    107     static void write(const Board &board, int ply, Move best, int16_t score, uint8_t depth,
    108                       Hashe::Flag flag) {
    109         U64 hash = board.get_hash();
    110         Hashe &phashe = table[hash % size];
    111 
    112         if (score < -MATE_SCORE) score += ply;
    113         if (score > MATE_SCORE) score -= ply;
    114 
    115         if (phashe.key == hash) {
    116             if (phashe.depth > depth) return;
    117         }
    118 #ifdef USE_STATS
    119         else {
    120             rewrite++;
    121         }
    122 #endif
    123 
    124         phashe = {hash, best, depth, score, flag};
    125     }
    126 
    127   private:
    128     static std::array<Hashe, size> table;
    129 
    130 #ifdef USE_STATS
    131     static U64 accessed, rewrite, miss;
    132 #endif
    133 };
    134 
    135 template <U64 size> std::array<Hashe, size> TTable_internal<size>::table;
    136 
    137 #ifdef USE_STATS
    138 template <U64 size> U64 TTable_internal<size>::accessed = 0;
    139 template <U64 size> U64 TTable_internal<size>::rewrite = 0;
    140 template <U64 size> U64 TTable_internal<size>::miss = 0;
    141 #endif
    142 
    143 using TTable = TTable_internal<C64(0x2000023)>;
    144 
    145 TTable ttable;
    146 
    147 class PVTable {
    148   public:
    149     Move best(uint8_t ply = 0) { return table[0][ply]; }
    150 
    151     void start(uint8_t ply) { length[ply] = ply; }
    152     void store(Move move, uint8_t ply) {
    153         table[ply][ply] = move;
    154         for (uint8_t i = ply + 1; i < length[ply + 1]; i++)
    155             table[ply][i] = table[ply + 1][i];
    156         length[ply] = length[ply + 1];
    157     }
    158 
    159     friend std::ostream &operator<<(std::ostream &os, const PVTable &pvtable);
    160 
    161   private:
    162     Move table[MAX_PLY][MAX_PLY] = {{}};
    163     uint8_t length[MAX_PLY] = {0};
    164 };
    165 
    166 std::ostream &operator<<(std::ostream &os, const PVTable &pvtable) {
    167     for (uint8_t i = 0; i < pvtable.length[0]; i++)
    168         os << pvtable.table[0][i] << " ";
    169     return os;
    170 }
    171 
    172 static const uci::Settings *settings = nullptr;
    173 static Board board;
    174 static repetition::Table rtable;
    175 
    176 static PVTable pvtable;
    177 
    178 static Move killer[2][MAX_PLY];
    179 static U32 history[12][64];
    180 static bool follow_pv;
    181 static U64 nodes;
    182 static uint8_t ply;
    183 
    184 U32 inline move_score(const Move move) {
    185     static constexpr const uint16_t capture[6][6] = {
    186         // clang-format off
    187         {105, 205, 305, 405, 505, 605},
    188         {104, 204, 304, 404, 504, 604},
    189         {103, 203, 303, 403, 503, 603},
    190         {102, 202, 302, 402, 502, 602},
    191         {101, 201, 301, 401, 501, 601},
    192         {100, 200, 300, 400, 500, 600},
    193         // clang-format on
    194     };
    195 
    196     const Type type = board.get_square_piece_type(move.source());
    197     if (move.is_capture()) {
    198         const Type captured = board.get_square_piece_type(move.target());
    199         return capture[type][captured] + 10000;
    200     }
    201     if (killer[0][ply] == move) return 9000;
    202     if (killer[1][ply] == move) return 8000;
    203     return history[piece::get_index(type, board.get_side())][move.target()];
    204 }
    205 
    206 void move_list_sort(MoveList &list, std::vector<int> &score, int crnt) {
    207     for (int i = crnt + 1; i < list.size(); i++) {
    208         if (score[crnt] < score[i]) {
    209             std::swap(list[crnt], list[i]);
    210             std::swap(score[crnt], score[i]);
    211         }
    212     }
    213 }
    214 
    215 std::vector<int> move_list_score(MoveList &list, const Move best) {
    216     std::vector<int> score(list.size(), 0);
    217 
    218     bool best_found = false;
    219     for (int i = 0; i < list.size(); i++) {
    220         score[i] = move_score(list[i]);
    221         if (list[i] == best) {
    222             score[i] = 30000;
    223             best_found = true;
    224         }
    225     }
    226 
    227     if (best_found) return score;
    228 
    229     if (ply && follow_pv) {
    230         follow_pv = false;
    231         for (int i = 0; i < list.size(); i++) {
    232             if (list[i] == pvtable.best(ply)) {
    233                 score[i] = 20000;
    234                 follow_pv = true;
    235                 break;
    236             }
    237         }
    238     }
    239 
    240     return score;
    241 }
    242 
    243 int stats_move_make(Board &copy, const Move move) {
    244     copy = board;
    245     if (!move.make(board)) {
    246         board = copy;
    247         return 0;
    248     }
    249     ply++;
    250     rtable.push_hash(copy.get_hash());
    251     if (!move.is_repeatable()) rtable.push_null();
    252     return 1;
    253 }
    254 
    255 void stats_move_make_pruning(Board &copy) {
    256     copy = board;
    257     board.switch_side();
    258     board.set_enpassant(Square::no_sq);
    259     ply++;
    260 }
    261 
    262 void stats_move_unmake_pruning(Board &copy) {
    263     board = copy;
    264     ply--;
    265 }
    266 
    267 void stats_move_unmake(Board &copy, const Move move) {
    268     board = copy;
    269     if (!move.is_repeatable()) rtable.pop();
    270     rtable.pop();
    271     ply--;
    272 }
    273 
    274 int16_t quiescence(int16_t alpha, int16_t beta) {
    275     pvtable.start(ply);
    276     if ((nodes & 2047) == 0) {
    277         uci::communicate(settings);
    278         if (settings->stopped) return 0;
    279     }
    280 
    281     nodes++;
    282 
    283     int score = evaluate::score_position(board);
    284     if (ply > MAX_PLY - 1) return score;
    285     if (score >= beta) return beta;
    286     if (score > alpha) alpha = score;
    287 
    288     Board copy;
    289     MoveList list(board, true);
    290     std::vector<int> listScore = move_list_score(list, Move());
    291     for (int i = 0; i < list.size(); i++) {
    292         move_list_sort(list, listScore, i);
    293         const Move move = list[i];
    294         if (!stats_move_make(copy, move)) continue;
    295         score = -quiescence(-beta, -alpha);
    296         stats_move_unmake(copy, move);
    297 
    298         if (settings->stopped) return 0;
    299         if (score > alpha) {
    300             alpha = score;
    301             pvtable.store(move, ply);
    302             if (score >= beta) return beta;
    303         }
    304     }
    305 
    306     return alpha;
    307 }
    308 
    309 int16_t negamax(int16_t alpha, int16_t beta, uint8_t depth, bool null) {
    310     int pv_node = (beta - alpha) > 1;
    311     Hashe::Flag flag = Hashe::Flag::Alpha;
    312     int futility = 0;
    313     Move bestMove;
    314     Board copy;
    315 
    316     pvtable.start(ply);
    317     if ((nodes & 2047) == 0) {
    318         uci::communicate(settings);
    319         if (settings->stopped) return 0;
    320     }
    321 
    322     // && fifty >= 100
    323     if (ply && rtable.is_repetition(board.get_hash())) return 0;
    324 
    325     int16_t score = ttable.read(board, ply, &bestMove, alpha, beta, depth);
    326     if (ply && score != TTable::unknown && !pv_node) return score;
    327 
    328     bool isCheck = board.is_check();
    329     if (isCheck) depth++;
    330 
    331     if (depth == 0) {
    332         nodes++;
    333         int16_t score = quiescence(alpha, beta);
    334         // ttable_write(board, ply, bestMove, score, depth, HasheFlag::Exact);
    335         return score;
    336     }
    337 
    338     if (alpha < -MATE_VALUE) alpha = -MATE_VALUE;
    339     if (beta > MATE_VALUE - 1) beta = MATE_VALUE - 1;
    340     if (alpha >= beta) return alpha;
    341     // if (ply > MAX_PLY - 1) return evaluate::score_position(board);
    342 
    343     if (!pv_node && !isCheck) {
    344         static constexpr const U32 score_pawn = score::get(PAWN);
    345         int16_t staticEval = evaluate::score_position(board);
    346 
    347         // evaluation pruning
    348         if (depth < 3 && abs(beta - 1) > -MATE_VALUE + 100) {
    349             int16_t marginEval = score_pawn * depth;
    350             if (staticEval - marginEval >= beta) return staticEval - marginEval;
    351         }
    352 
    353         if (settings->stopped) return 0;
    354 
    355         if (null) {
    356             // null move pruning
    357             if (ply && depth > 2 && staticEval >= beta) {
    358                 stats_move_make_pruning(copy);
    359                 score = -negamax(-beta, -beta + 1, depth - 1 - REDUCTION_MOVE, false);
    360                 stats_move_unmake_pruning(copy);
    361                 if (score >= beta) return beta;
    362             }
    363 
    364             // razoring
    365             score = staticEval + score_pawn;
    366             int16_t scoreNew = quiescence(alpha, beta);
    367 
    368             if (score < beta && depth == 1) {
    369                 return (scoreNew > score) ? scoreNew : score;
    370             }
    371 
    372             score += score_pawn;
    373             if (score < beta && depth < 4) {
    374                 if (scoreNew < beta) return (scoreNew > score) ? scoreNew : score;
    375             }
    376         }
    377 
    378         // futility pruning condition
    379         static constexpr const int16_t margin[] = {
    380             0,
    381             score::get(PAWN),
    382             score::get(KNIGHT),
    383             score::get(ROOK),
    384         };
    385         if (depth < 4 && abs(alpha) < MATE_SCORE && staticEval + margin[depth] <= alpha) futility = 1;
    386     }
    387 
    388     uint8_t legal_moves = 0;
    389     uint8_t searched = 0;
    390 
    391     MoveList list(board);
    392     std::vector<int> listScore = move_list_score(list, bestMove);
    393     for (int i = 0; i < list.size(); i++) {
    394         move_list_sort(list, listScore, i);
    395         const Move move = list[i];
    396         if (!stats_move_make(copy, move)) continue;
    397         legal_moves++;
    398 
    399         // futility pruning
    400         if (futility && searched && !move.is_capture() && !move.is_promote() && !board.is_check()) {
    401             stats_move_unmake(copy, move);
    402             continue;
    403         }
    404 
    405         if (!searched) {
    406             score = -negamax(-beta, -alpha, depth - 1, true);
    407         } else {
    408             // Late Move Reduction
    409             if (!pv_node && searched >= FULL_DEPTH && depth >= REDUCTION_LIMIT && !isCheck &&
    410                 !move.is_capture() && !move.is_promote() &&
    411                 (move.source() != killer[0][ply].source() || move.target() != killer[0][ply].target()) &&
    412                 (move.source() != killer[1][ply].source() || move.target() != killer[1][ply].target())) {
    413                 score = -negamax(-alpha - 1, -alpha, depth - 2, true);
    414             } else
    415                 score = alpha + 1;
    416 
    417             // Principal Variation Search
    418             if (score > alpha) {
    419                 score = -negamax(-alpha - 1, -alpha, depth - 1, true);
    420 
    421                 // if fail research
    422                 if ((score > alpha) && (score < beta)) score = -negamax(-beta, -alpha, depth - 1, true);
    423             }
    424         }
    425 
    426         stats_move_unmake(copy, move);
    427         searched++;
    428 
    429         if (settings->stopped) return 0;
    430         if (score > alpha) {
    431             if (!move.is_capture()) {
    432                 const Type piece = board.get_square_piece_type(move.source());
    433                 history[piece::get_index(piece, board.get_side())][move.target()] += depth;
    434             }
    435 
    436             alpha = score;
    437             flag = Hashe::Flag::Exact;
    438             bestMove = move;
    439             pvtable.store(move, ply);
    440 
    441             if (score >= beta) {
    442                 ttable.write(board, ply, bestMove, beta, depth, Hashe::Flag::Beta);
    443 
    444                 if (!move.is_capture()) {
    445                     killer[1][ply] = killer[0][ply];
    446                     killer[0][ply] = move;
    447                 }
    448 
    449                 return beta;
    450             }
    451         }
    452     }
    453 
    454     if (legal_moves == 0) {
    455         if (isCheck) return -MATE_VALUE + ply;
    456         else
    457             return 0;
    458     }
    459 
    460     ttable.write(board, ply, bestMove, alpha, depth, flag);
    461     return alpha;
    462 }
    463 
    464 Move search_position(const uci::Settings &settingsr) {
    465     int16_t alpha = -SCORE_INFINITY, beta = SCORE_INFINITY;
    466     settings = &settingsr;
    467 
    468     if (settings->newgame) ttable.clear();
    469 
    470     rtable.clear();
    471     board = settings->board;
    472     for (int i = 0; i < settings->madeMoves.size(); i++) {
    473         rtable.push_hash(board.get_hash());
    474         settings->madeMoves[i].make(board);
    475         if (!settings->madeMoves[i].is_repeatable()) rtable.clear();
    476     }
    477 
    478     ply = 0;
    479     nodes = 0;
    480     settings->stopped = false;
    481     memset(killer, 0x00, sizeof(killer));
    482     memset(history, 0x00, sizeof(history));
    483     rtable = repetition::Table();
    484 
    485     Move lastBest;
    486 
    487     uint64_t time_last = timer::get_ms();
    488     uint8_t max_depth = settings->depth ? settings->depth : MAX_PLY;
    489     for (uint8_t depth = 1; depth <= max_depth; depth++) {
    490         lastBest = pvtable.best();
    491         follow_pv = true;
    492         int16_t score = negamax(alpha, beta, depth, true);
    493 
    494         uci::communicate(settings);
    495         if (settings->stopped) break;
    496 
    497         if ((score <= alpha) || (score >= beta)) {
    498             alpha = -SCORE_INFINITY;
    499             beta = SCORE_INFINITY;
    500             depth--;
    501             continue;
    502         }
    503 
    504         alpha = score - WINDOW;
    505         beta = score + WINDOW;
    506 
    507         uint8_t mate_ply = 0xFF;
    508         if (score > -MATE_VALUE && score < -MATE_SCORE) {
    509             mate_ply = (score + MATE_VALUE) / 2 + 1;
    510             std::cout << "info score mate -" << (int)mate_ply;
    511         } else if (score > MATE_SCORE && score < MATE_VALUE) {
    512             mate_ply = (MATE_VALUE - score) / 2 + 1;
    513             std::cout << "info score mate " << (int)mate_ply;
    514         } else {
    515             std::cout << "info score cp " << score;
    516         }
    517 
    518         std::cout << " depth " << (unsigned)depth;
    519         std::cout << " nodes " << nodes;
    520         std::cout << " time " << timer::get_ms() - settings->starttime;
    521         std::cout << " pv " << pvtable << std::endl;
    522 
    523         if (depth >= mate_ply) break;
    524 
    525         uint64_t time_crnt = timer::get_ms();
    526         if (!settings->depth && 2 * time_crnt - time_last > settings->stoptime) break;
    527         time_last = time_crnt;
    528     }
    529 
    530     settings->board = board;
    531     return !settings->stopped ? pvtable.best() : lastBest;
    532 }
    533 } // namespace engine
    534 
    535 int main() {
    536     attack::init();
    537     zobrist::init();
    538     uci::loop();
    539 #ifdef USE_STATS
    540     engine::ttable.print();
    541 #endif
    542     return 0;
    543 }