diff --git a/lib/statespace_avx.h b/lib/statespace_avx.h index f21dd3f82..069317045 100644 --- a/lib/statespace_avx.h +++ b/lib/statespace_avx.h @@ -239,22 +239,31 @@ class StateSpaceAVX : public StateSpace, For, float> { state.get()[k + 8] = im; } - // Sets state[i] = val where (i & mask) == bits + // Sets state[i] = complex(re, im) where (i & mask) == bits. + // if `exclude` is true then the criteria becomes (i & mask) != bits. void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, - const std::complex& val) const { - BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val)); + const std::complex& val, + bool exclude = false) const { + BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val), exclude); } - // Sets state[i] = complex(re, im) where (i & mask) == bits + // Sets state[i] = complex(re, im) where (i & mask) == bits. + // if `exclude` is true then the criteria becomes (i & mask) != bits. void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, - fp_type im) const { + fp_type im, bool exclude = false) const { __m256 re_reg = _mm256_set1_ps(re); __m256 im_reg = _mm256_set1_ps(im); + __m256i exclude_reg = _mm256_setzero_si256(); + if (exclude) { + exclude_reg = _mm256_cmpeq_epi32(exclude_reg, exclude_reg); + } + auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, - uint64_t bitsv, __m256 re_n, __m256 im_n, fp_type* p) { - __m256 ml = - _mm256_castsi256_ps(detail::GetZeroMaskAVX(8 * i, maskv, bitsv)); + uint64_t bitsv, __m256 re_n, __m256 im_n, __m256i exclude_n, + fp_type* p) { + __m256 ml = _mm256_castsi256_ps(_mm256_xor_si256( + detail::GetZeroMaskAVX(8 * i, maskv, bitsv), exclude_n)); __m256 re = _mm256_load_ps(p + 16 * i); __m256 im = _mm256_load_ps(p + 16 * i + 8); @@ -267,7 +276,7 @@ class StateSpaceAVX : public StateSpace, For, float> { }; Base::for_.Run(MinSize(state.num_qubits()) / 16, f, mask, bits, re_reg, - im_reg, state.get()); + im_reg, exclude_reg, state.get()); } // Does the equivalent of dest += src elementwise. diff --git a/lib/statespace_basic.h b/lib/statespace_basic.h index 2cdab2c8f..e8d6ddf2d 100644 --- a/lib/statespace_basic.h +++ b/lib/statespace_basic.h @@ -96,26 +96,30 @@ class StateSpaceBasic : public StateSpace, For, FP> { state.get()[p + 1] = im; } - // Sets state[i] = val where (i & mask) == bits + // Sets state[i] = complex(re, im) where (i & mask) == bits. + // if `exclude` is true then the criteria becomes (i & mask) != bits. void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, - const std::complex& val) const { - BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val)); + const std::complex& val, + bool exclude = false) const { + BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val), exclude); } - // Sets state[i] = complex(re, im) where (i & mask) == bits + // Sets state[i] = complex(re, im) where (i & mask) == bits. + // if `exclude` is true then the criteria becomes (i & mask) != bits. void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, - fp_type im) const { + fp_type im, bool exclude = false) const { auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, - uint64_t bitsv, fp_type re_n, fp_type im_n, fp_type* p) { + uint64_t bitsv, fp_type re_n, fp_type im_n, bool excludev, + fp_type* p) { auto s = p + 2 * i; bool in_mask = (i & maskv) == bitsv; - + in_mask ^= excludev; s[0] = in_mask ? re_n : s[0]; s[1] = in_mask ? im_n : s[1]; }; Base::for_.Run(MinSize(state.num_qubits()) / 2, f, mask, bits, re, im, - state.get()); + exclude, state.get()); } // Does the equivalent of dest += src elementwise. diff --git a/lib/statespace_sse.h b/lib/statespace_sse.h index 5c1ecd415..a85dfe6a0 100644 --- a/lib/statespace_sse.h +++ b/lib/statespace_sse.h @@ -200,21 +200,30 @@ class StateSpaceSSE : public StateSpace, For, float> { state.get()[p + 4] = im; } - // Sets state[i] = val where (i & mask) == bits + // Sets state[i] = complex(re, im) where (i & mask) == bits. + // if `exclude` is true then the criteria becomes (i & mask) != bits. void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, - const std::complex& val) const { + const std::complex& val, + bool exclude = false) const { BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val)); } - // Sets state[i] = complex(re, im) where (i & mask) == bits + // Sets state[i] = complex(re, im) where (i & mask) == bits. + // if `exclude` is true then the criteria becomes (i & mask) != bits. void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, - fp_type im) const { + fp_type im, bool exclude = false) const { __m128 re_reg = _mm_set1_ps(re); __m128 im_reg = _mm_set1_ps(im); + __m128i exclude_reg = _mm_setzero_si128(); + if (exclude) { + exclude_reg = _mm_cmpeq_epi32(exclude_reg, exclude_reg); + } auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, - uint64_t bitsv, __m128 re_n, __m128 im_n, fp_type* p) { - __m128 ml = _mm_castsi128_ps(detail::GetZeroMaskSSE(4 * i, maskv, bitsv)); + uint64_t bitsv, __m128 re_n, __m128 im_n, __m128i exclude_n, + fp_type* p) { + __m128 ml = _mm_castsi128_ps(_mm_xor_si128( + detail::GetZeroMaskSSE(4 * i, maskv, bitsv), exclude_n)); __m128 re = _mm_load_ps(p + 8 * i); __m128 im = _mm_load_ps(p + 8 * i + 4); @@ -227,7 +236,7 @@ class StateSpaceSSE : public StateSpace, For, float> { }; Base::for_.Run(MinSize(state.num_qubits()) / 8, f, mask, bits, re_reg, - im_reg, state.get()); + im_reg, exclude_reg, state.get()); } // Does the equivalent of dest += src elementwise. diff --git a/tests/statespace_avx_test.cc b/tests/statespace_avx_test.cc index ca72a13b1..1889ea560 100644 --- a/tests/statespace_avx_test.cc +++ b/tests/statespace_avx_test.cc @@ -62,10 +62,18 @@ TEST(StateSpaceAVXTest, InvalidStateSize) { TestInvalidStateSize>(); } -TEST(StateSpaceBasicTest, BulkSetAmpl) { +TEST(StateSpaceAVXTest, BulkSetAmpl) { TestBulkSetAmplitude>(); } +TEST(StateSpaceAVXTest, BulkSetAmplExclude) { + TestBulkSetAmplitudeExclusion>(); +} + +TEST(StateSpaceAVXTest, BulkSetAmplDefault) { + TestBulkSetAmplitudeDefault>(); +} + } // namespace qsim int main(int argc, char** argv) { diff --git a/tests/statespace_basic_test.cc b/tests/statespace_basic_test.cc index 6a789f744..49950db43 100644 --- a/tests/statespace_basic_test.cc +++ b/tests/statespace_basic_test.cc @@ -66,6 +66,14 @@ TEST(StateSpaceBasicTest, BulkSetAmpl) { TestBulkSetAmplitude>(); } +TEST(StateSpaceBasicTest, BulkSetAmplExclude) { + TestBulkSetAmplitudeExclusion>(); +} + +TEST(StateSpaceBasicTest, BulkSetAmplDefault) { + TestBulkSetAmplitudeDefault>(); +} + } // namespace qsim int main(int argc, char** argv) { diff --git a/tests/statespace_sse_test.cc b/tests/statespace_sse_test.cc index 45f61e598..a61ade44d 100644 --- a/tests/statespace_sse_test.cc +++ b/tests/statespace_sse_test.cc @@ -62,10 +62,18 @@ TEST(StateSpaceSSETest, InvalidStateSize) { TestInvalidStateSize>(); } -TEST(StateSpaceBasicTest, BulkSetAmpl) { +TEST(StateSpaceSSETest, BulkSetAmpl) { TestBulkSetAmplitude>(); } +TEST(StateSpaceSSETest, BulkSetAmplExclude) { + TestBulkSetAmplitudeExclusion>(); +} + +TEST(StateSpaceSSETest, BulkSetAmplDefault) { + TestBulkSetAmplitudeDefault>(); +} + } // namespace qsim int main(int argc, char** argv) { diff --git a/tests/statespace_testfixture.h b/tests/statespace_testfixture.h index 28ca0cfd3..dcf78aee2 100644 --- a/tests/statespace_testfixture.h +++ b/tests/statespace_testfixture.h @@ -820,7 +820,7 @@ void TestBulkSetAmplitude() { for(int i = 0; i < 8; i++) { state_space.SetAmpl(state, i, 1, 1); } - state_space.BulkSetAmpl(state, 1, 0, 0, 0); + state_space.BulkSetAmpl(state, 1, 0, 0, 0, false); EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(1, 1)); EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(0, 0)); @@ -833,7 +833,7 @@ void TestBulkSetAmplitude() { for(int i = 0; i < 8; i++) { state_space.SetAmpl(state, i, 1, 1); } - state_space.BulkSetAmpl(state, 2, 0, 0, 0); + state_space.BulkSetAmpl(state, 2, 0, 0, 0, false); EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(0, 0)); EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(1, 1)); @@ -846,7 +846,7 @@ void TestBulkSetAmplitude() { for(int i = 0; i < 8; i++) { state_space.SetAmpl(state, i, 1, 1); } - state_space.BulkSetAmpl(state, 4, 0, 0, 0); + state_space.BulkSetAmpl(state, 4, 0, 0, 0, false); EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(0, 0)); EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(0, 0)); @@ -856,6 +856,89 @@ void TestBulkSetAmplitude() { EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(1, 1)); EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 4 | 1, 4, 0, 0, false); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); +} + +template +void TestBulkSetAmplitudeExclusion() { + using State = typename StateSpace::State; + unsigned num_qubits = 3; + + StateSpace state_space(1); + + State state = state_space.Create(num_qubits); + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 1, 0, 0, 0, true); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(0, 0)); + + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 2, 0, 0, 0, true); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(0, 0)); + + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 4, 0, 0, 0, true); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(0, 0)); + + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 4 | 1, 4, 0, 0, true); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(0, 0)); +} + +template +void TestBulkSetAmplitudeDefault() { + using State = typename StateSpace::State; + unsigned num_qubits = 3; + + StateSpace state_space(1); + + State state = state_space.Create(num_qubits); for(int i = 0; i < 8; i++) { state_space.SetAmpl(state, i, 1, 1); }