Skip to content

Commit

Permalink
Implement accumulator refresh table
Browse files Browse the repository at this point in the history
  • Loading branch information
gab8192 committed Apr 20, 2024
1 parent c55ae37 commit 0712cab
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/nnue/features/half_ka_v2_hm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ void HalfKAv2_hm::append_active_indices(const Position& pos, IndexList& active)
// Explicit template instantiations
template void HalfKAv2_hm::append_active_indices<WHITE>(const Position& pos, IndexList& active);
template void HalfKAv2_hm::append_active_indices<BLACK>(const Position& pos, IndexList& active);
template IndexType HalfKAv2_hm::make_index<WHITE>(Square s, Piece pc, Square ksq);
template IndexType HalfKAv2_hm::make_index<BLACK>(Square s, Piece pc, Square ksq);

// Get a list of indices for recently changed features
template<Color Perspective>
Expand Down
8 changes: 4 additions & 4 deletions src/nnue/features/half_ka_v2_hm.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ class HalfKAv2_hm {
{PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE,
PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE}};

// Index of a feature for a given king position and another piece on some square
template<Color Perspective>
static IndexType make_index(Square s, Piece pc, Square ksq);

public:
// Feature name
static constexpr const char* Name = "HalfKAv2_hm(Friend)";
Expand Down Expand Up @@ -126,6 +122,10 @@ class HalfKAv2_hm {
static constexpr IndexType MaxActiveDimensions = 32;
using IndexList = ValueList<IndexType, MaxActiveDimensions>;

// Index of a feature for a given king position and another piece on some square
template<Color Perspective>
static IndexType make_index(Square s, Piece pc, Square ksq);

// Get a list of indices for active features
template<Color Perspective>
static void append_active_indices(const Position& pos, IndexList& active);
Expand Down
5 changes: 5 additions & 0 deletions src/nnue/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ void Network<Arch, Transformer>::hint_common_access(const Position& pos, bool ps
featureTransformer->hint_common_access(pos, psqtOnl);
}

template<typename Arch, typename Transformer>
void Network<Arch, Transformer>::init_refresh_entry(AccumulatorRefreshEntry& entry) const {
featureTransformer->init_refresh_entry(entry);
}


template<typename Arch, typename Transformer>
NnueEvalTrace Network<Arch, Transformer>::trace_evaluate(const Position& pos) const {
Expand Down
4 changes: 3 additions & 1 deletion src/nnue/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

namespace Stockfish::Eval::NNUE {

struct AccumulatorRefreshEntry;

enum class EmbeddedNNUEType {
BIG,
Expand All @@ -51,7 +52,6 @@ class Network {
void load(const std::string& rootDirectory, std::string evalfilePath);
bool save(const std::optional<std::string>& filename) const;


Value evaluate(const Position& pos,
bool adjusted = false,
int* complexity = nullptr,
Expand All @@ -60,6 +60,8 @@ class Network {

void hint_common_access(const Position& pos, bool psqtOnl) const;

void init_refresh_entry(AccumulatorRefreshEntry& entry) const;

void verify(std::string evalfilePath) const;
NnueEvalTrace trace_evaluate(const Position& pos) const;

Expand Down
6 changes: 6 additions & 0 deletions src/nnue/nnue_accumulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ struct alignas(CacheLineSize) Accumulator {
bool computedPSQT[2];
};

struct AccumulatorRefreshEntry {
Bitboard byColorBB[2][2];
Bitboard byTypeBB[2][8];
Accumulator<TransformedFeatureDimensionsBig> acc;
};

} // namespace Stockfish::Eval::NNUE

#endif // NNUE_ACCUMULATOR_H_INCLUDED
163 changes: 163 additions & 0 deletions src/nnue/nnue_feature_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,20 @@ class FeatureTransformer {
hint_common_access_for_perspective<BLACK>(pos, psqtOnly);
}

void init_refresh_entry(AccumulatorRefreshEntry& entry) {
assert(HalfDimensions == TransformedFeatureDimensionsBig);

// To initialize a refresh entry, we set all its bitboards empty,
// so we put the biases in the accumulation, without any weights on top

std::memset(entry.byColorBB, 0, 2 * 2 * sizeof(Bitboard));
std::memset(entry.byTypeBB, 0, 2 * 8 * sizeof(Bitboard));

std::memcpy(entry.acc.accumulation[WHITE], biases, HalfDimensions * sizeof(BiasType));
std::memcpy(entry.acc.accumulation[BLACK], biases, HalfDimensions * sizeof(BiasType));
std::memset(entry.acc.psqtAccumulation, 0, sizeof(entry.acc.psqtAccumulation));
}

private:
template<Color Perspective>
[[nodiscard]] std::pair<StateInfo*, StateInfo*>
Expand Down Expand Up @@ -651,8 +665,157 @@ class FeatureTransformer {
#endif
}

template<Color Perspective>
void update_accumulator_refresh_cache(Position& pos) const {

assert(HalfDimensions == TransformedFeatureDimensionsBig);

Square ksq = pos.square<KING>(Perspective);
AccumulatorRefreshEntry& entry = pos.refreshTable[ksq];

auto& accumulator = pos.state()->*accPtr;
accumulator.computed[Perspective] = true;
accumulator.computedPSQT[Perspective] = true;

FeatureSet::IndexList removed, added;
for (Color c = WHITE; c <= BLACK; c = Color(int(c)+1)) {
for (PieceType pt = PAWN; pt <= KING; ++pt) {
const Piece piece = make_piece(c, pt);
const Bitboard oldBB = entry.byColorBB[Perspective][c] & entry.byTypeBB[Perspective][pt];
const Bitboard newBB = pos.pieces(c, pt);
Bitboard toRemove = oldBB & ~newBB;
Bitboard toAdd = newBB & ~oldBB;

while (toRemove) {
Square sq = pop_lsb(toRemove);
removed.push_back(FeatureSet::make_index<Perspective>(sq, piece, ksq));
}
while (toAdd) {
Square sq = pop_lsb(toAdd);
added.push_back(FeatureSet::make_index<Perspective>(sq, piece, ksq));
}
}
}

#ifdef VECTOR
int16_t* entryAccumulation = entry.acc.accumulation[Perspective];
int32_t* entryPsqtAccumulation = entry.acc.psqtAccumulation[Perspective];

vec_t acc[NumRegs];
psqt_vec_t psqt[NumPsqtRegs];

for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j)
{
auto entryTile = reinterpret_cast<vec_t*>(&entryAccumulation[j * TileHeight]);
for (IndexType k = 0; k < NumRegs; ++k)
acc[k] = entryTile[k];

for (int i = 0; i < int(added.size()); ++i)
{
IndexType index = added[i];
const IndexType offset = HalfDimensions * index + j * TileHeight;
auto column = reinterpret_cast<const vec_t*>(&weights[offset]);

for (unsigned k = 0; k < NumRegs; ++k)
acc[k] = vec_add_16(acc[k], column[k]);
}
for (int i = 0; i < int(removed.size()); ++i)
{
IndexType index = removed[i];
const IndexType offset = HalfDimensions * index + j * TileHeight;
auto column = reinterpret_cast<const vec_t*>(&weights[offset]);

for (unsigned k = 0; k < NumRegs; ++k)
acc[k] = vec_sub_16(acc[k], column[k]);
}

for (IndexType k = 0; k < NumRegs; k++)
vec_store(&entryTile[k], acc[k]);
}

for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j)
{
auto entryTilePsqt = reinterpret_cast<psqt_vec_t*>(&entryPsqtAccumulation[j * PsqtTileHeight]);
for (std::size_t k = 0; k < NumPsqtRegs; ++k)
psqt[k] = entryTilePsqt[k];

for (int i = 0; i < int(added.size()); ++i)
{
IndexType index = added[i];
const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);

for (std::size_t k = 0; k < NumPsqtRegs; ++k)
psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]);
}
for (int i = 0; i < int(removed.size()); ++i)
{
IndexType index = removed[i];
const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);

for (std::size_t k = 0; k < NumPsqtRegs; ++k)
psqt[k] = vec_sub_psqt_32(psqt[k], columnPsqt[k]);
}

for (std::size_t k = 0; k < NumPsqtRegs; ++k)
vec_store_psqt(&entryTilePsqt[k], psqt[k]);
}

#else

for (const auto index : added)
{
const IndexType offset = HalfDimensions * index;
for (IndexType j = 0; j < HalfDimensions; ++j)
entry.acc.accumulation[Perspective][j] += weights[offset + j];

for (std::size_t k = 0; k < PSQTBuckets; ++k)
entry.acc.psqtAccumulation[Perspective][k] +=
psqtWeights[index * PSQTBuckets + k];
}
for (const auto index : removed)
{
const IndexType offset = HalfDimensions * index;
for (IndexType j = 0; j < HalfDimensions; ++j)
entry.acc.accumulation[Perspective][j] -= weights[offset + j];

for (std::size_t k = 0; k < PSQTBuckets; ++k)
entry.acc.psqtAccumulation[Perspective][k] -=
psqtWeights[index * PSQTBuckets + k];
}

#endif

// The accumulator of the refresh entry has been updated.
// Now copy its content to the actual accumulator we were refreshing

std::memcpy(accumulator.psqtAccumulation[Perspective],
entry.acc.psqtAccumulation[Perspective],
sizeof(int32_t) * PSQTBuckets);

std::memcpy(accumulator.accumulation[Perspective],
entry.acc.accumulation[Perspective],
sizeof(int16_t) * HalfDimensions);

for (int i = WHITE; i <= BLACK; i++)
entry.byColorBB[Perspective][i] = pos.pieces(Color(i));

for (int i = PAWN; i <= KING; i++)
entry.byTypeBB[Perspective][i] = pos.pieces(PieceType(i));
}

