Skip to content

Commit 4c5ab89

Browse files
author
Jacob Beck
committed
move the require-dbt-version check to before parsing
1 parent fe46138 commit 4c5ab89

File tree

4 files changed

+130
-59
lines changed

4 files changed

+130
-59
lines changed

core/dbt/config/project.py

Lines changed: 78 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,11 @@ class PartialProject:
215215
metadata=dict(description='The root directory of the project'),
216216
)
217217
project_dict: Dict[str, Any]
218+
verify_version: bool = field(
219+
metadata=dict(description=(
220+
'If True, verify the dbt version matches the rquired version'
221+
))
222+
)
218223

219224
def render(self, renderer):
220225
packages_dict = package_data_from_root(self.project_root)
@@ -225,6 +230,7 @@ def render(self, renderer):
225230
packages_dict,
226231
selectors_dict,
227232
renderer,
233+
verify_version=self.verify_version,
228234
)
229235

230236
def render_profile_name(self, renderer) -> Optional[str]:
@@ -292,6 +298,32 @@ def to_dict(self):
292298
return self.vars
293299

294300

301+
def validate_version(
302+
required: List[VersionSpecifier],
303+
project_name: str,
304+
) -> None:
305+
"""Ensure this package works with the installed version of dbt."""
306+
installed = get_installed_version()
307+
if not versions_compatible(*required):
308+
msg = IMPOSSIBLE_VERSION_ERROR.format(
309+
package=project_name,
310+
version_spec=[
311+
x.to_version_string() for x in required
312+
]
313+
)
314+
raise DbtProjectError(msg)
315+
316+
if not versions_compatible(installed, *required):
317+
msg = INVALID_VERSION_ERROR.format(
318+
package=project_name,
319+
installed=installed.to_version_string(),
320+
version_spec=[
321+
x.to_version_string() for x in required
322+
]
323+
)
324+
raise DbtProjectError(msg)
325+
326+
295327
@dataclass
296328
class Project:
297329
project_name: str
@@ -363,6 +395,7 @@ def from_project_config(
363395
project_dict: Dict[str, Any],
364396
packages_dict: Optional[Dict[str, Any]] = None,
365397
selectors_dict: Optional[Dict[str, Any]] = None,
398+
required_dbt_version: Optional[List[VersionSpecifier]] = None,
366399
) -> 'Project':
367400
"""Create a project from its project and package configuration, as read
368401
by yaml.safe_load().
@@ -374,6 +407,11 @@ def from_project_config(
374407
the packages file exists and is invalid.
375408
:returns: The project, with defaults populated.
376409
"""
410+
if required_dbt_version is None:
411+
dbt_version = cls._get_required_version(project_dict)
412+
else:
413+
dbt_version = required_dbt_version
414+
377415
try:
378416
project_dict = cls._preprocess(project_dict)
379417
except RecursionException:
@@ -460,18 +498,8 @@ def from_project_config(
460498
on_run_start: List[str] = value_or(cfg.on_run_start, [])
461499
on_run_end: List[str] = value_or(cfg.on_run_end, [])
462500

463-
# weird type handling: no value_or use
464-
dbt_raw_version: Union[List[str], str] = '>=0.0.0'
465-
if cfg.require_dbt_version is not None:
466-
dbt_raw_version = cfg.require_dbt_version
467-
468501
query_comment = _query_comment_from_cfg(cfg.query_comment)
469502

470-
try:
471-
dbt_version = _parse_versions(dbt_raw_version)
472-
except SemverException as e:
473-
raise DbtProjectError(str(e)) from e
474-
475503
try:
476504
packages = package_config_from_data(packages_dict)
477505
except ValidationError as e:
@@ -583,6 +611,30 @@ def validate(self):
583611
except ValidationError as e:
584612
raise DbtProjectError(validator_error_message(e)) from e
585613

614+
@classmethod
615+
def _get_required_version(
616+
cls, rendered_project: Dict[str, Any], verify_version: bool = False
617+
) -> List[VersionSpecifier]:
618+
dbt_raw_version: Union[List[str], str] = '>=0.0.0'
619+
required = rendered_project.get('require-dbt-version')
620+
if required is not None:
621+
dbt_raw_version = required
622+
623+
try:
624+
dbt_version = _parse_versions(dbt_raw_version)
625+
except SemverException as e:
626+
raise DbtProjectError(str(e)) from e
627+
628+
if verify_version:
629+
# no name is also an error that we want to raise
630+
if 'name' not in rendered_project:
631+
raise DbtProjectError(
632+
'Required "name" field not present in project',
633+
)
634+
validate_version(dbt_version, rendered_project['name'])
635+
636+
return dbt_version
637+
586638
@classmethod
587639
def render_from_dict(
588640
cls,
@@ -591,18 +643,26 @@ def render_from_dict(
591643
packages_dict: Dict[str, Any],
592644
selectors_dict: Dict[str, Any],
593645
renderer: DbtProjectYamlRenderer,
646+
*,
647+
verify_version: bool = False
594648
) -> 'Project':
595649
rendered_project = renderer.render_data(project_dict)
596650
rendered_project['project-root'] = project_root
597651
package_renderer = renderer.get_package_renderer()
598652
rendered_packages = package_renderer.render_data(packages_dict)
599653
selectors_renderer = renderer.get_selector_renderer()
600654
rendered_selectors = selectors_renderer.render_data(selectors_dict)
655+
601656
try:
657+
dbt_version = cls._get_required_version(
658+
rendered_project, verify_version=verify_version
659+
)
660+
602661
return cls.from_project_config(
603662
rendered_project,
604663
rendered_packages,
605664
rendered_selectors,
665+
dbt_version,
606666
)
607667
except DbtProjectError as exc:
608668
if exc.path is None:
@@ -611,7 +671,7 @@ def render_from_dict(
611671

612672
@classmethod
613673
def partial_load(
614-
cls, project_root: str
674+
cls, project_root: str, *, verify_version: bool = False
615675
) -> PartialProject:
616676
project_root = os.path.normpath(project_root)
617677
project_dict = _raw_project_from(project_root)
@@ -626,41 +686,24 @@ def partial_load(
626686
project_name=project_name,
627687
project_root=project_root,
628688
project_dict=project_dict,
689+
verify_version=verify_version,
629690
)
630691

631692
@classmethod
632693
def from_project_root(
633-
cls, project_root: str, renderer: DbtProjectYamlRenderer
694+
cls,
695+
project_root: str,
696+
renderer: DbtProjectYamlRenderer,
697+
*,
698+
verify_version: bool = False,
634699
) -> 'Project':
635-
partial = cls.partial_load(project_root)
700+
partial = cls.partial_load(project_root, verify_version=verify_version)
636701
renderer.version = partial.config_version
637702
return partial.render(renderer)
638703

639704
def hashed_name(self):
640705
return hashlib.md5(self.project_name.encode('utf-8')).hexdigest()
641706

642-
def validate_version(self):
643-
"""Ensure this package works with the installed version of dbt."""
644-
installed = get_installed_version()
645-
if not versions_compatible(*self.dbt_version):
646-
msg = IMPOSSIBLE_VERSION_ERROR.format(
647-
package=self.project_name,
648-
version_spec=[
649-
x.to_version_string() for x in self.dbt_version
650-
]
651-
)
652-
raise DbtProjectError(msg)
653-
654-
if not versions_compatible(installed, *self.dbt_version):
655-
msg = INVALID_VERSION_ERROR.format(
656-
package=self.project_name,
657-
installed=installed.to_version_string(),
658-
version_spec=[
659-
x.to_version_string() for x in self.dbt_version
660-
]
661-
)
662-
raise DbtProjectError(msg)
663-
664707
def as_v1(self, all_projects: Iterable[str]):
665708
if self.config_version == 1:
666709
return self

core/dbt/config/runtime.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,11 @@ def new_project(self, project_root: str) -> 'RuntimeConfig':
138138
# load the new project and its packages. Don't pass cli variables.
139139
renderer = DbtProjectYamlRenderer(generate_target_context(profile, {}))
140140

141-
project = Project.from_project_root(project_root, renderer)
141+
project = Project.from_project_root(
142+
project_root,
143+
renderer,
144+
verify_version=getattr(self.args, 'version_check', False),
145+
)
142146

143147
cfg = self.from_parts(
144148
project=project,
@@ -173,9 +177,6 @@ def validate(self):
173177
except ValidationError as e:
174178
raise DbtProjectError(validator_error_message(e)) from e
175179

176-
if getattr(self.args, 'version_check', False):
177-
self.validate_version()
178-
179180
@classmethod
180181
def _get_rendered_profile(
181182
cls,
@@ -193,7 +194,11 @@ def collect_parts(
193194
) -> Tuple[Project, Profile]:
194195
# profile_name from the project
195196
project_root = args.project_dir if args.project_dir else os.getcwd()
196-
partial = Project.partial_load(project_root)
197+
version_check = getattr(args, 'version_check', False)
198+
partial = Project.partial_load(
199+
project_root,
200+
verify_version=version_check
201+
)
197202

198203
# build the profile using the base renderer and the one fact we know
199204
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))

core/dbt/task/debug.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def _load_project(self):
143143

144144
try:
145145
self.project = Project.from_project_root(
146-
self.project_dir, renderer
146+
self.project_dir,
147+
renderer,
148+
getattr(self.args, 'version_check', False),
147149
)
148150
except dbt.exceptions.DbtConfigError as exc:
149151
self.project_fail_details = str(exc)
@@ -181,7 +183,8 @@ def _choose_profile_names(self) -> Optional[List[str]]:
181183
if os.path.exists(self.project_path):
182184
try:
183185
partial = Project.partial_load(
184-
os.path.dirname(self.project_path)
186+
os.path.dirname(self.project_path),
187+
verify_version=getattr(self.args, 'version_check'),
185188
)
186189
renderer = DbtProjectYamlRenderer(
187190
generate_base_context(self.cli_vars)

test/unit/test_config.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def temp_cd(path):
3535
finally:
3636
os.chdir(current_path)
3737

38+
@contextmanager
39+
def raises_nothing():
40+
yield
41+
3842

3943
def empty_profile_renderer():
4044
return dbt.config.renderer.ProfileRenderer(generate_base_context({}))
@@ -179,6 +183,12 @@ def setUp(self):
179183
'env_value_profile': 'default',
180184
}
181185

186+
def assertRaisesOrReturns(self, exc):
187+
if exc is None:
188+
return raises_nothing()
189+
else:
190+
return self.assertRaises(exc)
191+
182192

183193
class BaseFileTest(BaseConfigTest):
184194
def setUp(self):
@@ -460,7 +470,7 @@ def test_target_override(self):
460470
self.assertEqual(profile.credentials.password, 'db_pass')
461471
self.assertEqual(profile.credentials.schema, 'redshift-schema')
462472
self.assertEqual(profile.credentials.database, 'redshift-db-name')
463-
self.assertEqual(profile, from_raw)
473+
self.assertEqual(profile, from_raw)
464474

465475
def test_env_vars(self):
466476
self.args.target = 'with-vars'
@@ -947,8 +957,12 @@ def setUp(self):
947957
self.default_project_data['project-root'] = self.project_dir
948958

949959
def get_project(self):
960+
version = dbt.config.Project._get_required_version(
961+
self.default_project_data,
962+
verify_version=bool(self.args.version_check)
963+
)
950964
return dbt.config.Project.from_project_config(
951-
self.default_project_data, None
965+
self.default_project_data, None, required_dbt_version=version
952966
)
953967

954968
def get_profile(self):
@@ -958,14 +972,16 @@ def get_profile(self):
958972
)
959973

960974
def from_parts(self, exc=None):
961-
project = self.get_project()
962-
profile = self.get_profile()
963-
if exc is None:
964-
return dbt.config.RuntimeConfig.from_parts(project, profile, self.args)
975+
with self.assertRaisesOrReturns(exc) as err:
976+
project = self.get_project()
977+
profile = self.get_profile()
978+
979+
result = dbt.config.RuntimeConfig.from_parts(project, profile, self.args)
965980

966-
with self.assertRaises(exc) as err:
967-
dbt.config.RuntimeConfig.from_parts(project, profile, self.args)
968-
return err
981+
if exc is None:
982+
return result
983+
else:
984+
return err
969985

970986
def test_from_parts(self):
971987
project = self.get_project()
@@ -1124,8 +1140,11 @@ def setUp(self):
11241140
))}
11251141

