-
Notifications
You must be signed in to change notification settings - Fork 107
Expand file tree
/
Copy pathtranslator.py
More file actions
301 lines (256 loc) · 9.55 KB
/
translator.py
File metadata and controls
301 lines (256 loc) · 9.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
from __future__ import annotations
import json
import re
import sys
from pathlib import Path
import sqlparse
from felderize.config import Config
from felderize.feldera_client import validate_sql
from felderize.llm import LLMClient, create_client
from felderize.models import Status, TranslationResult
from felderize.skills import build_system_prompt
def _parse_response(raw: str) -> dict:
"""Extract JSON from LLM response, handling markdown fences."""
text = raw.strip()
# Strip markdown code fences if present
match = re.search(r"```(?:json)?\s*\n(.*?)```", text, re.DOTALL)
if match:
text = match.group(1).strip()
return json.loads(text)
def _as_str(val: object) -> str:
"""Normalize a value to string — handles lists from LLM responses."""
if isinstance(val, list):
return "\n".join(str(v) for v in val)
return str(val) if val else ""
def _as_list(val: object) -> list[str]:
if isinstance(val, list):
return [str(v) for v in val]
if isinstance(val, str) and val:
return [val]
return []
def _build_result(data: dict) -> TranslationResult:
unsupported = _as_list(data.get("unsupported", []))
return TranslationResult(
feldera_schema=_as_str(data.get("feldera_schema", "")),
feldera_query=_as_str(data.get("feldera_query", "")),
unsupported=unsupported,
warnings=_as_list(data.get("warnings", [])),
explanations=_as_list(data.get("explanations", [])),
status=Status.UNSUPPORTED if unsupported else Status.SUCCESS,
)
def _build_user_prompt(schema_sql: str, query_sql: str) -> str:
return f"""\
Translate the following Spark SQL schema and query to Feldera SQL.
--- Spark Schema ---
{schema_sql.strip()}
--- Spark Query ---
{query_sql.strip()}
"""
def _build_repair_prompt(
schema_sql: str, query_sql: str, feldera_sql: str, errors: list[str]
) -> str:
error_text = "\n".join(f"- {e}" for e in errors)
return f"""\
The following Feldera SQL was generated from a Spark SQL translation but has compiler errors.
Fix the Feldera SQL to resolve these errors while keeping it semantically equivalent to the original Spark SQL.
--- Original Spark Schema ---
{schema_sql.strip()}
--- Original Spark Query ---
{query_sql.strip()}
--- Failed Feldera SQL ---
{feldera_sql.strip()}
--- Compiler Errors ---
{error_text}
"""
def _translate_with_repair(
schema_sql: str,
query_sql: str,
config: Config,
client: LLMClient,
system_prompt: str,
validate: bool,
max_retries: int,
verbose: bool = False,
) -> TranslationResult:
"""Run one translation attempt with optional validation + repair loop."""
user_prompt = _build_user_prompt(schema_sql, query_sql)
raw = client.translate(system_prompt, user_prompt)
try:
data = _parse_response(raw)
except (json.JSONDecodeError, KeyError) as e:
return TranslationResult(
status=Status.ERROR,
warnings=[f"Failed to parse LLM response: {e}", raw[:500]],
)
result = _build_result(data)
if not validate:
return result
# Validation + repair loop
full_sql = result.feldera_schema + "\n\n" + result.feldera_query
# If LLM produced no SQL or no query, skip validation — schema-only or empty SQL
# always compiles, which is not a meaningful success.
if not full_sql.strip() or not result.feldera_query.strip():
if not result.unsupported:
result.status = Status.UNSUPPORTED
result.unsupported = ["No query generated — translation incomplete"]
return result
for attempt in range(max_retries):
if verbose:
print(
f"\n--- SQL submitted to validator (attempt {attempt + 1}) ---",
file=sys.stderr,
)
print(full_sql, file=sys.stderr)
print("---", file=sys.stderr)
errors = validate_sql(full_sql, config.feldera_compiler or None)
if errors and any("Compiler not found" in e for e in errors):
print(
"Warning: Feldera compiler not found — skipping validation.",
file=sys.stderr,
)
result.warnings.append("Compiler not found — output SQL is not validated")
result.status = Status.UNSUPPORTED if result.unsupported else Status.SUCCESS
return result
if not errors:
result.warnings.append(f"Validated successfully (attempt {attempt + 1})")
result.status = Status.UNSUPPORTED if result.unsupported else Status.SUCCESS
return result
print(
f"Validation attempt {attempt + 1}/{max_retries} failed: {len(errors)} error(s)",
file=sys.stderr,
)
for err in errors:
print(f" {err}", file=sys.stderr)
repair_prompt = _build_repair_prompt(schema_sql, query_sql, full_sql, errors)
raw = client.translate(system_prompt, repair_prompt)
try:
data = _parse_response(raw)
result.feldera_schema = _as_str(
data.get("feldera_schema", result.feldera_schema)
)
result.feldera_query = _as_str(
data.get("feldera_query", result.feldera_query)
)
result.unsupported = _as_list(data.get("unsupported", result.unsupported))
result.warnings = _as_list(data.get("warnings", result.warnings))
result.explanations = _as_list(
data.get("explanations", result.explanations)
)
full_sql = result.feldera_schema + "\n\n" + result.feldera_query
except (json.JSONDecodeError, KeyError):
result.warnings.append(
f"Repair attempt {attempt + 1} produced invalid JSON"
)
# Final validation after all retries
if verbose:
print(
f"\n--- SQL submitted to validator (attempt {max_retries + 1}) ---",
file=sys.stderr,
)
print(full_sql, file=sys.stderr)
print("---", file=sys.stderr)
errors = validate_sql(full_sql, config.feldera_compiler or None)
if not errors:
result.warnings.append(f"Validated successfully (attempt {max_retries + 1})")
result.status = Status.UNSUPPORTED if result.unsupported else Status.SUCCESS
else:
result.status = Status.ERROR
result.warnings.extend(
[f"Still failing after {max_retries} repairs: {e}" for e in errors]
)
return result
def split_combined_sql(sql: str) -> tuple[str, str]:
"""Split a combined SQL file into (schema_sql, query_sql).
Schema: CREATE TABLE statements.
Query: CREATE [OR REPLACE] [TEMP[ORARY]] VIEW statements.
Comments and blank lines are preserved with their associated statement.
Unrecognised statements are placed in schema_sql.
"""
raw_stmts = [s.strip() for s in sqlparse.split(sql) if s.strip()]
schema_parts: list[str] = []
query_parts: list[str] = []
for stripped in raw_stmts:
if not stripped:
continue
# Find first non-comment, non-blank line to identify statement type.
first_kw = next(
(
ln.strip()
for ln in stripped.splitlines()
if ln.strip() and not ln.strip().startswith("--")
),
"",
).upper()
if not first_kw:
continue # comment-only block
if re.match(r"CREATE\s+(OR\s+REPLACE\s+)?(TEMP(ORARY)?\s+)?VIEW\b", first_kw):
query_parts.append(stripped + ";")
else:
schema_parts.append(stripped + ";")
return "\n\n".join(schema_parts), "\n\n".join(query_parts)
def translate_spark_to_feldera(
schema_sql: str,
query_sql: str,
config: Config,
validate: bool = False,
max_retries: int = 3,
docs_only_fallback: bool = True,
skills_dir: str | None = None,
docs_dir: str | None = None,
include_docs: bool = True,
force_docs: bool = False,
verbose: bool = False,
) -> TranslationResult:
combined_sql = schema_sql + "\n" + query_sql
docs_dir_path = Path(docs_dir) if docs_dir else None
client = create_client(config)
# First pass: skills only (no examples, no docs).
system_prompt_skills = build_system_prompt(
skills_dir,
docs_dir=docs_dir_path,
spark_sql=combined_sql,
with_docs=False,
with_examples=False,
)
result = _translate_with_repair(
schema_sql,
query_sql,
config,
client,
system_prompt_skills,
validate,
max_retries,
verbose,
)
# Determine whether to retry with docs:
# - always retry on ERROR (existing behaviour, controlled by docs_only_fallback)
# - also retry on NOT SUCCESS when force_docs is set
should_retry = include_docs and (
(result.status == Status.ERROR and docs_only_fallback)
or (result.status != Status.SUCCESS and force_docs)
)
if not should_retry:
return result
# Second pass: docs only (no skills, no examples).
print("Retrying with docs-only prompt...", file=sys.stderr)
system_prompt_docs = build_system_prompt(
skills_dir,
docs_dir=docs_dir_path,
spark_sql=combined_sql,
with_docs=True,
with_examples=False,
with_skills=False,
)
result = _translate_with_repair(
schema_sql,
query_sql,
config,
client,
system_prompt_docs,
validate,
max_retries,
verbose,
)
if result.status != Status.ERROR:
result.warnings.append("Resolved with docs-only fallback")
return result