Skip to content

Commit

Permalink
refactor (#3890)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #3890

Differential Revision: D63406173
  • Loading branch information
mengdilin authored and facebook-github-bot committed Sep 25, 2024
1 parent b7c7bc3 commit dd86d34
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 40 deletions.
26 changes: 14 additions & 12 deletions faiss/impl/ScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,8 @@ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
return simd8float32(
{vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))});
#else
throw std::runtime_error("not implemented");

#endif
}
Expand Down Expand Up @@ -671,7 +673,7 @@ struct QuantizerBF16<8> : QuantizerBF16<1> {

FAISS_ALWAYS_INLINE simd8float32
reconstruct_8_components(const uint8_t* code, int i) const {
#ifdef __AVX2__
// #ifdef __AVX2__
// reference impl: decode_bf16(((uint16_t*)code)[i]);
// decode_bf16(v) -> (uint32_t(v) << 16)
// read 128-bits (16 uint8_t) -> (uint16_t*)code)[i]
Expand All @@ -683,18 +685,18 @@ struct QuantizerBF16<8> : QuantizerBF16<1> {
simd8uint32 shifted_16 = code_256i << 16;
return as_float32(shifted_16);

#endif
// #endif

#ifdef __aarch64__
// #ifdef __aarch64__

uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
return simd8float32(
{vreinterpretq_f32_u32(
vshlq_n_u32(vmovl_u16(codei.val[0]), 16)),
vreinterpretq_f32_u32(
vshlq_n_u32(vmovl_u16(codei.val[1]), 16))});
// uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 *
// i)); return simd8float32(
// {vreinterpretq_f32_u32(
// vshlq_n_u32(vmovl_u16(codei.val[0]), 16)),
// vreinterpretq_f32_u32(
// vshlq_n_u32(vmovl_u16(codei.val[1]), 16))});

#endif
// #endif
}
};

Expand Down Expand Up @@ -1119,7 +1121,7 @@ struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
}
};

#else
#elif defined(__AVX2__) || defined(__aarch64__)

template <>
struct SimilarityL2<8> {
Expand Down Expand Up @@ -1229,7 +1231,7 @@ struct SimilarityL2<8> {
}
};

#else
#elif defined(__AVX2__) || defined(__aarch64__)

template <>
struct SimilarityIP<8> {
Expand Down
90 changes: 62 additions & 28 deletions faiss/utils/simdlib_emulated.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,19 @@ struct simd8uint32 : simd256bit {
bool operator!=(simd8uint32 other) const {
return !(*this == other);
}
template <typename F>
static simd8uint32 unary_func(const simd8uint32& a, F&& f) {
simd8uint32 c;
for (int j = 0; j < 8; j++) {
c.u32[j] = f(a.u32[j]);
}
return c;
}

// shift must be known at compile time
simd8uint32 operator<<(const int shift) const {
return unary_func(*this, [shift](uint16_t a) { return a << shift; });
}

std::string elements_to_string(const char* fmt) const {
char res[1000], *ptr = res;
Expand Down Expand Up @@ -705,6 +718,13 @@ struct simd8uint32 : simd256bit {
}
};

inline simd8uint32 load8_16bits_as_uint32(const uint8_t* code, int i) {
simd8uint32 res;
for (int j = 0; j < 16; j = j + 2) {
res.u32[j / 2] = *(code + i + j);
}
return res;
}
// Vectorized version of the following code:
// for (size_t i = 0; i < n; i++) {
// bool flag = (candidateValues[i] < currentValues[i]);
Expand Down Expand Up @@ -833,8 +853,12 @@ struct simd8float32 : simd256bit {
ptr[-1] = 0;
return std::string(res);
}
};

float accumulate() const {
return f32[0] + f32[1] + f32[2] + f32[3] + f32[4] + f32[5] + f32[6] +
f32[7];
};
};
// hadd does not cross lanes
inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) {
simd8float32 c;
Expand Down Expand Up @@ -880,25 +904,17 @@ inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) {

return c;
}

// compute a * b + c
inline simd8float32 fmadd(
const simd8float32& a,
const simd8float32& b,
const simd8float32& c) {
simd8float32 res;
for (int i = 0; i < 8; i++) {
res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i];
}
return res;
}

