diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala index 2016a84ac5a35..f3f52366df666 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala @@ -103,4 +103,25 @@ class GeographyConnectDataFrameSuite extends QueryTest with RemoteSparkSession { val expectedGeog = Geography.fromWKB(point, 4326) checkAnswer(df, Seq(Row(expectedGeog))) } + + test("geospatial feature disabled") { + withSQLConf("spark.sql.geospatial.enabled" -> "false") { + val geography = Geography.fromWKB(point1, 4326) + val schema = StructType(Seq(StructField("col1", GeographyType(4326)))) + // Java List[Row] + schema. + val javaList = java.util.Arrays.asList(Row(geography)) + checkError( + exception = intercept[AnalysisException] { + spark.createDataFrame(javaList, schema).collect() + }, + condition = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED") + // Implicit encoder path. + import testImplicits._ + checkError( + exception = intercept[AnalysisException] { + Seq(geography).toDF("g").collect() + }, + condition = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED") + } + } } diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala index 1450ac54184bd..b66c8a6a3d788 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala @@ -109,4 +109,25 @@ class GeometryConnectDataFrameSuite extends QueryTest with RemoteSparkSession { val expectedGeom = Geometry.fromWKB(point, 0) checkAnswer(df, Seq(Row(expectedGeom))) } + + test("geospatial feature disabled") { + withSQLConf("spark.sql.geospatial.enabled" -> "false") { + val geometry = Geometry.fromWKB(point1, 0) + val schema = StructType(Seq(StructField("col1", GeometryType(0)))) + // Java List[Row] + schema. + val javaList = java.util.Arrays.asList(Row(geometry)) + checkError( + exception = intercept[AnalysisException] { + spark.createDataFrame(javaList, schema).collect() + }, + condition = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED") + // Implicit encoder path. + import testImplicits._ + checkError( + exception = intercept[AnalysisException] { + Seq(geometry).toDF("g").collect() + }, + condition = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED") + } + } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index f5cb2696d849b..dc20af8e3700d 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -29,6 +29,7 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.st.STExpressionUtils import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto @@ -126,6 +127,12 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) val sessionId = executePlan.sessionHolder.sessionId val spark = dataframe.sparkSession val schema = dataframe.schema + val geospatialEnabled = spark.sessionState.conf.geospatialEnabled + if (!geospatialEnabled && schema.existsRecursively(STExpressionUtils.isGeoSpatialType)) { + throw new org.apache.spark.sql.AnalysisException( + errorClass = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED", + messageParameters = scala.collection.immutable.Map.empty) + } val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone val largeVarTypes = spark.sessionState.conf.arrowUseLargeVarTypes