Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions .github/workflows/CD_staging.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ jobs:
- name: Authenticate to Google Cloud
uses: 'google-github-actions/auth@v2'
with:
credentials_json: ${{ secrets.CLOUD_SQL_SERVICE_ACCOUNT_KEY }}
credentials_json: ${{ secrets.CLOUD_DEPLOY_SERVICE_ACCOUNT_KEY }}

- name: Run Alembic migrations on staging database
env:
DB_DRIVER: "cloudsql"
CLOUD_SQL_INSTANCE_NAME: "${{ secrets.CLOUD_SQL_INSTANCE_NAME }}"
CLOUD_SQL_DATABASE: "${{ vars.CLOUD_SQL_DATABASE }}"
CLOUD_SQL_USER: "${{ secrets.CLOUD_SQL_USER }}"
CLOUD_SQL_PASSWORD: "${{ secrets.CLOUD_SQL_PASSWORD }}"
CLOUD_SQL_IAM_AUTH: true
run: |
uv run alembic upgrade head

Expand All @@ -53,17 +53,12 @@ jobs:
CLOUD_SQL_INSTANCE_NAME: "${{ secrets.CLOUD_SQL_INSTANCE_NAME }}"
CLOUD_SQL_DATABASE: "${{ vars.CLOUD_SQL_DATABASE }}"
CLOUD_SQL_USER: "${{ secrets.CLOUD_SQL_USER }}"
CLOUD_SQL_PASSWORD: "${{ secrets.CLOUD_SQL_PASSWORD }}"
CLOUD_SQL_IAM_AUTH: true
GCS_SERVICE_ACCOUNT_KEY: "${{ secrets.GCS_SERVICE_ACCOUNT_KEY }}"
GCS_BUCKET_NAME: "${{ vars.GCS_BUCKET_NAME }}"
run: |
uv run python -m transfers.backfill.staging

- name: Authenticate to Google Cloud
uses: 'google-github-actions/auth@v2'
with:
credentials_json: ${{ secrets.CLOUD_DEPLOY_SERVICE_ACCOUNT_KEY }}

# Uses Google Cloud Secret Manager to store secret credentials
- name: Create app.yaml
run: |
Expand All @@ -82,7 +77,7 @@ jobs:
CLOUD_SQL_INSTANCE_NAME: "${{ secrets.CLOUD_SQL_INSTANCE_NAME }}"
CLOUD_SQL_DATABASE: "${{ vars.CLOUD_SQL_DATABASE }}"
CLOUD_SQL_USER: "${{ secrets.CLOUD_SQL_USER }}"
CLOUD_SQL_PASSWORD: "${{ secrets.CLOUD_SQL_PASSWORD }}"
CLOUD_SQL_IAM_AUTH: true
GCS_SERVICE_ACCOUNT_KEY: "${{ secrets.GCS_SERVICE_ACCOUNT_KEY }}"
GCS_BUCKET_NAME: "${{ vars.GCS_BUCKET_NAME }}"
AUTHENTIK_URL: "${{ vars.AUTHENTIK_URL }}"
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ reset_db.sh
tests/uploads
migrate.sh
launcher.sh
gcs_credentials.json
*credentials.json
transfers/data/assets*
transfers/data/nma_csv_cache/*
transfers/data/*.csv
Expand Down
68 changes: 54 additions & 14 deletions db/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# ===============================================================================

import asyncio
import copy
import getpass
import os
from contextlib import contextmanager
Expand All @@ -29,10 +30,32 @@
)
from sqlalchemy.util import await_only

from services.util import get_bool_env

load_dotenv()
driver = os.environ.get("DB_DRIVER", "")


def get_iam_login_token() -> str:
"""
Return a short-lived IAM DB auth token for Cloud SQL Postgres.
"""
from google.auth import default
from google.auth.transport.requests import Request

scopes = ["https://www.googleapis.com/auth/sqlservice.login"]
creds, _ = default()
if hasattr(creds, "with_scopes"):
creds = creds.with_scopes(scopes=scopes)
else:
creds = copy.copy(creds)
creds._scopes = scopes # type: ignore[attr-defined]
creds.refresh(Request())
if not getattr(creds, "token", None):
raise RuntimeError("Unable to acquire IAM DB auth token.")
return creds.token


async def get_async_engine():
"""
Asynchronous database session generator.
Expand All @@ -48,14 +71,21 @@ def asyncify_connection():
user = os.environ.get("CLOUD_SQL_USER")
password = os.environ.get("CLOUD_SQL_PASSWORD")
database = os.environ.get("CLOUD_SQL_DATABASE")

connection = connector.connect_async(
instance_name,
"asyncpg",
db=database,
password=password,
user=user,
)
use_iam_auth = get_bool_env("CLOUD_SQL_IAM_AUTH", False)
ip_type = os.environ.get("CLOUD_SQL_IP_TYPE", "public")

connect_kwargs = {
"db": database,
"user": user,
"enable_iam_auth": use_iam_auth,
"ip_type": ip_type,
}
if use_iam_auth:
connect_kwargs["password"] = get_iam_login_token()
else:
connect_kwargs["password"] = password

connection = connector.connect_async(instance_name, "asyncpg", **connect_kwargs)

return AsyncAdapt_asyncpg_connection(
engine.dialect.dbapi,
Expand All @@ -78,15 +108,25 @@ def init_connection_pool(connector):
user = os.environ.get("CLOUD_SQL_USER")
password = os.environ.get("CLOUD_SQL_PASSWORD")
database = os.environ.get("CLOUD_SQL_DATABASE")
use_iam_auth = get_bool_env("CLOUD_SQL_IAM_AUTH", False)
ip_type = os.environ.get("CLOUD_SQL_IP_TYPE", "public")

def getconn():
connect_kwargs = {
"user": user,
"db": database,
"ip_type": ip_type,
"enable_iam_auth": use_iam_auth,
}
if use_iam_auth:
connect_kwargs["password"] = get_iam_login_token()
else:
connect_kwargs["password"] = password

conn = connector.connect(
instance_name, # The Cloud SQL instance name
"pg8000",
user=user,
password=password,
db=database,
ip_type="public",
**connect_kwargs,
)
return conn

Expand All @@ -107,7 +147,7 @@ def getconn():
connector = Connector()
engine = init_connection_pool(connector)

async_engine = asyncio.run(get_async_engine())
# async_engine = asyncio.run(get_async_engine())

else:
# if driver == "sqlite":
Expand Down Expand Up @@ -161,7 +201,7 @@ def getconn():
# listen(engine, "connect", on_connect)


async_database_sessionmaker = async_sessionmaker(async_engine)
# async_database_sessionmaker = async_sessionmaker(async_engine)
database_sessionmaker = sessionmaker(engine, expire_on_commit=False)


Expand Down
Loading