|
10 | 10 |
|
11 | 11 | import torch |
12 | 12 |
|
13 | | -from sglang.srt.utils import is_hip, is_npu |
| 13 | +from sglang.srt.utils import is_hip, is_musa, is_npu |
14 | 14 |
|
15 | 15 |
|
16 | 16 | def is_cuda_v2(): |
@@ -423,11 +423,103 @@ def get_topology(self): |
423 | 423 | return {} |
424 | 424 |
|
425 | 425 |
|
| 426 | +class MUSAEnv(BaseEnv): |
| 427 | + """Environment checker for MThreads GPU""" |
| 428 | + |
| 429 | + def get_info(self): |
| 430 | + musa_info = {"MUSA available": torch.musa.is_available()} |
| 431 | + |
| 432 | + if musa_info["MUSA available"]: |
| 433 | + musa_info.update(self.get_device_info()) |
| 434 | + musa_info.update(self._get_musa_version_info()) |
| 435 | + |
| 436 | + return musa_info |
| 437 | + |
| 438 | + def _get_musa_version_info(self): |
| 439 | + """ |
| 440 | + Get MUSA version information. |
| 441 | + """ |
| 442 | + from torch_musa.utils.musa_extension import MUSA_HOME |
| 443 | + |
| 444 | + musa_info = {"MUSA_HOME": MUSA_HOME} |
| 445 | + |
| 446 | + if MUSA_HOME and os.path.isdir(MUSA_HOME): |
| 447 | + musa_info.update(self._get_mcc_info()) |
| 448 | + musa_info.update(self._get_musa_driver_version()) |
| 449 | + |
| 450 | + return musa_info |
| 451 | + |
| 452 | + def _get_mcc_info(self): |
| 453 | + """ |
| 454 | + Get MCC version information. |
| 455 | + """ |
| 456 | + from torch_musa.utils.musa_extension import MUSA_HOME |
| 457 | + |
| 458 | + try: |
| 459 | + mcc = os.path.join(MUSA_HOME, "bin/mcc") |
| 460 | + mcc_output = ( |
| 461 | + subprocess.check_output(f'"{mcc}" --version', shell=True) |
| 462 | + .decode("utf-8") |
| 463 | + .strip() |
| 464 | + ) |
| 465 | + return { |
| 466 | + "MCC": mcc_output[ |
| 467 | + mcc_output.rfind("mcc version") : mcc_output.rfind("Target") |
| 468 | + ].strip() |
| 469 | + } |
| 470 | + except subprocess.SubprocessError: |
| 471 | + return {"MCC": "Not Available"} |
| 472 | + |
| 473 | + def _get_musa_driver_version(self): |
| 474 | + """ |
| 475 | + Get MUSA driver version. |
| 476 | + """ |
| 477 | + try: |
| 478 | + output = subprocess.check_output( |
| 479 | + [ |
| 480 | + "mthreads-gmi", |
| 481 | + "-q", |
| 482 | + ], |
| 483 | + text=True, |
| 484 | + ) |
| 485 | + driver_version = None |
| 486 | + for line in output.splitlines(): |
| 487 | + if "Driver Version" in line: |
| 488 | + driver_version = line.split(":", 1)[1].strip() |
| 489 | + break |
| 490 | + |
| 491 | + return {"MUSA Driver Version": driver_version} |
| 492 | + except subprocess.SubprocessError: |
| 493 | + return {"MUSA Driver Version": "Not Available"} |
| 494 | + |
| 495 | + def get_topology(self): |
| 496 | + """ |
| 497 | + Get GPU topology information. |
| 498 | + """ |
| 499 | + try: |
| 500 | + result = subprocess.run( |
| 501 | + ["mthreads-gmi", "topo", "-m"], |
| 502 | + stdout=subprocess.PIPE, |
| 503 | + stderr=subprocess.PIPE, |
| 504 | + text=True, |
| 505 | + check=True, |
| 506 | + ) |
| 507 | + return { |
| 508 | + "MTHREADS Topology": ( |
| 509 | + "\n" + result.stdout if result.returncode == 0 else None |
| 510 | + ) |
| 511 | + } |
| 512 | + except subprocess.SubprocessError: |
| 513 | + return {} |
| 514 | + |
| 515 | + |
426 | 516 | if __name__ == "__main__": |
427 | 517 | if is_cuda_v2(): |
428 | 518 | env = GPUEnv() |
429 | 519 | elif is_hip(): |
430 | 520 | env = HIPEnv() |
431 | 521 | elif is_npu(): |
432 | 522 | env = NPUEnv() |
| 523 | + elif is_musa(): |
| 524 | + env = MUSAEnv() |
433 | 525 | env.check_env() |
0 commit comments