Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import copy
import os
from logging.config import fileConfig

from alembic import context
from dotenv import load_dotenv
from sqlalchemy import engine_from_config, pool, create_engine

from services.util import get_bool_env

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand Down Expand Up @@ -46,7 +48,10 @@ def build_database_url():
user = os.environ.get("CLOUD_SQL_USER", "")
password = os.environ.get("CLOUD_SQL_PASSWORD", "")
database = os.environ.get("CLOUD_SQL_DATABASE", "")
use_iam_auth = get_bool_env("CLOUD_SQL_IAM_AUTH", False)
# Host is provided by connector, so leave blank.
if use_iam_auth:
return f"postgresql+pg8000://{user}@/{database}"
return f"postgresql+pg8000://{user}:{password}@/{database}"

# Default/Postgres
Expand Down Expand Up @@ -96,22 +101,46 @@ def run_migrations_online() -> None:
if db_driver == "cloudsql":
# Use the Cloud SQL Python Connector for direct Cloud SQL access.
from google.cloud.sql.connector import Connector
from google.auth import default
from google.auth.transport.requests import Request

instance_name = os.environ.get("CLOUD_SQL_INSTANCE_NAME")
user = os.environ.get("CLOUD_SQL_USER")
password = os.environ.get("CLOUD_SQL_PASSWORD")
database = os.environ.get("CLOUD_SQL_DATABASE")
use_iam_auth = get_bool_env("CLOUD_SQL_IAM_AUTH", False)
ip_type = os.environ.get("CLOUD_SQL_IP_TYPE", "public")

connector = Connector()

def get_iam_login_token() -> str:
scopes = ["https://www.googleapis.com/auth/sqlservice.login"]
creds, _ = default()
if hasattr(creds, "with_scopes"):
creds = creds.with_scopes(scopes=scopes)
else:
creds = copy.copy(creds)
creds._scopes = scopes # type: ignore[attr-defined]
creds.refresh(Request())
if not getattr(creds, "token", None):
raise RuntimeError("Unable to acquire IAM DB auth token.")
return creds.token

def getconn():
connect_kwargs = {
"user": user,
"db": database,
"ip_type": ip_type,
"enable_iam_auth": use_iam_auth,
}
if use_iam_auth:
connect_kwargs["password"] = get_iam_login_token()
else:
connect_kwargs["password"] = password
return connector.connect(
instance_name,
"pg8000",
user=user,
password=password,
db=database,
ip_type="public",
**connect_kwargs,
)

connectable = create_engine(
Expand Down
Loading