From 6e61c391f33bf84f20affffc528310e4fb64761c Mon Sep 17 00:00:00 2001 From: Garrick Meeker Date: Sat, 6 Jan 2024 14:37:56 -0800 Subject: [PATCH 1/2] [RPC] Fix tuning on macOS and Windows (#15771) Fix regression in (#15187) when multiprocessing start method is not 'fork', which prevented tuning from working. This affects macOS and Windows. Also in python 3.14 the default start method will be 'spawn'. --- python/tvm/rpc/server.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 6ee683c73ba5..8d0881e12c6a 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -119,6 +119,15 @@ def download_linked_module(file_name): return temp +def _serve_loop(sock, load_library, work_path=None): + """Server loop""" + sockfd = sock.fileno() + temp = _server_env(load_library, work_path) + _ffi_api.ServerLoop(sockfd) + if not work_path: + temp.remove() + + def _parse_server_opt(opts): # parse client options ret = {} @@ -128,25 +137,21 @@ def _parse_server_opt(opts): return ret -def _serving(sock, addr, opts, load_library): +def _serving(conn, addr, opts, load_library): logger.info(f"connected from {addr}") work_path = utils.tempdir() old_cwd = os.getcwd() os.chdir(work_path.path) # Avoiding file name conflict between sessions. logger.info(f"start serving at {work_path.path}") - def _serve_loop(): - _server_env(load_library, work_path) - _ffi_api.ServerLoop(sock.fileno()) - - server_proc = multiprocessing.Process(target=_serve_loop) + server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, load_library, work_path)) server_proc.start() server_proc.join(opts.get("timeout", None)) # Wait until finish or timeout. if server_proc.is_alive(): logger.info("timeout in RPC session, kill..") _ffi_api.ReturnException( - sock.fileno(), + conn.fileno(), f'RPCSessionTimeoutError: Your {opts["timeout"]}s session has expired, ' f'try to increase the "session_timeout" value.', ) @@ -166,7 +171,7 @@ def _serve_loop(): logger.info(f"finish serving {addr}") os.chdir(old_cwd) work_path.remove() - sock.close() + conn.close() def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): From 9aa2d6946d8a497ec11788a13c92d7901986bcdd Mon Sep 17 00:00:00 2001 From: Garrick Meeker Date: Thu, 11 Jan 2024 10:48:22 -0800 Subject: [PATCH 2/2] [RPC] clean up _serve_loop function --- python/tvm/rpc/server.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 8d0881e12c6a..ea9576708667 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -119,13 +119,9 @@ def download_linked_module(file_name): return temp -def _serve_loop(sock, load_library, work_path=None): - """Server loop""" - sockfd = sock.fileno() - temp = _server_env(load_library, work_path) - _ffi_api.ServerLoop(sockfd) - if not work_path: - temp.remove() +def _serve_loop(sock, load_library, work_path): + _server_env(load_library, work_path) + _ffi_api.ServerLoop(sock.fileno()) def _parse_server_opt(opts): @@ -137,21 +133,21 @@ def _parse_server_opt(opts): return ret -def _serving(conn, addr, opts, load_library): +def _serving(sock, addr, opts, load_library): logger.info(f"connected from {addr}") work_path = utils.tempdir() old_cwd = os.getcwd() os.chdir(work_path.path) # Avoiding file name conflict between sessions. logger.info(f"start serving at {work_path.path}") - server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, load_library, work_path)) + server_proc = multiprocessing.Process(target=_serve_loop, args=(sock, load_library, work_path)) server_proc.start() server_proc.join(opts.get("timeout", None)) # Wait until finish or timeout. if server_proc.is_alive(): logger.info("timeout in RPC session, kill..") _ffi_api.ReturnException( - conn.fileno(), + sock.fileno(), f'RPCSessionTimeoutError: Your {opts["timeout"]}s session has expired, ' f'try to increase the "session_timeout" value.', ) @@ -171,7 +167,7 @@ def _serving(conn, addr, opts, load_library): logger.info(f"finish serving {addr}") os.chdir(old_cwd) work_path.remove() - conn.close() + sock.close() def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):