inline simd8float32 load8(const uint8_t* code, int i) {
simd8float32 res;
for (int j = 0; j < 8; j++) {
res.f32[i] = *(code + i + j);
}
return res;
inline simd8float32 as_float32(simd8uint32 x) {
simd8float32 c;
c.f32[0] = x.u32[0];
c.f32[1] = x.u32[1];
c.f32[2] = x.u32[2];
c.f32[3] = x.u32[3];
c.f32[4] = x.u32[4];
c.f32[5] = x.u32[5];
c.f32[6] = x.u32[6];
c.f32[7] = x.u32[7];
return c;
}

namespace {
Expand Down Expand Up @@ -981,8 +997,8 @@ simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) {
// lowestIndex = i;
// }
// }
// Vectorized version can be implemented via two operations: cmp and blend
// with something like this:
// Vectorized version can be implemented via two operations: cmp and
// blend with something like this:
// lowestValues = [HUGE_VAL; 8];
// lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7};
// for (size_t i = 0; i < n; i += 8) {
Expand All @@ -1000,8 +1016,9 @@ simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) {
// The problem is that blend primitive needs very different instruction
// order for AVX and ARM.
// So, let's introduce a combination of these two in order to avoid
// confusion for ppl who write in low-level SIMD instructions. Additionally,
// these two ops (cmp and blend) are very often used together.
// confusion for ppl who write in low-level SIMD instructions.
// Additionally, these two ops (cmp and blend) are very often used
// together.
inline void cmplt_and_blend_inplace(
const simd8float32 candidateValues,
const simd8uint32 candidateIndices,
Expand All @@ -1024,9 +1041,9 @@ inline void cmplt_and_blend_inplace(
// maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
// maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
// }
// Max indices evaluation is inaccurate in case of equal values (the index of
// the last equal value is saved instead of the first one), but this behavior
// saves instructions.
// Max indices evaluation is inaccurate in case of equal values (the
// index of the last equal value is saved instead of the first one), but
// this behavior saves instructions.
inline void cmplt_min_max_fast(
const simd8float32 candidateValues,
const simd8uint32 candidateIndices,
Expand All @@ -1049,5 +1066,22 @@ inline void cmplt_min_max_fast(
}

} // namespace

// compute a * b + c
inline simd8float32 fmadd(
const simd8float32& a,
const simd8float32& b,
const simd8float32& c) {
simd8float32 res;
for (int i = 0; i < 8; i++) {
res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i];
}
return res;
}
inline simd8float32 load8(const uint8_t* code, int i) {
simd8float32 res;
for (int j = 0; j < 8; j++) {
res.f32[j] = *(code + i + j);
}
return res;
}
} // namespace faiss
69 changes: 69 additions & 0 deletions faiss/utils/simdlib_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,11 @@ static inline uint32_t cmp_xe32(
return d0_mask | static_cast<uint32_t>(d1_mask) << 16;
}

template <std::uint8_t Shift>
static inline uint32x4_t vshlq(uint32x4_t vec) {
return vshlq_n_u32(vec, Shift);
}

template <std::uint8_t Shift>
static inline uint16x8_t vshlq(uint16x8_t vec) {
return vshlq_n_u16(vec, Shift);
Expand Down Expand Up @@ -972,6 +977,63 @@ struct simd8uint32 {
return ~(*this == other);
}

// shift must be known at compile time
simd8uint32 operator<<(const int shift) const {
switch (shift) {
case 0:
return *this;
case 1:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<1>>()};
case 2:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<2>>()};
case 3:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<3>>()};
case 4:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<4>>()};
case 5:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<5>>()};
case 6:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<6>>()};
case 7:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<7>>()};
case 8:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<8>>()};
case 9:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<9>>()};
case 10:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<10>>()};
case 11:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<11>>()};
case 12:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<12>>()};
case 13:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<13>>()};
case 14:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<14>>()};
case 15:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<15>>()};
case 16:
return simd8uint32{detail::simdlib::unary_func(data)
.call<detail::simdlib::vshlq<16>>()};
default:
FAISS_THROW_FMT("Invalid shift %d", shift);
}
}
// Checks whether the other holds exactly the same bytes.
template <typename T>
bool is_same_as(T other) const {
Expand Down Expand Up @@ -1240,6 +1302,13 @@ inline simd8float32 load8(const uint8_t* code, int i) {
{vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))});
}

inline simd8uint32 load8_16bits_as_uint32(const uint8_t* code, int i) {
uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
return simd8uint32({vmovl_u16(codei.val[0]), vmovl_u16(codei.val[1])});
}
inline simd8float32 as_float32(simd8uint32 x) {
return simd8float32(detail::simdlib::reinterpret_f32(x.data));
}
// The following primitive is a vectorized version of the following code
// snippet:
// float lowestValue = HUGE_VAL;
Expand Down

0 comments on commit dd86d34

Please sign in to comment.