commit b1e0eabe0aa23b9f0a858437535b4f2199cceb66
parent 64a67cd47dc75e777739555b0380a7134fe1cef8
Author: Dimitrije Dobrota <mail@dimitrijedobrota.com>
Date: Mon, 7 Aug 2023 21:02:32 +0200
Implement transposition tables
Diffstat:
6 files changed, 166 insertions(+), 29 deletions(-)
diff --git a/src/engine/CMakeLists.txt b/src/engine/CMakeLists.txt
@@ -1,4 +1,7 @@
-add_executable(engine engine.c)
+add_executable(engine
+ engine.c
+ transposition.c
+)
target_link_libraries(engine
PRIVATE attacks
diff --git a/src/engine/engine.c b/src/engine/engine.c
@@ -8,6 +8,7 @@
#include "moves.h"
#include "perft.h"
#include "score.h"
+#include "transposition.h"
#include "utils.h"
#include "zobrist.h"
@@ -19,20 +20,17 @@
#define REDUCTION_LIMIT 3
#define REDUCTION_MOVE 2
-#define INFINITY 50000
-#define MATE_VALUE 49000
-#define MATE_SCORE 48000
-
#define WINDOW 50
typedef struct Stats Stats;
struct Stats {
- long nodes;
- int ply;
- int pv_length[MAX_PLY];
Move pv_table[MAX_PLY][MAX_PLY];
Move killer_moves[2][MAX_PLY];
U32 history_moves[16][64];
+ TTable *ttable;
+ long nodes;
+ int ply;
+ int pv_length[MAX_PLY];
int follow_pv, score_pv;
};
@@ -118,10 +116,10 @@ int evaluate(const Board *board) {
int quiescence(Stats *stats, const Board *board, int alpha, int beta) {
stats->nodes++;
- int eval = evaluate(board);
- if (stats->ply > MAX_PLY - 1) return eval;
- if (eval >= beta) return beta;
- if (eval > alpha) alpha = eval;
+ int score = evaluate(board);
+ if (score >= beta) return beta;
+ if (score > alpha) alpha = score;
+ if (stats->ply > MAX_PLY - 1) return score;
Board copy;
MoveList moves;
@@ -135,7 +133,7 @@ int quiescence(Stats *stats, const Board *board, int alpha, int beta) {
if (move_make(move, ©, 1) == 0) continue;
stats->ply++;
- int score = -quiescence(stats, ©, -beta, -alpha);
+ score = -quiescence(stats, ©, -beta, -alpha);
stats->ply--;
if (score > alpha) {
@@ -148,15 +146,28 @@ int quiescence(Stats *stats, const Board *board, int alpha, int beta) {
}
int negamax(Stats *stats, const Board *board, int alpha, int beta, int depth) {
+ HasheFlag flag = flagAlpha;
+ U64 bhash = board_hash(board);
+
stats->pv_length[stats->ply] = stats->ply;
+
+ int pv_node = (beta - alpha) > 1;
+ int score =
+ ttable_read(stats->ttable, bhash, alpha, beta, depth, stats->ply);
+ if (stats->ply && score != TTABLE_UNKNOWN && !pv_node) return score;
+
stats->nodes++;
- if (depth == 0) return quiescence(stats, board, alpha, beta);
- if (stats->ply > MAX_PLY - 1) return evaluate(board);
+ if (depth == 0) {
+ int score = quiescence(stats, board, alpha, beta);
+ ttable_write(stats->ttable, bhash, score, depth, stats->ply, flagExact);
+ return score;
+ }
- // if (alpha < -MATE_VALUE) alpha = -MATE_VALUE;
- // if (beta > MATE_VALUE - 1) beta = MATE_VALUE - 1;
- // if (alpha >= beta) return alpha;
+ if (alpha < -MATE_VALUE) alpha = -MATE_VALUE;
+ if (beta > MATE_VALUE - 1) beta = MATE_VALUE - 1;
+ if (alpha >= beta) return alpha;
+ if (stats->ply > MAX_PLY - 1) return evaluate(board);
int isCheck = board_isCheck(board);
if (isCheck) depth++;
@@ -168,8 +179,10 @@ int negamax(Stats *stats, const Board *board, int alpha, int beta, int depth) {
board_side_switch(©);
board_enpassant_set(©, no_sq);
- int score = -negamax(stats, ©, -beta, -beta + 1,
- depth - 1 - REDUCTION_MOVE);
+ stats->ply++;
+ score = -negamax(stats, ©, -beta, -beta + 1,
+ depth - 1 - REDUCTION_MOVE);
+ stats->ply--;
if (score >= beta) return beta;
}
@@ -218,24 +231,31 @@ int negamax(Stats *stats, const Board *board, int alpha, int beta, int depth) {
searched++;
if (score > alpha) {
- if (!move_capture(move))
+ if (!move_capture(move)) {
stats->history_moves[piece_index(move_piece(move))]
[move_target(move)] += depth;
+ }
- alpha = score;
stats->pv_table[stats->ply][stats->ply] = move;
for (int i = stats->ply + 1; i < stats->pv_length[stats->ply + 1];
- i++)
+ i++) {
stats->pv_table[stats->ply][i] =
stats->pv_table[stats->ply + 1][i];
+ }
stats->pv_length[stats->ply] = stats->pv_length[stats->ply + 1];
+ alpha = score;
+ flag = flagExact;
+
if (score >= beta) {
if (!move_capture(move)) {
stats->killer_moves[1][stats->ply] =
stats->killer_moves[0][stats->ply];
stats->killer_moves[0][stats->ply] = move;
}
+
+ ttable_write(stats->ttable, board_hash(©), beta, depth,
+ stats->ply, flagBeta);
return beta;
}
}
@@ -247,6 +267,7 @@ int negamax(Stats *stats, const Board *board, int alpha, int beta, int depth) {
return 0;
}
+ ttable_write(stats->ttable, bhash, alpha, depth, stats->ply, flag);
return alpha;
}
@@ -256,24 +277,26 @@ void move_print_UCI(Move move) {
if (move_promote(move)) printf("%c", piece_asci(move_piece_promote(move)));
}
+TTable *ttable = NULL;
void search_position(const Board *board, int depth) {
- Stats stats = {0};
+ Stats stats = {.ttable = ttable, 0};
int alpha = -INFINITY, beta = INFINITY;
for (int crnt = 1; crnt <= depth;) {
stats.follow_pv = 1;
int score = negamax(&stats, board, alpha, beta, crnt);
- if (score <= alpha || score >= beta) {
+ if ((score <= alpha) || (score >= beta)) {
alpha = -INFINITY;
beta = INFINITY;
+ continue;
}
- alpha = score - 50;
- beta = score + 50;
+ alpha = score - WINDOW;
+ beta = score + WINDOW;
if (score > -MATE_VALUE && score < -MATE_SCORE) {
printf("info score mate %d depth %d nodes %ld pv ",
- (MATE_VALUE - score) / 2 + 1, crnt, stats.nodes);
+ -(score + MATE_VALUE) / 2 - 1, crnt, stats.nodes);
} else if (score > MATE_SCORE && score < MATE_VALUE) {
printf("info score mate %d depth %d nodes %ld pv ",
(MATE_VALUE - score) / 2 + 1, crnt, stats.nodes);
@@ -482,10 +505,12 @@ void uci_loop(void) {
void init(void) {
attacks_init();
zobrist_init();
+ ttable = ttable_new(C64(0x400000));
}
int main(void) {
init();
uci_loop();
+ ttable_free(&ttable);
return 0;
}
diff --git a/src/engine/transposition.c b/src/engine/transposition.c
@@ -0,0 +1,78 @@
+#include <cul/assert.h>
+#include <cul/mem.h>
+#include <string.h>
+
+#include "board.h"
+#include "moves.h"
+#include "transposition.h"
+
+#define TTABLE_SIZE 0x400000
+
+#define T TTable
+
+typedef struct Hashe Hashe;
+struct Hashe {
+ U64 key;
+ Move best;
+ int depth;
+ int score;
+ HasheFlag flag;
+};
+
+struct T {
+ U64 size;
+ Hashe table[];
+};
+
+T *ttable_new(U64 size) {
+ T *self = CALLOC(1, sizeof(T) + size * sizeof(Hashe));
+ self->size = size;
+ return self;
+}
+
+void ttable_free(T **self) {
+ assert(self && *self);
+ FREE(*self);
+}
+
+void ttable_clear(T *self) {
+ assert(self);
+ memset(self->table, 0x0, sizeof(T) + self->size * sizeof(Hashe));
+}
+
+int ttable_read(T *self, U64 hash, int alpha, int beta, int depth, int ply) {
+ assert(self);
+
+ Hashe *phashe = &self->table[hash % self->size];
+ if (phashe->key == hash) {
+ if (phashe->depth >= depth) {
+ int score = phashe->score;
+
+ if (score < -MATE_SCORE) score += ply;
+ if (score > MATE_SCORE) score -= ply;
+
+ if (phashe->flag == flagExact) return score;
+ if ((phashe->flag == flagAlpha) && (score <= alpha)) return alpha;
+ if ((phashe->flag == flagBeta) && (score >= beta)) return beta;
+ }
+ }
+ return TTABLE_UNKNOWN;
+}
+
+void ttable_write(T *self, U64 hash, int score, int depth, int ply,
+ HasheFlag flag) {
+ assert(self);
+
+ Hashe *phashe = &self->table[hash % self->size];
+
+ if (score < -MATE_SCORE) score += ply;
+ if (score > MATE_SCORE) score -= ply;
+
+ *phashe = (Hashe){
+ .key = hash,
+ .best = 0,
+ .depth = depth,
+ .score = score,
+ .flag = flag,
+ };
+}
diff --git a/src/include/transposition.h b/src/include/transposition.h
@@ -0,0 +1,28 @@
+#ifndef STELLAR_TRANSPOSITION_H
+#define STELLAR_TRANSPOSITION_H
+
+#include "utils.h"
+
+#define TTABLE_UNKNOWN 100000
+
+#define T TTable
+
+typedef enum HasheFlag HasheFlag;
+enum HasheFlag {
+ flagExact,
+ flagAlpha,
+ flagBeta
+};
+
+typedef struct T T;
+
+T *ttable_new(U64 size);
+void ttable_free(T **self);
+void ttable_clear(T *self);
+
+int ttable_read(T *self, U64 hash, int alpha, int beta, int depth, int ply);
+void ttable_write(T *self, U64 hash, int score, int depth, int ply,
+ HasheFlag flag);
+
+#undef T
+#endif
diff --git a/src/include/utils.h b/src/include/utils.h
@@ -3,6 +3,10 @@
#include <inttypes.h>
+#define INFINITY 50000
+#define MATE_VALUE 49000
+#define MATE_SCORE 48000
+
// useful macros
#define MAX(a, b) ((a > b) ? a : b)
#define MIN(a, b) ((a < b) ? a : b)
diff --git a/src/moves/moves.c b/src/moves/moves.c
@@ -61,4 +61,3 @@ void move_list_print(const MoveList *self) {
move_print(self->moves[i]);
printf("Total: %d\n", self->count);
}
-