diff --git a/src/nnue/features/half_ka_v2_hm.cpp b/src/nnue/features/half_ka_v2_hm.cpp index 5789db4844a..dd49fce6b18 100644 --- a/src/nnue/features/half_ka_v2_hm.cpp +++ b/src/nnue/features/half_ka_v2_hm.cpp @@ -49,6 +49,8 @@ void HalfKAv2_hm::append_active_indices(const Position& pos, IndexList& active) // Explicit template instantiations template void HalfKAv2_hm::append_active_indices(const Position& pos, IndexList& active); template void HalfKAv2_hm::append_active_indices(const Position& pos, IndexList& active); +template IndexType HalfKAv2_hm::make_index(Square s, Piece pc, Square ksq); +template IndexType HalfKAv2_hm::make_index(Square s, Piece pc, Square ksq); // Get a list of indices for recently changed features template diff --git a/src/nnue/features/half_ka_v2_hm.h b/src/nnue/features/half_ka_v2_hm.h index 8363184f430..96349704745 100644 --- a/src/nnue/features/half_ka_v2_hm.h +++ b/src/nnue/features/half_ka_v2_hm.h @@ -63,10 +63,6 @@ class HalfKAv2_hm { {PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE, PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE}}; - // Index of a feature for a given king position and another piece on some square - template - static IndexType make_index(Square s, Piece pc, Square ksq); - public: // Feature name static constexpr const char* Name = "HalfKAv2_hm(Friend)"; @@ -126,6 +122,10 @@ class HalfKAv2_hm { static constexpr IndexType MaxActiveDimensions = 32; using IndexList = ValueList; + // Index of a feature for a given king position and another piece on some square + template + static IndexType make_index(Square s, Piece pc, Square ksq); + // Get a list of indices for active features template static void append_active_indices(const Position& pos, IndexList& active); diff --git a/src/nnue/network.cpp b/src/nnue/network.cpp index bea3e7cb398..5df7270162c 100644 --- a/src/nnue/network.cpp +++ b/src/nnue/network.cpp @@ -259,6 +259,11 @@ void Network::hint_common_access(const Position& pos, bool ps featureTransformer->hint_common_access(pos, psqtOnl); } +template +void Network::init_refresh_entry(AccumulatorRefreshEntry& entry) const { + featureTransformer->init_refresh_entry(entry); +} + template NnueEvalTrace Network::trace_evaluate(const Position& pos) const { diff --git a/src/nnue/network.h b/src/nnue/network.h index 21e1c622205..0c095d85c31 100644 --- a/src/nnue/network.h +++ b/src/nnue/network.h @@ -34,6 +34,7 @@ namespace Stockfish::Eval::NNUE { +struct AccumulatorRefreshEntry; enum class EmbeddedNNUEType { BIG, @@ -51,7 +52,6 @@ class Network { void load(const std::string& rootDirectory, std::string evalfilePath); bool save(const std::optional& filename) const; - Value evaluate(const Position& pos, bool adjusted = false, int* complexity = nullptr, @@ -60,6 +60,8 @@ class Network { void hint_common_access(const Position& pos, bool psqtOnl) const; + void init_refresh_entry(AccumulatorRefreshEntry& entry) const; + void verify(std::string evalfilePath) const; NnueEvalTrace trace_evaluate(const Position& pos) const; diff --git a/src/nnue/nnue_accumulator.h b/src/nnue/nnue_accumulator.h index c0746b4ee86..791b61607e7 100644 --- a/src/nnue/nnue_accumulator.h +++ b/src/nnue/nnue_accumulator.h @@ -37,6 +37,12 @@ struct alignas(CacheLineSize) Accumulator { bool computedPSQT[2]; }; +struct AccumulatorRefreshEntry { + Bitboard byColorBB[2][2]; + Bitboard byTypeBB[2][8]; + Accumulator acc; +}; + } // namespace Stockfish::Eval::NNUE #endif // NNUE_ACCUMULATOR_H_INCLUDED diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index 3101c8d2689..a88b94b1ef9 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -378,6 +378,20 @@ class FeatureTransformer { hint_common_access_for_perspective(pos, psqtOnly); } + void init_refresh_entry(AccumulatorRefreshEntry& entry) { + assert(HalfDimensions == TransformedFeatureDimensionsBig); + + // To initialize a refresh entry, we set all its bitboards empty, + // so we put the biases in the accumulation, without any weights on top + + std::memset(entry.byColorBB, 0, 2 * 2 * sizeof(Bitboard)); + std::memset(entry.byTypeBB, 0, 2 * 8 * sizeof(Bitboard)); + + std::memcpy(entry.acc.accumulation[WHITE], biases, HalfDimensions * sizeof(BiasType)); + std::memcpy(entry.acc.accumulation[BLACK], biases, HalfDimensions * sizeof(BiasType)); + std::memset(entry.acc.psqtAccumulation, 0, sizeof(entry.acc.psqtAccumulation)); + } + private: template [[nodiscard]] std::pair @@ -651,8 +665,157 @@ class FeatureTransformer { #endif } + template + void update_accumulator_refresh_cache(Position& pos) const { + + assert(HalfDimensions == TransformedFeatureDimensionsBig); + + Square ksq = pos.square(Perspective); + AccumulatorRefreshEntry& entry = pos.refreshTable[ksq]; + + auto& accumulator = pos.state()->*accPtr; + accumulator.computed[Perspective] = true; + accumulator.computedPSQT[Perspective] = true; + + FeatureSet::IndexList removed, added; + for (Color c = WHITE; c <= BLACK; c = Color(int(c)+1)) { + for (PieceType pt = PAWN; pt <= KING; ++pt) { + const Piece piece = make_piece(c, pt); + const Bitboard oldBB = entry.byColorBB[Perspective][c] & entry.byTypeBB[Perspective][pt]; + const Bitboard newBB = pos.pieces(c, pt); + Bitboard toRemove = oldBB & ~newBB; + Bitboard toAdd = newBB & ~oldBB; + + while (toRemove) { + Square sq = pop_lsb(toRemove); + removed.push_back(FeatureSet::make_index(sq, piece, ksq)); + } + while (toAdd) { + Square sq = pop_lsb(toAdd); + added.push_back(FeatureSet::make_index(sq, piece, ksq)); + } + } + } + +#ifdef VECTOR + int16_t* entryAccumulation = entry.acc.accumulation[Perspective]; + int32_t* entryPsqtAccumulation = entry.acc.psqtAccumulation[Perspective]; + + vec_t acc[NumRegs]; + psqt_vec_t psqt[NumPsqtRegs]; + + for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j) + { + auto entryTile = reinterpret_cast(&entryAccumulation[j * TileHeight]); + for (IndexType k = 0; k < NumRegs; ++k) + acc[k] = entryTile[k]; + + for (int i = 0; i < int(added.size()); ++i) + { + IndexType index = added[i]; + const IndexType offset = HalfDimensions * index + j * TileHeight; + auto column = reinterpret_cast(&weights[offset]); + + for (unsigned k = 0; k < NumRegs; ++k) + acc[k] = vec_add_16(acc[k], column[k]); + } + for (int i = 0; i < int(removed.size()); ++i) + { + IndexType index = removed[i]; + const IndexType offset = HalfDimensions * index + j * TileHeight; + auto column = reinterpret_cast(&weights[offset]); + + for (unsigned k = 0; k < NumRegs; ++k) + acc[k] = vec_sub_16(acc[k], column[k]); + } + + for (IndexType k = 0; k < NumRegs; k++) + vec_store(&entryTile[k], acc[k]); + } + + for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j) + { + auto entryTilePsqt = reinterpret_cast(&entryPsqtAccumulation[j * PsqtTileHeight]); + for (std::size_t k = 0; k < NumPsqtRegs; ++k) + psqt[k] = entryTilePsqt[k]; + + for (int i = 0; i < int(added.size()); ++i) + { + IndexType index = added[i]; + const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight; + auto columnPsqt = reinterpret_cast(&psqtWeights[offset]); + + for (std::size_t k = 0; k < NumPsqtRegs; ++k) + psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]); + } + for (int i = 0; i < int(removed.size()); ++i) + { + IndexType index = removed[i]; + const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight; + auto columnPsqt = reinterpret_cast(&psqtWeights[offset]); + + for (std::size_t k = 0; k < NumPsqtRegs; ++k) + psqt[k] = vec_sub_psqt_32(psqt[k], columnPsqt[k]); + } + + for (std::size_t k = 0; k < NumPsqtRegs; ++k) + vec_store_psqt(&entryTilePsqt[k], psqt[k]); + } + +#else + + for (const auto index : added) + { + const IndexType offset = HalfDimensions * index; + for (IndexType j = 0; j < HalfDimensions; ++j) + entry.acc.accumulation[Perspective][j] += weights[offset + j]; + + for (std::size_t k = 0; k < PSQTBuckets; ++k) + entry.acc.psqtAccumulation[Perspective][k] += + psqtWeights[index * PSQTBuckets + k]; + } + for (const auto index : removed) + { + const IndexType offset = HalfDimensions * index; + for (IndexType j = 0; j < HalfDimensions; ++j) + entry.acc.accumulation[Perspective][j] -= weights[offset + j]; + + for (std::size_t k = 0; k < PSQTBuckets; ++k) + entry.acc.psqtAccumulation[Perspective][k] -= + psqtWeights[index * PSQTBuckets + k]; + } + +#endif + + // The accumulator of the refresh entry has been updated. + // Now copy its content to the actual accumulator we were refreshing + + std::memcpy(accumulator.psqtAccumulation[Perspective], + entry.acc.psqtAccumulation[Perspective], + sizeof(int32_t) * PSQTBuckets); + + std::memcpy(accumulator.accumulation[Perspective], + entry.acc.accumulation[Perspective], + sizeof(int16_t) * HalfDimensions); + + for (int i = WHITE; i <= BLACK; i++) + entry.byColorBB[Perspective][i] = pos.pieces(Color(i)); + + for (int i = PAWN; i <= KING; i++) + entry.byTypeBB[Perspective][i] = pos.pieces(PieceType(i)); + } + template void update_accumulator_refresh(const Position& pos, bool psqtOnly) const { + + // When we are refreshing the accumulator of the big net, + // redirect to the version of refresh that uses the refresh table + if (HalfDimensions == Eval::NNUE::TransformedFeatureDimensionsBig) { + // TODO: find a better solution than const_casting the position + update_accumulator_refresh_cache(const_cast(pos)); + return; + } + #ifdef VECTOR // Gcc-10.2 unnecessarily spills AVX2 registers if this array // is defined in the VECTOR code below, once in each branch diff --git a/src/position.h b/src/position.h index 154ed652942..b3925094266 100644 --- a/src/position.h +++ b/src/position.h @@ -172,6 +172,9 @@ class Position { void put_piece(Piece pc, Square s); void remove_piece(Square s); + // Used by NNUE + Eval::NNUE::AccumulatorRefreshEntry refreshTable[SQUARE_NB]; + private: // Initialization helpers (used while setting up a position) void set_castling_right(Color c, Square rfrom); diff --git a/src/search.cpp b/src/search.cpp index 24805aa70ea..baa35b1e127 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -33,6 +33,8 @@ #include "misc.h" #include "movegen.h" #include "movepick.h" +#include "nnue/network.h" +#include "nnue/nnue_accumulator.h" #include "nnue/nnue_common.h" #include "nnue/nnue_misc.h" #include "position.h" @@ -143,6 +145,11 @@ Search::Worker::Worker(SharedState& sharedState, } void Search::Worker::start_searching() { + + // Initialize accumulator refresh entries + for (int i = 0; i < SQUARE_NB; i++) + networks.big.init_refresh_entry(rootPos.refreshTable[i]); + // Non-main threads go directly to iterative_deepening() if (!is_mainthread()) {