|
10 | 10 | """ |
11 | 11 |
|
12 | 12 | # Standard Python modules |
| 13 | +import contextlib |
13 | 14 | import importlib |
14 | 15 | import os |
15 | 16 | import sys |
16 | 17 | import types |
17 | | -from typing import Optional, Tuple, Union |
| 18 | +from typing import Literal, Sequence, Tuple, Union |
18 | 19 | import warnings |
19 | 20 |
|
20 | 21 | # External modules |
@@ -361,9 +362,9 @@ def convertToCSC(mat: Union[dict, spmatrix, ndarray]) -> dict: |
361 | 362 |
|
362 | 363 | def convertToDense(mat: Union[dict, spmatrix, ndarray]) -> ndarray: |
363 | 364 | """ |
364 | | - Take a pyopsparse sparse matrix definition and convert back to a dense |
| 365 | + Take a pyoptsparse sparse matrix definition and convert back to a dense |
365 | 366 | format. This is typically the final step for optimizers with dense constraint |
366 | | - jacibians. |
| 367 | + jacobians. |
367 | 368 |
|
368 | 369 | Parameters |
369 | 370 | ---------- |
@@ -576,40 +577,53 @@ def _broadcast_to_array(name: str, value: ArrayType, n_values: int, allow_none: |
576 | 577 | return value |
577 | 578 |
|
578 | 579 |
|
579 | | -def try_import_compiled_module_from_path( |
580 | | - module_name: str, path: Optional[str] = None, raise_warning: bool = False |
581 | | -) -> Union[types.ModuleType, str]: |
| 580 | +@contextlib.contextmanager |
| 581 | +def _prepend_path(path: Union[str, Sequence[str]]): |
| 582 | + """Context manager which temporarily prepends to `sys.path`.""" |
| 583 | + if isinstance(path, str): |
| 584 | + path = [path] |
| 585 | + orig_path = sys.path |
| 586 | + if path: |
| 587 | + path = [os.path.abspath(os.path.expandvars(os.path.expanduser(p))) for p in path] |
| 588 | + sys.path = path + sys.path |
| 589 | + yield |
| 590 | + sys.path = orig_path |
| 591 | + return |
| 592 | + |
| 593 | + |
| 594 | +def import_module( |
| 595 | + module_name: str, |
| 596 | + path: Union[str, Sequence[str]] = (), |
| 597 | + on_error: Literal["raise", "return"] = "return", |
| 598 | +) -> Union[types.ModuleType, Exception]: |
582 | 599 | """ |
583 | 600 | Attempt to import a module from a given path. |
584 | 601 |
|
585 | 602 | Parameters |
586 | 603 | ---------- |
587 | 604 | module_name : str |
588 | | - The name of the module |
589 | | - path : Optional[str] |
590 | | - The path to import from. If None, the default ``sys.path`` is used. |
591 | | - raise_warning : bool |
592 | | - If true, raise an import warning. By default false. |
| 605 | + The name of the module. |
| 606 | + path : Union[str, Sequence[str]] |
| 607 | + The search path, which will be prepended to ``sys.path``. May be a string, or a sequence of strings. |
| 608 | + on_error : str |
| 609 | + Specify behavior when import fails. If "raise", any exception raised during the import will be raised. |
| 610 | + If "return", any exception during the import will be returned. |
593 | 611 |
|
594 | 612 | Returns |
595 | 613 | ------- |
596 | 614 | Union[types.ModuleType, str] |
597 | 615 | If importable, the imported module is returned. |
598 | | - If not importable, the error message is instead returned. |
| 616 | + If not importable, the exception is returned. |
599 | 617 | """ |
600 | | - orig_path = sys.path |
601 | | - if path is not None: |
602 | | - path = os.path.abspath(os.path.expandvars(os.path.expanduser(path))) |
603 | | - sys.path = [path] |
604 | | - try: |
605 | | - module = importlib.import_module(module_name) |
606 | | - except ImportError as e: |
607 | | - if raise_warning: |
608 | | - warnings.warn( |
609 | | - f"{module_name} module could not be imported from {path}.", |
610 | | - stacklevel=2, |
611 | | - ) |
612 | | - module = str(e) |
613 | | - finally: |
614 | | - sys.path = orig_path |
| 618 | + if on_error.lower() not in ("raise", "return"): |
| 619 | + raise ValueError("`on_error` must be 'raise' or 'return'.") |
| 620 | + |
| 621 | + with _prepend_path(path): |
| 622 | + try: |
| 623 | + module = importlib.import_module(module_name) |
| 624 | + except ImportError as e: |
| 625 | + if on_error.lower() == "raise": |
| 626 | + raise e |
| 627 | + else: |
| 628 | + module = e |
615 | 629 | return module |
0 commit comments