Skip to content

Commit d935902

Browse files
committed
refactor the backend end tests for running all examples scripts
1 parent cd29093 commit d935902

14 files changed

+3371
-110
lines changed

src/tyxonq/backends/abstract_backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def sqrtmh(self: Any, a: Tensor, psd: bool = False) -> Tensor:
6363
if psd:
6464
e = self.relu(e)
6565
e = self.sqrt(e)
66+
# Ensure consistent dtype for complex matrix operations
67+
e = self.cast(e, v.dtype)
6668
return v @ self.diagflat(e) @ self.adjoint(v)
6769

6870
def eigvalsh(self: Any, a: Tensor) -> Tensor:

src/tyxonq/backends/complex_utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,24 @@ def quantum_gradient_safe(complex_tensor: torch.Tensor) -> torch.Tensor:
8282
else:
8383
return complex_tensor
8484

85-
def safe_cast(tensor: torch.Tensor, dtype: str) -> torch.Tensor:
85+
def safe_cast(tensor, dtype):
8686
"""
87-
Safely cast tensor to specified dtype without complex warnings.
87+
Safely cast tensor to the specified dtype, handling complex numbers.
8888
8989
:param tensor: Input tensor
9090
:param dtype: Target dtype
9191
:return: Casted tensor
9292
"""
93-
if dtype in ['float32', 'float64'] and tensor.is_complex():
94-
# For real dtypes, convert complex to real safely
95-
return ComplexHandler.safe_complex_to_real(tensor, "real")
96-
else:
93+
import torch
94+
95+
# Convert NumPy arrays to PyTorch tensors first
96+
if hasattr(tensor, 'numpy'): # Already a PyTorch tensor
97+
return tensor.type(getattr(torch, dtype))
98+
else: # NumPy array or other type
99+
if hasattr(tensor, 'detach'): # Already a PyTorch tensor
100+
tensor = tensor.detach().clone()
101+
else:
102+
tensor = torch.tensor(tensor)
97103
return tensor.type(getattr(torch, dtype))
98104

99105
def quantum_expectation_value(complex_tensor: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)