Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add test_dot
  • Loading branch information
AndrewZhaoLuo committed Jun 24, 2022
commit cea751e01f7aa0b2d3ac75c92223b772b6b46abc
43 changes: 29 additions & 14 deletions tests/python/integration/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,46 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test scheduling and running a dot product."""
import numpy as np

import tvm
import tvm.testing
from tvm import te
import numpy as np


@tvm.testing.requires_llvm
def test_dot():
nn = 12
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name="A")
B = te.placeholder((n,), name="B")
k = te.reduce_axis((0, n), "k")
C = te.compute((), lambda: te.sum(A[k] * B[k], axis=k), name="C")
s = te.create_schedule(C.op)
"""Test dot product."""
arr_length = 12
arr_lenght_tvm = tvm.runtime.convert(arr_length)
Comment thread
AndrewZhaoLuo marked this conversation as resolved.
Outdated
placeholder_a = te.placeholder((arr_lenght_tvm,), name="A")
placeholder_b = te.placeholder((arr_lenght_tvm,), name="B")
reduce_axis_k = te.reduce_axis((0, arr_lenght_tvm), "k")
result_c = te.compute(
(),
lambda: te.sum(
placeholder_a[reduce_axis_k] * placeholder_b[reduce_axis_k], axis=reduce_axis_k
),
name="C",
)
schedule = te.create_schedule(result_c.op)

def verify(target):
f = tvm.driver.build(s, [A, B, C], target)
f = tvm.driver.build(schedule, [placeholder_a, placeholder_b, result_c], target)
# verify
dev = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(nn,)).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=(nn,)).astype(B.dtype), dev)
c = tvm.nd.array(np.zeros((), dtype=C.dtype), dev)
f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), np.dot(a.numpy(), b.numpy()), rtol=1e-4)
buff_a = tvm.nd.array(
np.random.uniform(size=(arr_length,)).astype(placeholder_a.dtype), dev
)
buff_b = tvm.nd.array(
np.random.uniform(size=(arr_length,)).astype(placeholder_b.dtype), dev
)
buff_c = tvm.nd.array(np.zeros((), dtype=result_c.dtype), dev)
f(buff_a, buff_b, buff_c)
tvm.testing.assert_allclose(
buff_c.numpy(), np.dot(buff_a.numpy(), buff_b.numpy()), rtol=1e-4
)

verify("llvm")

Expand Down