From 7f407ed116b450db89e04cd25f86b7e35b30af47 Mon Sep 17 00:00:00 2001 From: Fabian Jakobs Date: Tue, 27 Feb 2024 15:34:19 +0100 Subject: [PATCH 1/4] DABs Template: Fix DBConnect support in VS Code --- .../scratch/exploration.ipynb.tmpl | 12 +++++++- .../src/dlt_pipeline.ipynb.tmpl | 2 +- .../{{.project_name}}/src/notebook.ipynb.tmpl | 12 +++++++- .../src/{{.project_name}}/main.py.tmpl | 12 +++----- .../{{.project_name}}/tests/main_test.py.tmpl | 28 ++++++++----------- 5 files changed, 38 insertions(+), 28 deletions(-) diff --git a/libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb.tmpl index 04bb261cd0..42164dff07 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb.tmpl @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, { "cell_type": "code", "execution_count": null, @@ -22,7 +32,7 @@ "sys.path.append('../src')\n", "from {{.project_name}} import main\n", "\n", - "main.get_taxis().show(10)" + "main.get_taxis(spark).show(10)" {{else}} "spark.range(10)" {{end -}} diff --git a/libs/template/templates/default-python/template/{{.project_name}}/src/dlt_pipeline.ipynb.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/src/dlt_pipeline.ipynb.tmpl index 4f50294f6a..b152e9a308 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/src/dlt_pipeline.ipynb.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/src/dlt_pipeline.ipynb.tmpl @@ -63,7 +63,7 @@ {{- if (eq .include_python "yes") }} "@dlt.view\n", "def taxi_raw():\n", - " return main.get_taxis()\n", + " return main.get_taxis(spark)\n", {{else}} "\n", "@dlt.view\n", diff --git a/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl index 0ab61db2c9..a228f8d18d 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl @@ -17,6 +17,16 @@ "This default notebook is executed using Databricks Workflows as defined in resources/{{.project_name}}_job.yml." ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, { "cell_type": "code", "execution_count": 0, @@ -37,7 +47,7 @@ {{- if (eq .include_python "yes") }} "from {{.project_name}} import main\n", "\n", - "main.get_taxis().show(10)" + "main.get_taxis(spark).show(10)" {{else}} "spark.range(10)" {{end -}} diff --git a/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl index 4fe5ac8f4b..48529e974b 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl @@ -1,16 +1,12 @@ -{{- /* -We use pyspark.sql rather than DatabricksSession.builder.getOrCreate() -for compatibility with older runtimes. With a new runtime, it's -equivalent to DatabricksSession.builder.getOrCreate(). -*/ -}} from pyspark.sql import SparkSession -def get_taxis(): - spark = SparkSession.builder.getOrCreate() +def get_taxis(spark: SparkSession): return spark.read.table("samples.nyctaxi.trips") def main(): - get_taxis().show(5) + from databricks.connect import DatabricksSession as SparkSession + spark = SparkSession.builder.getOrCreate() + get_taxis(spark).show(5) if __name__ == '__main__': main() diff --git a/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl index a7a6afe0a8..8ae043a65b 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl @@ -1,21 +1,15 @@ -from databricks.connect import DatabricksSession -from pyspark.sql import SparkSession +from databricks.connect import DatabricksSession as SparkSession +from pytest import fixture from {{.project_name}} import main -# Create a new Databricks Connect session. If this fails, -# check that you have configured Databricks Connect correctly. -# See https://docs.databricks.com/dev-tools/databricks-connect.html. -{{/* - The below works around a problematic error message from Databricks Connect. - The standard SparkSession is supported in all configurations (workspace, IDE, - all runtime versions, CLI). But on the CLI it currently gives a confusing - error message if SPARK_REMOTE is not set. We can't directly use - DatabricksSession.builder in main.py, so we're re-assigning it here so - everything works out of the box, even for CLI users who don't set SPARK_REMOTE. -*/}} -SparkSession.builder = DatabricksSession.builder -SparkSession.builder.getOrCreate() -def test_main(): - taxis = main.get_taxis() +@fixture(scope="session") +def spark(): + spark = SparkSession.builder.getOrCreate() + yield spark + spark.stop() + + +def test_main(spark: SparkSession): + taxis = main.get_taxis(spark) assert taxis.count() > 5 From 26a27cbc6b21ccc3a016716f561de0e9dbb9bed7 Mon Sep 17 00:00:00 2001 From: Fabian Jakobs Date: Mon, 4 Mar 2024 12:59:09 +0100 Subject: [PATCH 2/4] Add support for older runtimes --- .../src/{{.project_name}}/main.py.tmpl | 15 ++++++++++++--- .../{{.project_name}}/tests/main_test.py.tmpl | 6 +++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl index 48529e974b..658c74cf8a 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl @@ -3,10 +3,19 @@ from pyspark.sql import SparkSession def get_taxis(spark: SparkSession): return spark.read.table("samples.nyctaxi.trips") + +# Create a new Databricks Connect session. If this fails, +# check that you have configured Databricks Connect correctly. +# See https://docs.databricks.com/dev-tools/databricks-connect.html. +def get_spark(): + try: + from databricks.connect import DatabricksSession + return DatabricksSession.builder.getOrCreate() + except ImportError: + return SparkSession.builder.getOrCreate() + def main(): - from databricks.connect import DatabricksSession as SparkSession - spark = SparkSession.builder.getOrCreate() - get_taxis(spark).show(5) + get_taxis(get_spark()).show(5) if __name__ == '__main__': main() diff --git a/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl index 8ae043a65b..2cffbc0d53 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl @@ -1,11 +1,11 @@ -from databricks.connect import DatabricksSession as SparkSession +from pyspark.sql import SparkSession from pytest import fixture -from {{.project_name}} import main +from {{.project_name}} import main, get_spark @fixture(scope="session") def spark(): - spark = SparkSession.builder.getOrCreate() + spark = get_spark() yield spark spark.stop() From 7e2f57139fd0d693c8cfd412719d38704b768972 Mon Sep 17 00:00:00 2001 From: Fabian Jakobs Date: Mon, 4 Mar 2024 12:59:09 +0100 Subject: [PATCH 3/4] Add support for older runtimes --- .../template/{{.project_name}}/requirements-dev.txt.tmpl | 3 +++ .../{{.project_name}}/src/{{.project_name}}/main.py.tmpl | 6 +++--- .../template/{{.project_name}}/tests/main_test.py.tmpl | 6 +++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/libs/template/templates/default-python/template/{{.project_name}}/requirements-dev.txt.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/requirements-dev.txt.tmpl index 6da403219d..93dd4c4802 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/requirements-dev.txt.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/requirements-dev.txt.tmpl @@ -3,6 +3,9 @@ ## For defining dependencies used by jobs in Databricks Workflows, see ## https://docs.databricks.com/dev-tools/bundles/library-dependencies.html +## Add code completion support for DLT +databricks-dlt + ## pytest is the default package used for testing pytest diff --git a/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl index 658c74cf8a..c514c6dc5d 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl @@ -1,13 +1,13 @@ -from pyspark.sql import SparkSession +from pyspark.sql import SparkSession, DataFrame -def get_taxis(spark: SparkSession): +def get_taxis(spark: SparkSession) -> DataFrame: return spark.read.table("samples.nyctaxi.trips") # Create a new Databricks Connect session. If this fails, # check that you have configured Databricks Connect correctly. # See https://docs.databricks.com/dev-tools/databricks-connect.html. -def get_spark(): +def get_spark() -> SparkSession: try: from databricks.connect import DatabricksSession return DatabricksSession.builder.getOrCreate() diff --git a/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl index 2cffbc0d53..f4480bc47d 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl @@ -1,15 +1,15 @@ from pyspark.sql import SparkSession from pytest import fixture -from {{.project_name}} import main, get_spark +from {{.project_name}}.main import get_taxis, get_spark @fixture(scope="session") -def spark(): +def spark() -> SparkSession: spark = get_spark() yield spark spark.stop() def test_main(spark: SparkSession): - taxis = main.get_taxis(spark) + taxis = get_taxis(spark) assert taxis.count() > 5 From 0977481c897c0b181071d97c19a00d4efb95519e Mon Sep 17 00:00:00 2001 From: Fabian Jakobs Date: Tue, 5 Mar 2024 13:14:15 +0100 Subject: [PATCH 4/4] remove fixture --- .../{{.project_name}}/tests/main_test.py.tmpl | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl index f4480bc47d..fea2f3f665 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl @@ -1,15 +1,6 @@ -from pyspark.sql import SparkSession -from pytest import fixture from {{.project_name}}.main import get_taxis, get_spark -@fixture(scope="session") -def spark() -> SparkSession: - spark = get_spark() - yield spark - spark.stop() - - -def test_main(spark: SparkSession): - taxis = get_taxis(spark) +def test_main(): + taxis = get_taxis(get_spark()) assert taxis.count() > 5