Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
## For defining dependencies used by jobs in Databricks Workflows, see
## https://docs.databricks.com/dev-tools/bundles/library-dependencies.html

## Add code completion support for DLT
databricks-dlt

## pytest is the default package used for testing
pytest

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -22,7 +32,7 @@
"sys.path.append('../src')\n",
"from {{.project_name}} import main\n",
"\n",
"main.get_taxis().show(10)"
"main.get_taxis(spark).show(10)"
{{else}}
"spark.range(10)"
{{end -}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
{{- if (eq .include_python "yes") }}
"@dlt.view\n",
"def taxi_raw():\n",
" return main.get_taxis()\n",
" return main.get_taxis(spark)\n",
{{else}}
"\n",
"@dlt.view\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
"This default notebook is executed using Databricks Workflows as defined in resources/{{.project_name}}_job.yml."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 0,
Expand All @@ -37,7 +47,7 @@
{{- if (eq .include_python "yes") }}
"from {{.project_name}} import main\n",
"\n",
"main.get_taxis().show(10)"
"main.get_taxis(spark).show(10)"
{{else}}
"spark.range(10)"
{{end -}}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
{{- /*
We use pyspark.sql rather than DatabricksSession.builder.getOrCreate()
for compatibility with older runtimes. With a new runtime, it's
equivalent to DatabricksSession.builder.getOrCreate().
*/ -}}
from pyspark.sql import SparkSession
from pyspark.sql import SparkSession, DataFrame

def get_taxis():
spark = SparkSession.builder.getOrCreate()
def get_taxis(spark: SparkSession) -> DataFrame:
return spark.read.table("samples.nyctaxi.trips")


# Create a new Databricks Connect session. If this fails,
# check that you have configured Databricks Connect correctly.
# See https://docs.databricks.com/dev-tools/databricks-connect.html.
def get_spark() -> SparkSession:
try:
from databricks.connect import DatabricksSession
return DatabricksSession.builder.getOrCreate()
except ImportError:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This really doesn't seem like it should be how we recommend customers write their code...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but until we have something better I'd rather be explicit and verbose than not supporting old DBRs or hiding it in non-standard libraries.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does seem like a very specific factory pattern we want users to follow. How about moving it inside the main function and not having the get_spark function? That should make it very clear that this is not intended to be used everywhere and is here to only enable per file runs.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I correct to understand that the only case that DatabricksSession.builder.getOrCreate() will fail is if the user is targeting this at DBR <13?

return SparkSession.builder.getOrCreate()

def main():
get_taxis().show(5)
get_taxis(get_spark()).show(5)

if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -1,21 +1,6 @@
from databricks.connect import DatabricksSession
from pyspark.sql import SparkSession
from {{.project_name}} import main
from {{.project_name}}.main import get_taxis, get_spark

# Create a new Databricks Connect session. If this fails,
# check that you have configured Databricks Connect correctly.
# See https://docs.databricks.com/dev-tools/databricks-connect.html.
{{/*
The below works around a problematic error message from Databricks Connect.
The standard SparkSession is supported in all configurations (workspace, IDE,
all runtime versions, CLI). But on the CLI it currently gives a confusing
error message if SPARK_REMOTE is not set. We can't directly use
DatabricksSession.builder in main.py, so we're re-assigning it here so
everything works out of the box, even for CLI users who don't set SPARK_REMOTE.
*/}}
SparkSession.builder = DatabricksSession.builder
SparkSession.builder.getOrCreate()

def test_main():
taxis = main.get_taxis()
taxis = get_taxis(get_spark())
assert taxis.count() > 5