Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 86 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sqlite3
import sys
import time
import textwrap
from dataclasses import dataclass
from datetime import datetime
from functools import partial
Expand Down Expand Up @@ -51,6 +52,10 @@ def init_db():
time_taken REAL,
status_code INTEGER,
status_api_endpoint TEXT,
total_embedding_cost REAL,
total_embedding_tokens INTEGER,
total_llm_cost REAL,
total_llm_tokens INTEGER,
updated_at TEXT,
created_at TEXT
)"""
Expand Down Expand Up @@ -97,6 +102,15 @@ def update_db(
status_code,
status_api_endpoint,
):

total_embedding_cost = None
total_embedding_tokens = None
total_llm_cost = None
total_llm_tokens = None

if result is not None:
total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens = calculate_cost_and_tokens(result)

conn = sqlite3.connect(DB_NAME)
conn.set_trace_callback(
lambda x: (
Expand All @@ -109,16 +123,20 @@ def update_db(
now = datetime.now().isoformat()
c.execute(
"""
INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, updated_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
""",
INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, updated_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
""",
(
file_name,
execution_status,
json.dumps(result),
time_taken,
status_code,
status_api_endpoint,
total_embedding_cost,
total_embedding_tokens,
total_llm_cost,
total_llm_tokens,
now,
file_name,
now,
Expand All @@ -127,10 +145,57 @@ def update_db(
conn.commit()
conn.close()

# Calculate total cost and tokens for detailed report
def calculate_cost_and_tokens(result):

total_embedding_cost = None
total_embedding_tokens = None
total_llm_cost = None
total_llm_tokens = None

# Extract 'extraction_result' from the result
extraction_result = result.get("extraction_result", [])

if not extraction_result:
return total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens

extraction_data = extraction_result[0].get("result", "")

# If extraction_data is a string, attempt to parse it as JSON
if isinstance(extraction_data, str):
try:
extraction_data = json.loads(extraction_data) if extraction_data else {}
except json.JSONDecodeError:
logger.warning("Failed to decode JSON for extraction data; defaulting to empty dictionary.")
extraction_data = {}


metadata = extraction_data.get("metadata", None)
embedding_llm = metadata.get("embedding") if metadata else None
extraction_llm = metadata.get("extraction_llm") if metadata else None

#Process embedding costs and tokens if embedding_llm list exists and is not empty
if embedding_llm:
total_embedding_cost = 0.0
total_embedding_tokens = 0
for item in embedding_llm:
total_embedding_cost += float(item.get("cost_in_dollars", "0"))
total_embedding_tokens += item.get("embedding_tokens", 0)

#Process embedding costs and tokens if extraction_llm list exists and is not empty
if extraction_llm:
total_llm_cost = 0.0
total_llm_tokens = 0
for item in extraction_llm:
total_llm_cost += float(item.get("cost_in_dollars", "0"))
total_llm_tokens += item.get("total_tokens", 0)

return total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens


# Print final summary with count of each status and average time using a single SQL query
def print_summary():
conn = sqlite3.connect("file_processing.db")
conn = sqlite3.connect(DB_NAME)
c = conn.cursor()

# Fetch count and average time for each status
Expand All @@ -153,13 +218,13 @@ def print_summary():


def print_report():
conn = sqlite3.connect("file_processing.db")
conn = sqlite3.connect(DB_NAME)
c = conn.cursor()

# Fetch count and average time for each status
# Fetch required fields, including total_cost and total_tokens
c.execute(
"""
SELECT file_name, execution_status, time_taken
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens
FROM file_status
"""
)
Expand All @@ -170,8 +235,20 @@ def print_report():
print("\nDetailed Report:")
if report_data:
# Tabulate the data with column headers
headers = ["File Name", "Execution Status", "Time Elapsed (seconds)"]
print(tabulate(report_data, headers=headers, tablefmt="pretty"))
headers = ["File Name", "Execution Status", "Time Elapsed (seconds)", "Total Embedding Cost", "Total Embedding Tokens", "Total LLM Cost", "Total LLM Tokens"]

# Wrap text in each column to a specific width (e.g., 30 characters for file names and 20 for others) and return None if the value is NULL
formatted_data = []
for row in report_data:
formatted_row = [
"None" if cell is None else
textwrap.fill(str(cell), width=30) if isinstance(cell, str) else
f"{cell:.8f}" if isinstance(cell, float) else cell
for cell in row
]
formatted_data.append(formatted_row)

print(tabulate(formatted_data, headers=headers, tablefmt="pretty"))
else:
print("No records found in the database.")

Expand Down