Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,10 @@ cc_library(
cc_library(
name = "unitary_calculator_basic",
hdrs = ["unitary_calculator_basic.h"],
deps = [":unitaryspace_basic"],
deps = [
":bits",
":unitaryspace_basic"
],
)

### Unitary mux header ###
Expand Down
105 changes: 101 additions & 4 deletions lib/unitary_calculator_basic.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cstdint>
#include <vector>

#include "bits.h"
#include "unitaryspace_basic.h"

namespace qsim {
Expand Down Expand Up @@ -69,8 +70,9 @@ class UnitaryCalculatorBasic final {
const fp_type* matrix, Unitary& state) const {
if (qs.size() == 1) {
ApplyControlledGate1(qs[0], cqs, cmask, matrix, state);
} else if (qs.size() == 2) {
ApplyControlledGate2(qs[0], qs[1], cqs, cmask, matrix, state);
}
// Implement 2 qubit version.
}

private:
Expand Down Expand Up @@ -164,6 +166,101 @@ class UnitaryCalculatorBasic final {
emaskh, rstate);
}

void ApplyControlledGate2(unsigned q0, unsigned q1,
const std::vector<unsigned>& cqs, uint64_t cmask,
const fp_type* matrix, State& state) const {
uint64_t xs[2];
uint64_t ms[3];

xs[0] = uint64_t{1} << (q0 + 1);
ms[0] = (uint64_t{1} << q0) - 1;

xs[1] = uint64_t{1} << (q1 + 1);
ms[1] = ((uint64_t{1} << q1) - 1) ^ (xs[0] - 1);

ms[2] = ((uint64_t{1} << num_qubits_) - 1) ^ (xs[1] - 1);

uint64_t xss[4];
for (unsigned i = 0; i < 4; ++i) {
uint64_t a = 0;
for (uint64_t k = 0; k < 2; ++k) {
if (((i >> k) & 1) == 1) {
a += xs[k];
}
}
xss[i] = a;
}

uint64_t emaskh = 0;

for (auto q : cqs) {
emaskh |= uint64_t{1} << q;
}

uint64_t cmaskh = bits::ExpandBits(cmask, num_qubits_, emaskh);

emaskh |= uint64_t{1} << q0;
emaskh |= uint64_t{1} << q1;

emaskh = ~emaskh;

auto f = [](unsigned n, unsigned m, uint64_t ii, const fp_type* v,
const uint64_t* ms, const uint64_t* xss, unsigned n_qb,
unsigned sqrt_size, uint64_t cmaskh, uint64_t emaskh,
fp_type* rstate) {
fp_type rn, in;
fp_type rs[16], is[16];

auto row_size = uint64_t{1} << n_qb;

uint64_t i = ii % sqrt_size;
uint64_t j = ii / sqrt_size;

uint64_t col_loc = (1 * i & ms[0]) | (2 * i & ms[1]) | (4 * i & ms[2]);
uint64_t row_loc = bits::ExpandBits(j, n_qb, emaskh) | cmaskh;

auto p0 = rstate + row_size * 2 * row_loc + 2 * col_loc;

for (unsigned l = 0; l < 4; ++l) {
for (unsigned k = 0; k < 4; ++k) {
rs[4 * l + k] = *(p0 + xss[l] * row_size + xss[k]);
is[4 * l + k] = *(p0 + xss[l] * row_size + xss[k] + 1);
}
}

for (unsigned l = 0; l < 4; l++) {
uint64_t j = 0;
for (unsigned k = 0; k < 4; ++k) {
rn = rs[l] * v[j] - is[l] * v[j + 1];
in = rs[l] * v[j + 1] + is[l] * v[j];
j += 2;

for (unsigned p = 1; p < 4; ++p) {
rn += rs[4 * p + l] * v[j] - is[4 * p + l] * v[j + 1];
in += rs[4 * p + l] * v[j + 1] + is[4 * p + l] * v[j];

j += 2;
}
*(p0 + xss[k] * row_size + xss[l]) = rn;
*(p0 + xss[k] * row_size + xss[l] + 1) = in;
}
}
};

fp_type* rstate = state.get();

unsigned k = 2 + cqs.size();
unsigned n = num_qubits_ > k ? num_qubits_ - k : 0;
uint64_t size = uint64_t{1} << n;

unsigned kk = 2;
unsigned nn = num_qubits_ > kk ? num_qubits_ - kk : 0;
uint64_t size2 = uint64_t{1} << nn;

for_.Run(size * size2, f, matrix, ms, xss, num_qubits_, size2, cmaskh,
emaskh, rstate);
}

void ApplyGate1(unsigned q0, const fp_type* matrix, Unitary& state) const {
uint64_t xs[1];
uint64_t ms[2];
Expand Down Expand Up @@ -259,9 +356,9 @@ class UnitaryCalculatorBasic final {
xss[i] = a;
}

auto f = [q0, q1](unsigned n, unsigned m, uint64_t ii, const fp_type* v,
const uint64_t* ms, const uint64_t* xss, unsigned n_qb,
unsigned sqrt_size, fp_type* rstate) {
auto f = [](unsigned n, unsigned m, uint64_t ii, const fp_type* v,
const uint64_t* ms, const uint64_t* xss, unsigned n_qb,
unsigned sqrt_size, fp_type* rstate) {
fp_type rn, in;
fp_type rs[16], is[16];

Expand Down
14 changes: 8 additions & 6 deletions tests/unitary_calculator_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "unitary_calculator_testfixture.h"

#include "gtest/gtest.h"
#include "../lib/unitary_calculator_basic.h"

#include "../lib/formux.h"
#include "../lib/unitaryspace_basic.h"
#include "../lib/unitary_calculator_basic.h"
#include "gtest/gtest.h"
#include "unitary_calculator_testfixture.h"

namespace qsim {

namespace unitary {
namespace {

Expand All @@ -37,11 +35,15 @@ TEST(UnitaryCalculatorTest, ApplyGate2) {
TestApplyGate2<UnitaryCalculatorBasic<For, float>>();
}

TEST(UnitaryCalculatorTest, ApplyControlledGate2) {
TestApplyControlledGate2<UnitaryCalculatorBasic<For, float>>();
}

TEST(UnitaryCalculatorTest, ApplyFusedGate) {
TestApplyFusedGate<UnitaryCalculatorBasic<For, float>>();
}

} // namspace
} // namespace
} // namespace unitary
} // namespace qsim

Expand Down
69 changes: 69 additions & 0 deletions tests/unitary_calculator_testfixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,75 @@ void TestApplyGate2() {
EUnitaryEQ(us, u, n_qubits, expected_mat_02);
}

template <typename UC>
void TestApplyControlledGate2() {
const int n_qubits = 3;
UC uc(n_qubits, 1);
using UnitarySpace = typename UC::UnitarySpace;
using Unitary = typename UC::Unitary;

UnitarySpace us(n_qubits, 1);
Unitary u = us.CreateUnitary();

// clang-format off
float ref_gate[] = {1,2,3,4,5,6,7,8,
9,10,11,12,13,14,15,16,
17,18,19,20,21,22,23,24,
25,26,27,28,29,30,31,32};
// clang-format on

// Test applying on qubit 0, 1
FillMatrix(us, u, n_qubits);
// clang-format off
float expected_mat_01[] = {
0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,
16.0,17.0,18.0,19.0,20.0,21.0,22.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0,31.0,
32.0,33.0,34.0,35.0,36.0,37.0,38.0,39.0,40.0,41.0,42.0,43.0,44.0,45.0,46.0,47.0,
48.0,49.0,50.0,51.0,52.0,53.0,54.0,55.0,56.0,57.0,58.0,59.0,60.0,61.0,62.0,63.0,
-372.0,3504.0,-380.0,3576.0,-388.0,3648.0,-396.0,3720.0,-404.0,3792.0,-412.0,3864.0,-420.0,3936.0,-428.0,4008.0,
-404.0,9168.0,-412.0,9368.0,-420.0,9568.0,-428.0,9768.0,-436.0,9968.0,-444.0,10168.0,-452.0,10368.0,-460.0,10568.0,
-436.0,14832.0,-444.0,15160.0,-452.0,15488.0,-460.0,15816.0,-468.0,16144.0,-476.0,16472.0,-484.0,16800.0,-492.0,17128.0,
-468.0,20496.0,-476.0,20952.0,-484.0,21408.0,-492.0,21864.0,-500.0,22320.0,-508.0,22776.0,-516.0,23232.0,-524.0,23688.0,
};
// clang-format on
uc.ApplyControlledGate({0, 1}, {2}, 1, ref_gate, u);
EUnitaryEQ(us, u, n_qubits, expected_mat_01);

// Test applying on qubit 1, 2
FillMatrix(us, u, n_qubits);
// clang-format off
float expected_mat_12[] = {
0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,
-276.0,2960.0,-284.0,3032.0,-292.0,3104.0,-300.0,3176.0,-308.0,3248.0,-316.0,3320.0,-324.0,3392.0,-332.0,3464.0,
32.0,33.0,34.0,35.0,36.0,37.0,38.0,39.0,40.0,41.0,42.0,43.0,44.0,45.0,46.0,47.0,
-308.0,7088.0,-316.0,7288.0,-324.0,7488.0,-332.0,7688.0,-340.0,7888.0,-348.0,8088.0,-356.0,8288.0,-364.0,8488.0,
64.0,65.0,66.0,67.0,68.0,69.0,70.0,71.0,72.0,73.0,74.0,75.0,76.0,77.0,78.0,79.0,
-340.0,11216.0,-348.0,11544.0,-356.0,11872.0,-364.0,12200.0,-372.0,12528.0,-380.0,12856.0,-388.0,13184.0,-396.0,13512.0,
96.0,97.0,98.0,99.0,100.0,101.0,102.0,103.0,104.0,105.0,106.0,107.0,108.0,109.0,110.0,111.0,
-372.0,15344.0,-380.0,15800.0,-388.0,16256.0,-396.0,16712.0,-404.0,17168.0,-412.0,17624.0,-420.0,18080.0,-428.0,18536.0,
};
// clang-format on
uc.ApplyControlledGate({1, 2}, {0}, 1, ref_gate, u);
EUnitaryEQ(us, u, n_qubits, expected_mat_12);

// Test applying on qubit 0, 2
FillMatrix(us, u, n_qubits);
// clang-format off
float expected_mat_02[] = {
0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,
16.0,17.0,18.0,19.0,20.0,21.0,22.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0,31.0,
-308.0,3184.0,-316.0,3256.0,-324.0,3328.0,-332.0,3400.0,-340.0,3472.0,-348.0,3544.0,-356.0,3616.0,-364.0,3688.0,
-340.0,7824.0,-348.0,8024.0,-356.0,8224.0,-364.0,8424.0,-372.0,8624.0,-380.0,8824.0,-388.0,9024.0,-396.0,9224.0,
64.0,65.0,66.0,67.0,68.0,69.0,70.0,71.0,72.0,73.0,74.0,75.0,76.0,77.0,78.0,79.0,
80.0,81.0,82.0,83.0,84.0,85.0,86.0,87.0,88.0,89.0,90.0,91.0,92.0,93.0,94.0,95.0,
-372.0,12464.0,-380.0,12792.0,-388.0,13120.0,-396.0,13448.0,-404.0,13776.0,-412.0,14104.0,-420.0,14432.0,-428.0,14760.0,
-404.0,17104.0,-412.0,17560.0,-420.0,18016.0,-428.0,18472.0,-436.0,18928.0,-444.0,19384.0,-452.0,19840.0,-460.0,20296.0,
};
// clang-format on
uc.ApplyControlledGate({0, 2}, {1}, 1, ref_gate, u);
EUnitaryEQ(us, u, n_qubits, expected_mat_02);
}

template <typename UnitaryCalculator>
void TestApplyFusedGate() {
using UnitarySpace = typename UnitaryCalculator::UnitarySpace;
Expand Down