diff --git a/src/nnue/network.cpp b/src/nnue/network.cpp index 803e4627476..656ad97a1e3 100644 --- a/src/nnue/network.cpp +++ b/src/nnue/network.cpp @@ -186,12 +186,11 @@ bool Network::save(const std::optional& filename template -Value Network::evaluate( - const Position& pos, - AccumulatorCaches::Cache* cache, - bool adjusted, - int* complexity, - bool psqtOnly) const { +Value Network::evaluate(const Position& pos, + AccumulatorCaches::Cache* cache, + bool adjusted, + int* complexity, + bool psqtOnly) const { // We manually align the arrays on the stack because with gcc < 9.3 // overaligning stack variables with alignas() doesn't work correctly. @@ -199,14 +198,14 @@ Value Network::evaluate( constexpr int delta = 24; #if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN) - TransformedFeatureType transformedFeaturesUnaligned - [FeatureTransformer::BufferSize - + alignment / sizeof(TransformedFeatureType)]; + TransformedFeatureType + transformedFeaturesUnaligned[FeatureTransformer::BufferSize + + alignment / sizeof(TransformedFeatureType)]; auto* transformedFeatures = align_ptr_up(&transformedFeaturesUnaligned[0]); #else - alignas(alignment) TransformedFeatureType transformedFeatures - [FeatureTransformer::BufferSize]; + alignas(alignment) TransformedFeatureType + transformedFeatures[FeatureTransformer::BufferSize]; #endif ASSERT_ALIGNED(transformedFeatures, alignment); @@ -258,29 +257,29 @@ void Network::verify(std::string evalfilePath) const { template -void Network::hint_common_access( - const Position& pos, - AccumulatorCaches::Cache* cache, - bool psqtOnl) const { +void Network::hint_common_access(const Position& pos, + AccumulatorCaches::Cache* cache, + bool psqtOnl) const { featureTransformer->hint_common_access(pos, cache, psqtOnl); } template -NnueEvalTrace Network::trace_evaluate( - const Position& pos, AccumulatorCaches::Cache* cache) const { +NnueEvalTrace +Network::trace_evaluate(const Position& pos, + AccumulatorCaches::Cache* cache) const { // We manually align the arrays on the stack because with gcc < 9.3 // overaligning stack variables with alignas() doesn't work correctly. constexpr uint64_t alignment = CacheLineSize; #if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN) - TransformedFeatureType transformedFeaturesUnaligned - [FeatureTransformer::BufferSize - + alignment / sizeof(TransformedFeatureType)]; + TransformedFeatureType + transformedFeaturesUnaligned[FeatureTransformer::BufferSize + + alignment / sizeof(TransformedFeatureType)]; auto* transformedFeatures = align_ptr_up(&transformedFeaturesUnaligned[0]); #else - alignas(alignment) TransformedFeatureType transformedFeatures - [FeatureTransformer::BufferSize]; + alignas(alignment) TransformedFeatureType + transformedFeatures[FeatureTransformer::BufferSize]; #endif ASSERT_ALIGNED(transformedFeatures, alignment); diff --git a/src/nnue/network.h b/src/nnue/network.h index 049374766c2..df59732d955 100644 --- a/src/nnue/network.h +++ b/src/nnue/network.h @@ -43,6 +43,8 @@ enum class EmbeddedNNUEType { template class Network { + static constexpr IndexType FTDimensions = Arch::TransformedFeatureDimensions; + public: Network(EvalFile file, EmbeddedNNUEType type) : evalFile(file), @@ -51,21 +53,20 @@ class Network { void load(const std::string& rootDirectory, std::string evalfilePath); bool save(const std::optional& filename) const; - Value evaluate(const Position& pos, - AccumulatorCaches::Cache* cache, - bool adjusted = false, - int* complexity = nullptr, - bool psqtOnly = false) const; + Value evaluate(const Position& pos, + AccumulatorCaches::Cache* cache, + bool adjusted = false, + int* complexity = nullptr, + bool psqtOnly = false) const; - void hint_common_access(const Position& pos, - AccumulatorCaches::Cache* cache, - bool psqtOnl) const; + void hint_common_access(const Position& pos, + AccumulatorCaches::Cache* cache, + bool psqtOnl) const; - void verify(std::string evalfilePath) const; - NnueEvalTrace - trace_evaluate(const Position& pos, - AccumulatorCaches::Cache* cache) const; + void verify(std::string evalfilePath) const; + NnueEvalTrace trace_evaluate(const Position& pos, + AccumulatorCaches::Cache* cache) const; private: void load_user_net(const std::string&, const std::string&); diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index de09543b883..ace29b738a4 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -194,7 +194,7 @@ static constexpr int BestRegisterCount() { template StateInfo::*accPtr> class FeatureTransformer { - public: + // Number of output dimensions for one side static constexpr IndexType HalfDimensions = TransformedFeatureDimensions;