Skip to content

Commit f1b8835

Browse files
Meteorixyangulei
authored andcommitted
add einsum in pytorch frontend (apache#9651)
* add einsum in pytorch frontend * add einsum in pytorch frontend
1 parent 31267b1 commit f1b8835

2 files changed

Lines changed: 17 additions & 0 deletions

File tree

python/tvm/relay/frontend/pytorch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2819,6 +2819,10 @@ def slide_axes(inp, shape, ax):
28192819

28202820
return out
28212821

2822+
def einsum(self, inputs, input_types):
2823+
equation, data = inputs
2824+
return _op.einsum(data, equation)
2825+
28222826
# Operator mappings
28232827
def create_convert_map(self):
28242828
self.convert_map = {
@@ -3063,6 +3067,7 @@ def create_convert_map(self):
30633067
"aten::searchsorted": self.searchsorted,
30643068
"aten::bucketize": self.bucketize,
30653069
"aten::roll": self.roll,
3070+
"aten::einsum": self.einsum,
30663071
}
30673072

30683073
def update_convert_map(self, custom_map):

tests/python/frontend/pytorch/test_forward.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4020,5 +4020,17 @@ def test_fn(shifts, dims):
40204020
verify_model(test_fn(shifts=(2, 1), dims=(0, 1)), [x])
40214021

40224022

4023+
@tvm.testing.uses_gpu
4024+
def test_einsum():
4025+
def test_fn(equation):
4026+
return lambda *x: torch.einsum(equation, *x)
4027+
4028+
x = torch.ones([2, 3])
4029+
y = torch.ones([3, 4])
4030+
z = torch.ones([4, 5])
4031+
verify_model(test_fn("ij,jk"), [x, y])
4032+
verify_model(test_fn("ij,jk,km->im"), [x, y, z])
4033+
4034+
40234035
if __name__ == "__main__":
40244036
pytest.main([__file__])

0 commit comments

Comments
 (0)