Skip to content

Commit

Permalink
Simplify ClippedReLU
Browse files Browse the repository at this point in the history
Removes some max calls, and removes the need for SSSE3 specialisation.

Some speedup stats, courtesy of @AndyGrant
Dev  749240 nps
Base 748495 nps
Gain 0.100%
289936 games

LLR: 2.94 (-2.94,2.94) <-1.75,0.25>
Total: 203040 W: 52213 L: 52179 D: 98648
Ptnml(0-2): 480, 20722, 59139, 20642, 537
https://tests.stockfishchess.org/tests/view/664805fe6dcff0d1d6b05f2c

closes #5261

No functional change
  • Loading branch information
cj5716 committed May 18, 2024
1 parent 4edd1a3 commit 0a8372c
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 47 deletions.
51 changes: 17 additions & 34 deletions src/nnue/layers/clipped_relu.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,74 +65,57 @@ class ClippedReLU {
if constexpr (InputDimensions % SimdWidth == 0)
{
constexpr IndexType NumChunks = InputDimensions / SimdWidth;
const __m256i Zero = _mm256_setzero_si256();
const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
const auto in = reinterpret_cast<const __m256i*>(input);
const auto out = reinterpret_cast<__m256i*>(output);
for (IndexType i = 0; i < NumChunks; ++i)
{
const __m256i words0 =
_mm256_srai_epi16(_mm256_packs_epi32(_mm256_load_si256(&in[i * 4 + 0]),
_mm256_load_si256(&in[i * 4 + 1])),
_mm256_srli_epi16(_mm256_packus_epi32(_mm256_load_si256(&in[i * 4 + 0]),
_mm256_load_si256(&in[i * 4 + 1])),
WeightScaleBits);
const __m256i words1 =
_mm256_srai_epi16(_mm256_packs_epi32(_mm256_load_si256(&in[i * 4 + 2]),
_mm256_load_si256(&in[i * 4 + 3])),
_mm256_srli_epi16(_mm256_packus_epi32(_mm256_load_si256(&in[i * 4 + 2]),
_mm256_load_si256(&in[i * 4 + 3])),
WeightScaleBits);
_mm256_store_si256(
&out[i], _mm256_permutevar8x32_epi32(
_mm256_max_epi8(_mm256_packs_epi16(words0, words1), Zero), Offsets));
_mm256_store_si256(&out[i], _mm256_permutevar8x32_epi32(
_mm256_packs_epi16(words0, words1), Offsets));
}
}
else
{
constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2);
const __m128i Zero = _mm_setzero_si128();
const auto in = reinterpret_cast<const __m128i*>(input);
const auto out = reinterpret_cast<__m128i*>(output);
for (IndexType i = 0; i < NumChunks; ++i)
{
const __m128i words0 = _mm_srai_epi16(
_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
const __m128i words0 = _mm_srli_epi16(
_mm_packus_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
WeightScaleBits);
const __m128i words1 = _mm_srai_epi16(
_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
const __m128i words1 = _mm_srli_epi16(
_mm_packus_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
WeightScaleBits);
const __m128i packedbytes = _mm_packs_epi16(words0, words1);
_mm_store_si128(&out[i], _mm_max_epi8(packedbytes, Zero));
_mm_store_si128(&out[i], _mm_packs_epi16(words0, words1));
}
}
constexpr IndexType Start = InputDimensions % SimdWidth == 0
? InputDimensions / SimdWidth * SimdWidth
? InputDimensions
: InputDimensions / (SimdWidth / 2) * (SimdWidth / 2);

#elif defined(USE_SSE2)
constexpr IndexType NumChunks = InputDimensions / SimdWidth;

#ifdef USE_SSE41
const __m128i Zero = _mm_setzero_si128();
#else
const __m128i k0x80s = _mm_set1_epi8(-128);
#endif

const auto in = reinterpret_cast<const __m128i*>(input);
const auto out = reinterpret_cast<__m128i*>(output);
for (IndexType i = 0; i < NumChunks; ++i)
{
const __m128i words0 = _mm_srai_epi16(
_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
const __m128i words0 = _mm_srli_epi16(
_mm_packus_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
WeightScaleBits);
const __m128i words1 = _mm_srai_epi16(
_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
const __m128i words1 = _mm_srli_epi16(
_mm_packus_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
WeightScaleBits);
const __m128i packedbytes = _mm_packs_epi16(words0, words1);
_mm_store_si128(&out[i],

#ifdef USE_SSE41
_mm_max_epi8(packedbytes, Zero)
#else
_mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
#endif
_mm_store_si128(&out[i], _mm_packs_epi16(words0, words1)

);
}
Expand Down
9 changes: 3 additions & 6 deletions src/nnue/nnue_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,11 @@ trace(Position& pos, const Eval::NNUE::Networks& networks, Eval::NNUE::Accumulat
ss << "| " << bucket << " ";
ss << " | ";
format_cp_aligned_dot(t.psqt[bucket], ss, pos);
ss << " "
<< " | ";
ss << " " << " | ";
format_cp_aligned_dot(t.positional[bucket], ss, pos);
ss << " "
<< " | ";
ss << " " << " | ";
format_cp_aligned_dot(t.psqt[bucket] + t.positional[bucket], ss, pos);
ss << " "
<< " |";
ss << " " << " |";
if (bucket == t.correctBucket)
ss << " <-- this bucket is used";
ss << '\n';
Expand Down
6 changes: 2 additions & 4 deletions src/tune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ void make_option(OptionsMap* options, const string& n, int v, const SetRange& r)

// Print formatted parameters, ready to be copy-pasted in Fishtest
std::cout << n << "," << v << "," << r(v).first << "," << r(v).second << ","
<< (r(v).second - r(v).first) / 20.0 << ","
<< "0.0020" << std::endl;
<< (r(v).second - r(v).first) / 20.0 << "," << "0.0020" << std::endl;
}
}

Expand Down Expand Up @@ -118,7 +117,6 @@ void Tune::Entry<Tune::PostUpdate>::read_option() {

namespace Stockfish {

void Tune::read_results() { /* ...insert your values here... */
}
void Tune::read_results() { /* ...insert your values here... */ }

} // namespace Stockfish
6 changes: 3 additions & 3 deletions src/uci.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ void UCIEngine::bench(std::istream& args) {

dbg_print();

std::cerr << "\n==========================="
<< "\nTotal time (ms) : " << elapsed << "\nNodes searched : " << nodes
<< "\nNodes/second : " << 1000 * nodes / elapsed << std::endl;
std::cerr << "\n===========================" << "\nTotal time (ms) : " << elapsed
<< "\nNodes searched : " << nodes << "\nNodes/second : " << 1000 * nodes / elapsed
<< std::endl;

// reset callback, to not capture a dangling reference to nodesSearched
engine.set_on_update_full([&](const auto& i) { on_update_full(i, options["UCI_ShowWDL"]); });
Expand Down

0 comments on commit 0a8372c

Please sign in to comment.