From d37e12f207a31c9a15b7c87aa77b67d483e7ba86 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 19 Aug 2025 23:17:53 +0800 Subject: [PATCH 1/6] init Signed-off-by: Weichen Xu --- python/pyspark/ml/util.py | 7 +++++++ python/pyspark/sql/connect/client/core.py | 15 +++++++++++++++ .../src/main/protobuf/spark/connect/ml.proto | 6 ++++++ .../org/apache/spark/sql/connect/ml/MLCache.scala | 1 + .../apache/spark/sql/connect/ml/MLHandler.scala | 9 +++++++++ 5 files changed, 38 insertions(+) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 5f23826c73cce..880752edf18e0 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -360,6 +360,13 @@ def del_remote_cache(ref_id: str) -> None: return +# query estimated model in-memory size +def query_model_size(ref_id: str) -> None: + assert ref_id is not None and "." not in ref_id + session = SparkSession.getActiveSession() + assert session is not None + + def try_remote_del(f: FuncT) -> FuncT: """Mark the function/property to delete a model on the server side.""" diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 6ac4cc1894c72..7132be434f461 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -2027,3 +2027,18 @@ def _get_ml_cache_info(self) -> List[str]: return [item.string for item in ml_command_result.param.array.elements] return [] + + def _query_model_size(self, model_ref_id) -> int: + command = pb2.Command() + command.ml_command.delete.obj_refs.extend( + [pb2.ObjectRef(id=cache_id) for cache_id in cache_ids] + ) + command.ml_command.delete.evict_only = evict_only + (_, properties, _) = self.execute_command(command) + + assert properties is not None + + if properties is not None and "ml_command_result" in properties: + ml_command_result = properties["ml_command_result"] + deleted = ml_command_result.operator_info.obj_ref.id.split(",") + return cast(List[str], deleted) \ No newline at end of file diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto index 3497284af4ab8..ef5c406dedd26 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto @@ -39,6 +39,7 @@ message MlCommand { CleanCache clean_cache = 7; GetCacheInfo get_cache_info = 8; CreateSummary create_summary = 9; + GetModelSize get_model_size = 10; } // Command for estimator.fit(dataset) @@ -109,6 +110,11 @@ message MlCommand { ObjectRef model_ref = 1; Relation dataset = 2; } + + // This is for query the model estimated in-memory size + message GetModelSize { + ObjectRef model_ref = 1; + } } // The result of MlCommand diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index f944110a54c77..ef51c0c7498c8 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -241,6 +241,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { def clear(): Int = this.synchronized { val size = cachedModel.size() cachedModel.clear() + totalMLCacheSizeBytes.set(0L) if (getMemoryControlEnabled) { SparkFileUtils.cleanDirectory(new File(offloadedModelsDir.toString)) } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index 4de4f238e41a9..40f1172677a50 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -448,6 +448,15 @@ private[connect] object MLHandler extends Logging { val createSummaryCmd = mlCommand.getCreateSummary createModelSummary(sessionHolder, createSummaryCmd) + case proto.MlCommand.CommandCase.GET_MODEL_SIZE => + val modelRefId = mlCommand.getGetModelSize.getModelRef.getId + val model = mlCache.get(modelRefId) + val modelSize = model.asInstanceOf[Model[_]].estimatedSize + proto.MlCommandResult + .newBuilder() + .setParam(LiteralValueProtoConverter.toLiteralProto(modelSize)) + .build() + case other => throw MlUnsupportedException(s"$other not supported") } } From f0100be59b077999fcbea163af4858bd7ae4ab8c Mon Sep 17 00:00:00 2001 From: weichenxu123 Date: Tue, 19 Aug 2025 15:32:46 +0000 Subject: [PATCH 2/6] update Signed-off-by: weichenxu123 --- python/pyspark/sql/connect/proto/ml_pb2.py | 50 +++++++++++---------- python/pyspark/sql/connect/proto/ml_pb2.pyi | 29 ++++++++++++ 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py b/python/pyspark/sql/connect/proto/ml_pb2.py index 1ede558b94140..4c1b4038c35e3 100644 --- a/python/pyspark/sql/connect/proto/ml_pb2.py +++ b/python/pyspark/sql/connect/proto/ml_pb2.py @@ -40,7 +40,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb1\r\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x12\x46\n\x0b\x63lean_cache\x18\x07 \x01(\x0b\x32#.spark.connect.MlCommand.CleanCacheH\x00R\ncleanCache\x12M\n\x0eget_cache_info\x18\x08 \x01(\x0b\x32%.spark.connect.MlCommand.GetCacheInfoH\x00R\x0cgetCacheInfo\x12O\n\x0e\x63reate_summary\x18\t \x01(\x0b\x32&.spark.connect.MlCommand.CreateSummaryH\x00R\rcreateSummary\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ap\n\x06\x44\x65lete\x12\x33\n\x08obj_refs\x18\x01 \x03(\x0b\x32\x18.spark.connect.ObjectRefR\x07objRefs\x12"\n\nevict_only\x18\x02 \x01(\x08H\x00R\tevictOnly\x88\x01\x01\x42\r\n\x0b_evict_only\x1a\x0c\n\nCleanCache\x1a\x0e\n\x0cGetCacheInfo\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ay\n\rCreateSummary\x12\x35\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x08modelRef\x12\x31\n\x07\x64\x61taset\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07\x63ommand"\xd5\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\x85\x02\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x12,\n\x0fwarning_message\x18\x05 \x01(\tH\x03R\x0ewarningMessage\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\x12\n\x10_warning_messageB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xc7\x0e\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x12\x46\n\x0b\x63lean_cache\x18\x07 \x01(\x0b\x32#.spark.connect.MlCommand.CleanCacheH\x00R\ncleanCache\x12M\n\x0eget_cache_info\x18\x08 \x01(\x0b\x32%.spark.connect.MlCommand.GetCacheInfoH\x00R\x0cgetCacheInfo\x12O\n\x0e\x63reate_summary\x18\t \x01(\x0b\x32&.spark.connect.MlCommand.CreateSummaryH\x00R\rcreateSummary\x12M\n\x0eget_model_size\x18\n \x01(\x0b\x32%.spark.connect.MlCommand.GetModelSizeH\x00R\x0cgetModelSize\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ap\n\x06\x44\x65lete\x12\x33\n\x08obj_refs\x18\x01 \x03(\x0b\x32\x18.spark.connect.ObjectRefR\x07objRefs\x12"\n\nevict_only\x18\x02 \x01(\x08H\x00R\tevictOnly\x88\x01\x01\x42\r\n\x0b_evict_only\x1a\x0c\n\nCleanCache\x1a\x0e\n\x0cGetCacheInfo\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ay\n\rCreateSummary\x12\x35\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x08modelRef\x12\x31\n\x07\x64\x61taset\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61taset\x1a\x45\n\x0cGetModelSize\x12\x35\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x08modelRefB\t\n\x07\x63ommand"\xd5\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\x85\x02\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x12,\n\x0fwarning_message\x18\x05 \x01(\tH\x03R\x0ewarningMessage\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\x12\n\x10_warning_messageB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -54,27 +54,29 @@ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001" _globals["_MLCOMMAND"]._serialized_start = 137 - _globals["_MLCOMMAND"]._serialized_end = 1850 - _globals["_MLCOMMAND_FIT"]._serialized_start = 712 - _globals["_MLCOMMAND_FIT"]._serialized_end = 890 - _globals["_MLCOMMAND_DELETE"]._serialized_start = 892 - _globals["_MLCOMMAND_DELETE"]._serialized_end = 1004 - _globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 1006 - _globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 1018 - _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 1020 - _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 1034 - _globals["_MLCOMMAND_WRITE"]._serialized_start = 1037 - _globals["_MLCOMMAND_WRITE"]._serialized_end = 1447 - _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1349 - _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1407 - _globals["_MLCOMMAND_READ"]._serialized_start = 1449 - _globals["_MLCOMMAND_READ"]._serialized_end = 1530 - _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1533 - _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1716 - _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_start = 1718 - _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_end = 1839 - _globals["_MLCOMMANDRESULT"]._serialized_start = 1853 - _globals["_MLCOMMANDRESULT"]._serialized_end = 2322 - _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 2046 - _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2307 + _globals["_MLCOMMAND"]._serialized_end = 2000 + _globals["_MLCOMMAND_FIT"]._serialized_start = 791 + _globals["_MLCOMMAND_FIT"]._serialized_end = 969 + _globals["_MLCOMMAND_DELETE"]._serialized_start = 971 + _globals["_MLCOMMAND_DELETE"]._serialized_end = 1083 + _globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 1085 + _globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 1097 + _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 1099 + _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 1113 + _globals["_MLCOMMAND_WRITE"]._serialized_start = 1116 + _globals["_MLCOMMAND_WRITE"]._serialized_end = 1526 + _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1428 + _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1486 + _globals["_MLCOMMAND_READ"]._serialized_start = 1528 + _globals["_MLCOMMAND_READ"]._serialized_end = 1609 + _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1612 + _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1795 + _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_start = 1797 + _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_end = 1918 + _globals["_MLCOMMAND_GETMODELSIZE"]._serialized_start = 1920 + _globals["_MLCOMMAND_GETMODELSIZE"]._serialized_end = 1989 + _globals["_MLCOMMANDRESULT"]._serialized_start = 2003 + _globals["_MLCOMMANDRESULT"]._serialized_end = 2472 + _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 2196 + _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2457 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi b/python/pyspark/sql/connect/proto/ml_pb2.pyi index 0a72c207b5264..156ef846a8d10 100644 --- a/python/pyspark/sql/connect/proto/ml_pb2.pyi +++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi @@ -388,6 +388,26 @@ class MlCommand(google.protobuf.message.Message): field_name: typing_extensions.Literal["dataset", b"dataset", "model_ref", b"model_ref"], ) -> None: ... + class GetModelSize(google.protobuf.message.Message): + """This is for query the model estimated in-memory size""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MODEL_REF_FIELD_NUMBER: builtins.int + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ObjectRef: ... + def __init__( + self, + *, + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref"] + ) -> None: ... + FIT_FIELD_NUMBER: builtins.int FETCH_FIELD_NUMBER: builtins.int DELETE_FIELD_NUMBER: builtins.int @@ -397,6 +417,7 @@ class MlCommand(google.protobuf.message.Message): CLEAN_CACHE_FIELD_NUMBER: builtins.int GET_CACHE_INFO_FIELD_NUMBER: builtins.int CREATE_SUMMARY_FIELD_NUMBER: builtins.int + GET_MODEL_SIZE_FIELD_NUMBER: builtins.int @property def fit(self) -> global___MlCommand.Fit: ... @property @@ -415,6 +436,8 @@ class MlCommand(google.protobuf.message.Message): def get_cache_info(self) -> global___MlCommand.GetCacheInfo: ... @property def create_summary(self) -> global___MlCommand.CreateSummary: ... + @property + def get_model_size(self) -> global___MlCommand.GetModelSize: ... def __init__( self, *, @@ -427,6 +450,7 @@ class MlCommand(google.protobuf.message.Message): clean_cache: global___MlCommand.CleanCache | None = ..., get_cache_info: global___MlCommand.GetCacheInfo | None = ..., create_summary: global___MlCommand.CreateSummary | None = ..., + get_model_size: global___MlCommand.GetModelSize | None = ..., ) -> None: ... def HasField( self, @@ -447,6 +471,8 @@ class MlCommand(google.protobuf.message.Message): b"fit", "get_cache_info", b"get_cache_info", + "get_model_size", + b"get_model_size", "read", b"read", "write", @@ -472,6 +498,8 @@ class MlCommand(google.protobuf.message.Message): b"fit", "get_cache_info", b"get_cache_info", + "get_model_size", + b"get_model_size", "read", b"read", "write", @@ -491,6 +519,7 @@ class MlCommand(google.protobuf.message.Message): "clean_cache", "get_cache_info", "create_summary", + "get_model_size", ] | None ): ... From 7ef0e3fcd3aae040928b3d9e1c66b3beef4d759f Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 19 Aug 2025 23:53:42 +0800 Subject: [PATCH 3/6] update Signed-off-by: Weichen Xu --- python/pyspark/ml/util.py | 7 ------- python/pyspark/sql/connect/client/core.py | 14 +++++++------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 880752edf18e0..5f23826c73cce 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -360,13 +360,6 @@ def del_remote_cache(ref_id: str) -> None: return -# query estimated model in-memory size -def query_model_size(ref_id: str) -> None: - assert ref_id is not None and "." not in ref_id - session = SparkSession.getActiveSession() - assert session is not None - - def try_remote_del(f: FuncT) -> FuncT: """Mark the function/property to delete a model on the server side.""" diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 7132be434f461..a1f8b659117f2 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -2030,15 +2030,15 @@ def _get_ml_cache_info(self) -> List[str]: def _query_model_size(self, model_ref_id) -> int: command = pb2.Command() - command.ml_command.delete.obj_refs.extend( - [pb2.ObjectRef(id=cache_id) for cache_id in cache_ids] + command.ml_command.read.CopyFrom( + pb2.MlCommand.GetModelSize( + model_ref=pb2.ObjectRef(id=model_ref_id) + ) ) - command.ml_command.delete.evict_only = evict_only + command.ml_command.get_model_size.model_ref = pb2.ObjectRef(id=model_ref_id) (_, properties, _) = self.execute_command(command) assert properties is not None - if properties is not None and "ml_command_result" in properties: - ml_command_result = properties["ml_command_result"] - deleted = ml_command_result.operator_info.obj_ref.id.split(",") - return cast(List[str], deleted) \ No newline at end of file + ml_command_result = properties["ml_command_result"] + return ml_command_result.param.long From 2297f1932fd5d76d8c667f0bc305bdbd1af11332 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 20 Aug 2025 12:09:28 +0000 Subject: [PATCH 4/6] fix Signed-off-by: Weichen Xu --- python/pyspark/ml/tests/connect/test_connect_cache.py | 3 +++ python/pyspark/sql/connect/client/core.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/tests/connect/test_connect_cache.py b/python/pyspark/ml/tests/connect/test_connect_cache.py index f911ab22286c0..b6c801f32eaf0 100644 --- a/python/pyspark/ml/tests/connect/test_connect_cache.py +++ b/python/pyspark/ml/tests/connect/test_connect_cache.py @@ -51,6 +51,9 @@ def test_delete_model(self): # the `model._summary` holds another ref to the remote model. assert model._java_obj._ref_count == 2 + model_size = spark.client._query_model_size(model._java_obj.ref_id) + assert isinstance(model_size, int) and model_size > 0 + model2 = model.copy() cache_info = spark.client._get_ml_cache_info() self.assertEqual(len(cache_info), 1) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index a1f8b659117f2..bb90c5b8f70d5 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -2030,12 +2030,11 @@ def _get_ml_cache_info(self) -> List[str]: def _query_model_size(self, model_ref_id) -> int: command = pb2.Command() - command.ml_command.read.CopyFrom( + command.ml_command.get_model_size.CopyFrom( pb2.MlCommand.GetModelSize( model_ref=pb2.ObjectRef(id=model_ref_id) ) ) - command.ml_command.get_model_size.model_ref = pb2.ObjectRef(id=model_ref_id) (_, properties, _) = self.execute_command(command) assert properties is not None From 859e8c0204a3449beca962d0c61b264b78a8f169 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 20 Aug 2025 12:09:57 +0000 Subject: [PATCH 5/6] format Signed-off-by: Weichen Xu --- python/pyspark/sql/connect/client/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index bb90c5b8f70d5..bd657a55a2a1e 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -2031,9 +2031,7 @@ def _get_ml_cache_info(self) -> List[str]: def _query_model_size(self, model_ref_id) -> int: command = pb2.Command() command.ml_command.get_model_size.CopyFrom( - pb2.MlCommand.GetModelSize( - model_ref=pb2.ObjectRef(id=model_ref_id) - ) + pb2.MlCommand.GetModelSize(model_ref=pb2.ObjectRef(id=model_ref_id)) ) (_, properties, _) = self.execute_command(command) From d8c90b30f6fec44f6b82c3b0370b3251d20f532f Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 20 Aug 2025 21:57:04 +0800 Subject: [PATCH 6/6] update Signed-off-by: Weichen Xu --- python/pyspark/sql/connect/client/core.py | 2 +- .../main/scala/org/apache/spark/sql/connect/ml/MLCache.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index bd657a55a2a1e..9d2e18ebb7600 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -2028,7 +2028,7 @@ def _get_ml_cache_info(self) -> List[str]: return [] - def _query_model_size(self, model_ref_id) -> int: + def _query_model_size(self, model_ref_id: str) -> int: command = pb2.Command() command.ml_command.get_model_size.CopyFrom( pb2.MlCommand.GetModelSize(model_ref=pb2.ObjectRef(id=model_ref_id)) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index ef51c0c7498c8..f944110a54c77 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -241,7 +241,6 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { def clear(): Int = this.synchronized { val size = cachedModel.size() cachedModel.clear() - totalMLCacheSizeBytes.set(0L) if (getMemoryControlEnabled) { SparkFileUtils.cleanDirectory(new File(offloadedModelsDir.toString)) }