Skip to content

Commit a77729a

Browse files
authored
[MUSA][1/N] sglang.check_env (sgl-project#16959)
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
1 parent bdaa3de commit a77729a

File tree

3 files changed

+102
-2
lines changed

3 files changed

+102
-2
lines changed

python/pyproject_other.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ runtime_common = [
6666
"grpcio==1.75.1", # keep it align with compile_proto.py
6767
"grpcio-tools==1.75.1", # keep it align with compile_proto.py
6868
"grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py
69-
"bidict",
7069
]
7170

7271
tracing = [

python/sglang/check_env.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212

13-
from sglang.srt.utils import is_hip, is_npu
13+
from sglang.srt.utils import is_hip, is_musa, is_npu
1414

1515

1616
def is_cuda_v2():
@@ -423,11 +423,103 @@ def get_topology(self):
423423
return {}
424424

425425

426+
class MUSAEnv(BaseEnv):
427+
"""Environment checker for MThreads GPU"""
428+
429+
def get_info(self):
430+
musa_info = {"MUSA available": torch.musa.is_available()}
431+
432+
if musa_info["MUSA available"]:
433+
musa_info.update(self.get_device_info())
434+
musa_info.update(self._get_musa_version_info())
435+
436+
return musa_info
437+
438+
def _get_musa_version_info(self):
439+
"""
440+
Get MUSA version information.
441+
"""
442+
from torch_musa.utils.musa_extension import MUSA_HOME
443+
444+
musa_info = {"MUSA_HOME": MUSA_HOME}
445+
446+
if MUSA_HOME and os.path.isdir(MUSA_HOME):
447+
musa_info.update(self._get_mcc_info())
448+
musa_info.update(self._get_musa_driver_version())
449+
450+
return musa_info
451+
452+
def _get_mcc_info(self):
453+
"""
454+
Get MCC version information.
455+
"""
456+
from torch_musa.utils.musa_extension import MUSA_HOME
457+
458+
try:
459+
mcc = os.path.join(MUSA_HOME, "bin/mcc")
460+
mcc_output = (
461+
subprocess.check_output(f'"{mcc}" --version', shell=True)
462+
.decode("utf-8")
463+
.strip()
464+
)
465+
return {
466+
"MCC": mcc_output[
467+
mcc_output.rfind("mcc version") : mcc_output.rfind("Target")
468+
].strip()
469+
}
470+
except subprocess.SubprocessError:
471+
return {"MCC": "Not Available"}
472+
473+
def _get_musa_driver_version(self):
474+
"""
475+
Get MUSA driver version.
476+
"""
477+
try:
478+
output = subprocess.check_output(
479+
[
480+
"mthreads-gmi",
481+
"-q",
482+
],
483+
text=True,
484+
)
485+
driver_version = None
486+
for line in output.splitlines():
487+
if "Driver Version" in line:
488+
driver_version = line.split(":", 1)[1].strip()
489+
break
490+
491+
return {"MUSA Driver Version": driver_version}
492+
except subprocess.SubprocessError:
493+
return {"MUSA Driver Version": "Not Available"}
494+
495+
def get_topology(self):
496+
"""
497+
Get GPU topology information.
498+
"""
499+
try:
500+
result = subprocess.run(
501+
["mthreads-gmi", "topo", "-m"],
502+
stdout=subprocess.PIPE,
503+
stderr=subprocess.PIPE,
504+
text=True,
505+
check=True,
506+
)
507+
return {
508+
"MTHREADS Topology": (
509+
"\n" + result.stdout if result.returncode == 0 else None
510+
)
511+
}
512+
except subprocess.SubprocessError:
513+
return {}
514+
515+
426516
if __name__ == "__main__":
427517
if is_cuda_v2():
428518
env = GPUEnv()
429519
elif is_hip():
430520
env = HIPEnv()
431521
elif is_npu():
432522
env = NPUEnv()
523+
elif is_musa():
524+
env = MUSAEnv()
433525
env.check_env()

python/sglang/srt/utils/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,15 @@ def is_cpu() -> bool:
185185
return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_supported
186186

187187

188+
@lru_cache(maxsize=1)
189+
def is_musa() -> bool:
190+
try:
191+
import torchada # noqa: F401
192+
except ImportError:
193+
return False
194+
return hasattr(torch.version, "musa") and torch.version.musa is not None
195+
196+
188197
def is_float4_e2m1fn_x2(dtype) -> bool:
189198
"""Check if dtype is float4_e2m1fn_x2 and CUDA is available."""
190199
target_dtype = getattr(torch, "float4_e2m1fn_x2", None)

0 commit comments

Comments
 (0)