perft.cpp (4402B)
1 #include <iomanip> 2 #include <semaphore> 3 #include <thread> 4 5 #include "attack.hpp" 6 #include "board.hpp" 7 #include "move.hpp" 8 #include "movelist.hpp" 9 #include "utils.hpp" 10 11 // FEN debug positions 12 #define tricky_position "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1 " 13 14 enum { 15 THREAD_MAX = 64 16 }; 17 18 class Perft { 19 public: 20 using semaphore_t = std::counting_semaphore<THREAD_MAX>; 21 Perft(semaphore_t &sem) : sem(sem) {} 22 void operator()(const Board &board_start, Move move, int depth) { 23 Board board = board_start; 24 if (!move.make(board)) return; 25 sem.acquire(); 26 // debug(board_start, move, board); 27 28 if (depth > 1) { 29 test(board, depth - 1); 30 } else { 31 score(board, move); 32 } 33 34 mutex.acquire(); 35 result += local; 36 mutex.release(); 37 sem.release(); 38 } 39 40 struct result_t { 41 U64 node = 0; 42 #ifdef USE_FULL_COUNT 43 U64 check = 0; 44 U64 castle = 0; 45 U64 promote = 0; 46 U64 capture = 0; 47 U64 enpassant = 0; 48 #endif 49 result_t &operator+=(const result_t res) { 50 node += res.node; 51 #ifdef USE_FULL_COUNT 52 check += res.check; 53 castle += res.castle; 54 promote += res.promote; 55 capture += res.capture; 56 enpassant += res.enpassant; 57 #endif 58 return *this; 59 } 60 }; 61 62 static result_t result; 63 64 private: 65 void test(const Board &board, int depth) { 66 const MoveList list(board); 67 for (int i = 0; i < list.size(); i++) { 68 Board copy = board; 69 if (!list[i].make(copy)) continue; 70 // debug(board, list[i], copy); 71 if (depth != 1) test(copy, depth - 1); 72 else 73 score(copy, list[i]); 74 } 75 } 76 77 void debug(const Board &before, Move move, const Board &after) { 78 std::cout << std::setw(16) << std::hex << before.get_hash() << " "; 79 std::cout << move << " "; 80 std::cout << std::setw(16) << std::hex << after.get_hash() << "\n"; 81 } 82 83 void score(const Board &board, Move move) { 84 local.node++; 85 #ifdef USE_FULL_COUNT 86 if (board.is_check()) local.check++; 87 if (move.is_capture()) local.capture++; 88 if (move.is_enpassant()) local.enpassant++; 89 if (move.is_castle()) local.castle++; 90 if (move.is_promote()) local.promote++; 91 #endif 92 } 93 result_t local; 94 semaphore_t &sem; 95 static std::binary_semaphore mutex; 96 }; 97 98 std::binary_semaphore Perft::mutex{1}; 99 Perft::result_t Perft::result; 100 101 void perft_test(const char *fen, int depth, int thread_num) { 102 const Board board = Board(fen); 103 const MoveList list = MoveList(board); 104 std::vector<std::thread> threads(list.size()); 105 106 Perft::semaphore_t sem(thread_num); 107 108 int index = 0; 109 for (int i = 0; i < list.size(); i++) 110 threads[index++] = std::thread(Perft(sem), board, list[i], depth); 111 112 for (auto &thread : threads) 113 thread.join(); 114 115 std::cout << std::dec; 116 std::cout << " Nodes: " << Perft::result.node << "\n"; 117 #ifdef USE_FULL_COUNT 118 std::cout << " Captures: " << Perft::result.capture << "\n"; 119 std::cout << "Enpassants: " << Perft::result.enpassant << "\n"; 120 std::cout << " Castles: " << Perft::result.castle << "\n"; 121 std::cout << "Promotions: " << Perft::result.promote << "\n"; 122 std::cout << " Checks: " << Perft::result.check << "\n"; 123 #endif 124 } 125 126 void usage(const char *program) { 127 std::cout << "Usage: " << program; 128 std::cout << " [-h]"; 129 std::cout << " [-t thread number]"; 130 std::cout << " [-d depth]"; 131 std::cout << " [-f fen]" << std::endl; 132 } 133 134 int main(int argc, char *argv[]) { 135 int c = 0, depth = 1, thread_num = 1; 136 std::string s(start_position); 137 const char *fen = s.data(); 138 while ((c = getopt(argc, argv, "ht:f:d:")) != -1) { 139 switch (c) { 140 case 't': 141 thread_num = atoi(optarg); 142 if (thread_num <= 0 && thread_num > THREAD_MAX) abort(); 143 break; 144 case 'f': fen = optarg; break; 145 case 'd': 146 depth = atoi(optarg); 147 if (depth <= 0) abort(); 148 break; 149 case 'h': usage(argv[0]); return 1; 150 default: usage(argv[0]); abort(); 151 } 152 } 153 154 attack::init(); 155 zobrist::init(); 156 perft_test(fen, depth, thread_num); 157 return 0; 158 }