44from typing import List , Tuple
55
66# import psycopg2.errors
7+ import psycopg2
8+ from psycopg2 import sql
79from sqlalchemy import create_engine , text
10+ from sqlalchemy .exc import OperationalError
811from sqlalchemy .orm import sessionmaker , Session
912
1013from goldfig import models
1114
1215_log = logging .getLogger (__name__ )
1316
1417DBNAME = 'goldfig'
15- HOST = 'localhost'
18+ HOST = os .environ .get ('GOLDFIG_DB_HOST' , 'localhost' )
19+ PORT = int (os .environ .get ('GOLDFIG_DB_PORT' , 5432 ))
1620
1721
1822@dataclass
@@ -29,12 +33,12 @@ def connection_string(self) -> str:
2933_ImportCredential = DbCredential (db_name = DBNAME ,
3034 user = 'goldfig' ,
3135 password = 'goldfig' ,
32- host = HOST )
36+ host = f' { HOST } :5432' )
3337
3438_ReadonlyCredential = DbCredential (db_name = DBNAME ,
3539 user = 'goldfig_ro' ,
3640 password = 'goldfig_ro' ,
37- host = HOST )
41+ host = f' { HOST } :5432' )
3842
3943_import_engine = None
4044_readonly_engine = None
@@ -43,8 +47,7 @@ def connection_string(self) -> str:
4347
4448
4549def _view_files () -> Tuple [str , List [str ]]:
46- path = os .path .realpath (
47- os .path .join (os .path .dirname (__file__ ), 'views' ))
50+ path = os .path .realpath (os .path .join (os .path .dirname (__file__ ), 'views' ))
4851 files = [
4952 f for f in os .listdir (path )
5053 if os .path .isfile (os .path .join (path , f )) and f [- 4 :] == '.sql'
@@ -76,9 +79,7 @@ def _install_views(db: Session):
7679 raise
7780
7881
79- # TODO: switch to something like alembic?
80- def init_db ():
81- db = import_session ()
82+ def _install_schema (db : Session ) -> None :
8283 version = None
8384 try :
8485 version = db .query (models .SchemaVersion ).one_or_none ()
@@ -100,10 +101,79 @@ def init_db():
100101 db .commit ()
101102
102103
104+ def _install_db_and_roles ():
105+ print ('creating goldfig database and installing roles' )
106+ cred = DbCredential (db_name = 'postgres' ,
107+ host = HOST ,
108+ user = os .environ .get ('GOLDFIG_DB_SU_USER' , 'postgres' ),
109+ password = os .environ .get ('GOLDFIG_DB_SU_PASSWORD' ,
110+ 'postgres' ))
111+ su_conn = psycopg2 .connect (dbname = cred .db_name ,
112+ user = cred .user ,
113+ password = cred .password ,
114+ host = cred .host )
115+ su_conn .autocommit = True
116+ cursor = su_conn .cursor ()
117+ cursor .execute ('CREATE DATABASE goldfig' )
118+ cursor .close ()
119+ su_conn .autocommit = False
120+ cursor = su_conn .cursor ()
121+ for user in (_ImportCredential , _ReadonlyCredential ):
122+ cursor .execute (
123+ sql .SQL ('CREATE USER {} WITH ENCRYPTED PASSWORD %s' ).format (
124+ sql .Identifier (user .user )), (user .password , ))
125+
126+ su_conn .commit ()
127+ su_conn .close ()
128+
129+ cred = DbCredential (db_name = 'goldfig' ,
130+ host = cred .host ,
131+ user = cred .user ,
132+ password = cred .password )
133+ su_conn = psycopg2 .connect (dbname = cred .db_name ,
134+ user = cred .user ,
135+ password = cred .password ,
136+ host = cred .host )
137+ su_conn .autocommit = False
138+ cursor = su_conn .cursor ()
139+ import_user = sql .Identifier (_ImportCredential .user )
140+ ro_user = sql .Identifier (_ReadonlyCredential .user )
141+ cursor .execute ('revoke create on schema public from public' )
142+ cursor .execute (
143+ sql .SQL ('grant all privileges on schema public to {}' ).format (
144+ import_user ))
145+ cursor .execute (
146+ sql .SQL ('grant select on all tables in schema public to {}' ).format (
147+ ro_user ))
148+ cursor .execute (
149+ sql .SQL (
150+ 'alter default privileges for role {} in schema public grant select on tables to {}'
151+ ).format (import_user , ro_user ))
152+ cursor .close ()
153+ su_conn .commit ()
154+
155+
156+ # TODO: switch to something like alembic?
157+ def init_db ():
158+ try :
159+ db = import_session ()
160+ except OperationalError as e :
161+ if 'password authentication failed' in str (e ):
162+ # need to set up db
163+ _install_db_and_roles ()
164+ db = import_session ()
165+ else :
166+ raise
167+ _install_schema (db )
168+
169+
103170def import_session () -> Session :
104171 global _import_engine
105172 if _import_engine is None :
106- _import_engine = create_engine (_ImportCredential .connection_string ())
173+ _import_engine = create_engine (_ImportCredential .connection_string (),
174+ connect_args = {'connect_timeout' : 3 })
175+ # Force connection errors early
176+ _import_engine .connect ()
107177 return sessionmaker (bind = _import_engine )()
108178
109179
0 commit comments