stellar

Stellar - UCI Chess engine written in C++20
git clone git://git.dimitrijedobrota.com/stellar.git
Log | Files | Refs | README | LICENSE

perft.cpp (4402B)


      1 #include <iomanip>
      2 #include <semaphore>
      3 #include <thread>
      4 
      5 #include "attack.hpp"
      6 #include "board.hpp"
      7 #include "move.hpp"
      8 #include "movelist.hpp"
      9 #include "utils.hpp"
     10 
     11 // FEN debug positions
     12 #define tricky_position "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1 "
     13 
     14 enum {
     15     THREAD_MAX = 64
     16 };
     17 
     18 class Perft {
     19   public:
     20     using semaphore_t = std::counting_semaphore<THREAD_MAX>;
     21     Perft(semaphore_t &sem) : sem(sem) {}
     22     void operator()(const Board &board_start, Move move, int depth) {
     23         Board board = board_start;
     24         if (!move.make(board)) return;
     25         sem.acquire();
     26         // debug(board_start, move, board);
     27 
     28         if (depth > 1) {
     29             test(board, depth - 1);
     30         } else {
     31             score(board, move);
     32         }
     33 
     34         mutex.acquire();
     35         result += local;
     36         mutex.release();
     37         sem.release();
     38     }
     39 
     40     struct result_t {
     41         U64 node = 0;
     42 #ifdef USE_FULL_COUNT
     43         U64 check = 0;
     44         U64 castle = 0;
     45         U64 promote = 0;
     46         U64 capture = 0;
     47         U64 enpassant = 0;
     48 #endif
     49         result_t &operator+=(const result_t res) {
     50             node += res.node;
     51 #ifdef USE_FULL_COUNT
     52             check += res.check;
     53             castle += res.castle;
     54             promote += res.promote;
     55             capture += res.capture;
     56             enpassant += res.enpassant;
     57 #endif
     58             return *this;
     59         }
     60     };
     61 
     62     static result_t result;
     63 
     64   private:
     65     void test(const Board &board, int depth) {
     66         const MoveList list(board);
     67         for (int i = 0; i < list.size(); i++) {
     68             Board copy = board;
     69             if (!list[i].make(copy)) continue;
     70             // debug(board, list[i], copy);
     71             if (depth != 1) test(copy, depth - 1);
     72             else
     73                 score(copy, list[i]);
     74         }
     75     }
     76 
     77     void debug(const Board &before, Move move, const Board &after) {
     78         std::cout << std::setw(16) << std::hex << before.get_hash() << " ";
     79         std::cout << move << " ";
     80         std::cout << std::setw(16) << std::hex << after.get_hash() << "\n";
     81     }
     82 
     83     void score(const Board &board, Move move) {
     84         local.node++;
     85 #ifdef USE_FULL_COUNT
     86         if (board.is_check()) local.check++;
     87         if (move.is_capture()) local.capture++;
     88         if (move.is_enpassant()) local.enpassant++;
     89         if (move.is_castle()) local.castle++;
     90         if (move.is_promote()) local.promote++;
     91 #endif
     92     }
     93     result_t local;
     94     semaphore_t &sem;
     95     static std::binary_semaphore mutex;
     96 };
     97 
     98 std::binary_semaphore Perft::mutex{1};
     99 Perft::result_t Perft::result;
    100 
    101 void perft_test(const char *fen, int depth, int thread_num) {
    102     const Board board = Board(fen);
    103     const MoveList list = MoveList(board);
    104     std::vector<std::thread> threads(list.size());
    105 
    106     Perft::semaphore_t sem(thread_num);
    107 
    108     int index = 0;
    109     for (int i = 0; i < list.size(); i++)
    110         threads[index++] = std::thread(Perft(sem), board, list[i], depth);
    111 
    112     for (auto &thread : threads)
    113         thread.join();
    114 
    115     std::cout << std::dec;
    116     std::cout << "     Nodes: " << Perft::result.node << "\n";
    117 #ifdef USE_FULL_COUNT
    118     std::cout << "  Captures: " << Perft::result.capture << "\n";
    119     std::cout << "Enpassants: " << Perft::result.enpassant << "\n";
    120     std::cout << "   Castles: " << Perft::result.castle << "\n";
    121     std::cout << "Promotions: " << Perft::result.promote << "\n";
    122     std::cout << "    Checks: " << Perft::result.check << "\n";
    123 #endif
    124 }
    125 
    126 void usage(const char *program) {
    127     std::cout << "Usage: " << program;
    128     std::cout << " [-h]";
    129     std::cout << " [-t thread number]";
    130     std::cout << " [-d depth]";
    131     std::cout << " [-f fen]" << std::endl;
    132 }
    133 
    134 int main(int argc, char *argv[]) {
    135     int c = 0, depth = 1, thread_num = 1;
    136     std::string s(start_position);
    137     const char *fen = s.data();
    138     while ((c = getopt(argc, argv, "ht:f:d:")) != -1) {
    139         switch (c) {
    140         case 't':
    141             thread_num = atoi(optarg);
    142             if (thread_num <= 0 && thread_num > THREAD_MAX) abort();
    143             break;
    144         case 'f': fen = optarg; break;
    145         case 'd':
    146             depth = atoi(optarg);
    147             if (depth <= 0) abort();
    148             break;
    149         case 'h': usage(argv[0]); return 1;
    150         default: usage(argv[0]); abort();
    151         }
    152     }
    153 
    154     attack::init();
    155     zobrist::init();
    156     perft_test(fen, depth, thread_num);
    157     return 0;
    158 }