template<Color Perspective>
void update_accumulator_refresh(const Position& pos, bool psqtOnly) const {

// When we are refreshing the accumulator of the big net,
// redirect to the version of refresh that uses the refresh table
if (HalfDimensions == Eval::NNUE::TransformedFeatureDimensionsBig) {
// TODO: find a better solution than const_casting the position
update_accumulator_refresh_cache<Perspective>(const_cast<Position&>(pos));
return;
}

#ifdef VECTOR
// Gcc-10.2 unnecessarily spills AVX2 registers if this array
// is defined in the VECTOR code below, once in each branch
Expand Down
3 changes: 3 additions & 0 deletions src/position.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ class Position {
void put_piece(Piece pc, Square s);
void remove_piece(Square s);

// Used by NNUE
Eval::NNUE::AccumulatorRefreshEntry refreshTable[SQUARE_NB];

private:
// Initialization helpers (used while setting up a position)
void set_castling_right(Color c, Square rfrom);
Expand Down
7 changes: 7 additions & 0 deletions src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include "misc.h"
#include "movegen.h"
#include "movepick.h"
#include "nnue/network.h"
#include "nnue/nnue_accumulator.h"
#include "nnue/nnue_common.h"
#include "nnue/nnue_misc.h"
#include "position.h"
Expand Down Expand Up @@ -143,6 +145,11 @@ Search::Worker::Worker(SharedState& sharedState,
}

void Search::Worker::start_searching() {

// Initialize accumulator refresh entries
for (int i = 0; i < SQUARE_NB; i++)
networks.big.init_refresh_entry(rootPos.refreshTable[i]);

// Non-main threads go directly to iterative_deepening()
if (!is_mainthread())
{
Expand Down

0 comments on commit 0712cab

Please sign in to comment.