Skip to content

Commit bcc707b

Browse files
committed
feat: enforce PostgreSQL compatibility baselines
1 parent 01a2e3a commit bcc707b

2 files changed

Lines changed: 470 additions & 0 deletions

File tree

scripts/pg_compat/baseline.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#!/usr/bin/env python3
2+
3+
import copy
4+
import json
5+
6+
if __package__:
7+
from .common import atomic_write_text
8+
else:
9+
from common import atomic_write_text
10+
11+
12+
RESULTS = (
13+
"DEEP_SUPPORTED",
14+
"CLASSIFIED_ONLY",
15+
"PARTIAL",
16+
"ERROR",
17+
"TRAILING_INPUT",
18+
"TYPE_MISMATCH",
19+
)
20+
RESULT_SET = frozenset(RESULTS)
21+
CLASSIFIED_IMPROVEMENTS = frozenset(
22+
("PARTIAL", "ERROR", "TRAILING_INPUT", "TYPE_MISMATCH")
23+
)
24+
SUPPORT_LEVEL = {
25+
"DEEP_SUPPORTED": 2,
26+
"CLASSIFIED_ONLY": 1,
27+
"PARTIAL": 0,
28+
"ERROR": 0,
29+
"TRAILING_INPUT": 0,
30+
"TYPE_MISMATCH": 0,
31+
}
32+
33+
34+
def _validate_result(result, label):
35+
if not isinstance(result, str) or result not in RESULT_SET:
36+
raise ValueError(f"{label} result has unsupported value {result!r}")
37+
38+
39+
def transition_allowed(previous, current):
40+
_validate_result(previous, "previous")
41+
_validate_result(current, "current")
42+
43+
if previous == current:
44+
return True
45+
if current == "DEEP_SUPPORTED":
46+
return True
47+
return (
48+
previous in CLASSIFIED_IMPROVEMENTS
49+
and current == "CLASSIFIED_ONLY"
50+
)
51+
52+
53+
def _records_by_id(rows, label, *, require_sql=False):
54+
records = {}
55+
for index, row in enumerate(rows):
56+
if not isinstance(row, dict):
57+
raise ValueError(f"{label} row {index}: expected a JSON object")
58+
59+
record_id = row.get("id")
60+
if not isinstance(record_id, str) or not record_id:
61+
raise ValueError(
62+
f"{label} row {index}: field 'id' must be a non-empty string"
63+
)
64+
if record_id in records:
65+
raise ValueError(f"{label} has duplicate ID {record_id!r}")
66+
67+
result = row.get("result")
68+
if not isinstance(result, str) or result not in RESULT_SET:
69+
raise ValueError(
70+
f"{label} row {index}: field 'result' has unsupported value "
71+
f"{result!r}"
72+
)
73+
74+
if require_sql:
75+
sql = row.get("sql")
76+
if not isinstance(sql, str) or not sql:
77+
raise ValueError(
78+
f"{label} row {index}: field 'sql' must be a non-empty string"
79+
)
80+
oracle_node = row.get("oracle_node")
81+
if not isinstance(oracle_node, str) or not oracle_node:
82+
raise ValueError(
83+
f"{label} row {index}: field 'oracle_node' must be a "
84+
"non-empty string"
85+
)
86+
87+
records[record_id] = row
88+
return records
89+
90+
91+
def _transition_record(previous, current):
92+
record = copy.deepcopy(current)
93+
record["previous_result"] = previous["result"]
94+
record["current_result"] = current["result"]
95+
return record
96+
97+
98+
def evaluate_baseline(previous_rows, current_rows):
99+
previous = _records_by_id(previous_rows, "previous")
100+
current = _records_by_id(current_rows, "current")
101+
102+
allowed = []
103+
regressions = []
104+
review_required = []
105+
for record_id in sorted(previous.keys() & current.keys()):
106+
previous_row = previous[record_id]
107+
current_row = current[record_id]
108+
transition = _transition_record(previous_row, current_row)
109+
if transition_allowed(previous_row["result"], current_row["result"]):
110+
allowed.append(transition)
111+
elif (
112+
SUPPORT_LEVEL[current_row["result"]]
113+
< SUPPORT_LEVEL[previous_row["result"]]
114+
):
115+
regressions.append(transition)
116+
else:
117+
review_required.append(transition)
118+
119+
return {
120+
"allowed": allowed,
121+
"regressions": regressions,
122+
"review_required": review_required,
123+
"new_cases": [
124+
copy.deepcopy(current[record_id])
125+
for record_id in sorted(current.keys() - previous.keys())
126+
],
127+
"missing_ids": sorted(previous.keys() - current.keys()),
128+
}
129+
130+
131+
def _case_identity(row):
132+
return (
133+
row["id"],
134+
row["sql"],
135+
row["oracle_node"],
136+
row["result"],
137+
)
138+
139+
140+
def build_ci_cases(inventory_rows, release_delta_rows, witness_rows):
141+
sources = (
142+
("inventory", inventory_rows),
143+
("release delta", release_delta_rows),
144+
("witness", witness_rows),
145+
)
146+
validated = {
147+
label: _records_by_id(rows, label, require_sql=True)
148+
for label, rows in sources
149+
}
150+
151+
selected_ids = set()
152+
inventory = validated["inventory"]
153+
for record_id, row in inventory.items():
154+
if row["result"] != "DEEP_SUPPORTED":
155+
selected_ids.add(record_id)
156+
157+
representatives = {}
158+
for record_id, row in inventory.items():
159+
if row["result"] not in ("DEEP_SUPPORTED", "CLASSIFIED_ONLY"):
160+
continue
161+
key = (row["oracle_node"], row["result"])
162+
representatives[key] = min(record_id, representatives.get(key, record_id))
163+
selected_ids.update(representatives.values())
164+
165+
selected = {}
166+
167+
def add_case(row):
168+
record_id = row["id"]
169+
existing = selected.get(record_id)
170+
if existing is not None:
171+
if _case_identity(existing) != _case_identity(row):
172+
raise ValueError(f"conflicting records for ID {record_id!r}")
173+
return
174+
175+
case = copy.deepcopy(row)
176+
case["expected_result"] = case["result"]
177+
selected[record_id] = case
178+
179+
for record_id in sorted(selected_ids):
180+
add_case(inventory[record_id])
181+
for label in ("release delta", "witness"):
182+
for record_id in sorted(validated[label]):
183+
add_case(validated[label][record_id])
184+
185+
return [selected[record_id] for record_id in sorted(selected)]
186+
187+
188+
def write_ci_cases(path, cases):
189+
records = _records_by_id(cases, "CI case", require_sql=True)
190+
output = []
191+
for record_id in sorted(records):
192+
row = copy.deepcopy(records[record_id])
193+
expected_result = row.get("expected_result")
194+
if (
195+
not isinstance(expected_result, str)
196+
or expected_result not in RESULT_SET
197+
):
198+
raise ValueError(
199+
f"CI case {record_id!r}: field 'expected_result' has "
200+
f"unsupported value {expected_result!r}"
201+
)
202+
if expected_result != row["result"]:
203+
raise ValueError(
204+
f"CI case {record_id!r}: expected_result does not match result"
205+
)
206+
output.append(
207+
json.dumps(
208+
row,
209+
allow_nan=False,
210+
ensure_ascii=False,
211+
separators=(",", ":"),
212+
sort_keys=True,
213+
)
214+
)
215+
216+
text = "\n".join(output)
217+
if output:
218+
text += "\n"
219+
atomic_write_text(path, text)

0 commit comments

Comments
 (0)