11261142
def get_project(self):
1143+
version = dbt.config.Project._get_required_version(
1144+
self.default_project_data, verify_version=True
1145+
)
11271146
return dbt.config.Project.from_project_config(
1128-
self.default_project_data, None
1147+
self.default_project_data, None, required_dbt_version=version
11291148
)
11301149

11311150
def get_profile(self):
@@ -1135,15 +1154,16 @@ def get_profile(self):
11351154
)
11361155

11371156
def from_parts(self, exc=None):
1138-
project = self.get_project()
1139-
profile = self.get_profile()
1140-
if exc is None:
1141-
return dbt.config.RuntimeConfig.from_parts(project, profile, self.args)
1157+
with self.assertRaisesOrReturns(exc) as err:
1158+
project = self.get_project()
1159+
profile = self.get_profile()
11421160

1143-
with self.assertRaises(exc) as err:
1144-
dbt.config.RuntimeConfig.from_parts(project, profile, self.args)
1145-
return err
1161+
result = dbt.config.RuntimeConfig.from_parts(project, profile, self.args)
11461162

1163+
if exc is None:
1164+
return result
1165+
else:
1166+
return err
11471167

11481168
def test__get_unused_resource_config_paths(self):
11491169
project = self.from_parts()

0 commit comments

Comments
 (0)