Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.db
.venv/
151 changes: 90 additions & 61 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,35 @@
import sqlite3
import sys
import time
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from multiprocessing import Manager, Pool

from tabulate import tabulate
from tqdm import tqdm
from unstract.api_deployments.client import APIDeploymentsClient

DB_NAME = "file_processing.db"
global_arguments = None
logger = logging.getLogger(__name__)

# https://docs.unstract.com/unstract_platform/api_deployment/unstract_api_deployment_execution_api#possible-execution-status

# Dataclass for arguments
@dataclass
class Arguments:
api_endpoint: str
api_key: str
api_timeout: int = 10
poll_interval: int = 5
input_folder_path: str = ""
parallel_call_count: int = 10
retry_failed: bool = False
retry_pending: bool = False
skip_pending: bool = False
skip_unprocessed: bool = False
log_level: str = "INFO"
print_report: bool = False


# Initialize SQLite DB
Expand All @@ -41,7 +58,7 @@ def init_db():


# Check if the file is already processed
def skip_file_processing(file_name, retry_failed, skip_unprocessed, skip_pending):
def skip_file_processing(file_name, args: Arguments):
conn = sqlite3.connect(DB_NAME)
c = conn.cursor()
c.execute(
Expand All @@ -51,20 +68,22 @@ def skip_file_processing(file_name, retry_failed, skip_unprocessed, skip_pending
conn.close()

if not row:
if skip_unprocessed:
if args.skip_unprocessed:
logger.warning(f"[{file_name}] Skipping due to the flag `skip_unprocessed`")
return skip_unprocessed # skip unprocessed files
return args.skip_unprocessed # skip unprocessed files

if row[0] == "ERROR":
if not retry_failed:
logger.warning(f"[{file_name}] Skipping due to the flag `retry_failed`")
return not retry_failed
if not args.retry_failed:
logger.warning(
f"[{file_name}] Skipping due to the flag not set `retry_failed`"
)
return not args.retry_failed
elif row[0] == "COMPLETED":
return True
else:
if skip_pending:
if args.skip_pending:
logger.warning(f"[{file_name}] Skipping due to the flag `skip_pending`")
return skip_pending
return args.skip_pending


# Update status in SQLite DB
Expand Down Expand Up @@ -115,7 +134,7 @@ def print_summary():
# Fetch count and average time for each status
c.execute(
"""
SELECT execution_status, COUNT(*) AS status_count, AVG(time_taken) AS avg_time_taken
SELECT execution_status, COUNT(*) AS status_count
FROM file_status
GROUP BY execution_status
"""
Expand All @@ -128,13 +147,35 @@ def print_summary():
for row in summary:
status = row[0]
count = row[1]
avg_time = row[2] or 0 # Handle NULL avg_time
print(f"Status '{status}': {count} (Avg time: {avg_time:.2f} seconds)")
print(f"Status '{status}': {count}")


def print_report():
conn = sqlite3.connect("file_processing.db")
c = conn.cursor()

def get_status_endpoint(file_path, client, retry_pending):
"""Returns status_endpoint, status and response (if available)
# Fetch count and average time for each status
c.execute(
"""
SELECT file_name, execution_status, time_taken
FROM file_status
"""
)
report_data = c.fetchall()
conn.close()

# Print the summary
print("\nDetailed Report:")
if report_data:
# Tabulate the data with column headers
headers = ["File Name", "Execution Status", "Time Elapsed (seconds)"]
print(tabulate(report_data, headers=headers, tablefmt="pretty"))
else:
print("No records found in the database.")


def get_status_endpoint(file_path, client, args: Arguments):
"""Returns status_endpoint, status and response (if available)"""
status_endpoint = None

# If retry_pending is True, check if the status API endpoint is available
Expand All @@ -152,7 +193,7 @@ def get_status_endpoint(file_path, client, retry_pending):
status_endpoint = row[0]

# status_endpoint is only available for pending items. retry_pending will force retry and hence ignore existing.
if retry_pending:
if args.retry_pending:
status_endpoint = None

if status_endpoint:
Expand All @@ -166,7 +207,9 @@ def get_status_endpoint(file_path, client, retry_pending):
update_db(file_path, execution_status, None, None, None, None)
response = client.structure_file(file_paths=[file_path])
logger.debug(f"[{file_path}] Response of initial API call: {response}")
status_endpoint = response.get("status_check_api_endpoint") # If ERROR or completed this will be None
status_endpoint = response.get(
"status_check_api_endpoint"
) # If ERROR or completed this will be None
execution_status = response.get("execution_status")
status_code = response.get("status_code")
update_db(
Expand All @@ -181,20 +224,12 @@ def get_status_endpoint(file_path, client, retry_pending):


def process_file(
file_path,
success_count,
failure_count,
skipped_count,
retry_failed,
skip_unprocessed,
retry_pending,
skip_pending,
file_path, success_count, failure_count, skipped_count, args: Arguments
):
global global_arguments
logger.info(f"[{file_path}]: Processing started")

# Any file which should be skipped will happen at this point.
if skip_file_processing(file_path, retry_failed, skip_unprocessed, skip_pending):
if skip_file_processing(file_name=file_path, args=args):
logger.warning(f"[{file_path}]: Skipping processing.")
skipped_count.value += 1
return
Expand All @@ -205,18 +240,18 @@ def process_file(

try:
client = APIDeploymentsClient(
api_url=global_arguments.api_endpoint,
api_key=global_arguments.api_key,
api_timeout=global_arguments.api_timeout,
logging_level=global_arguments.log_level,
api_url=args.api_endpoint,
api_key=args.api_key,
api_timeout=args.api_timeout,
logging_level=args.log_level,
)

status_endpoint, execution_status, response = get_status_endpoint(
file_path=file_path, client=client, retry_pending=retry_pending
file_path=file_path, client=client, args=args
)
# Polling until status is COMPLETE or ERROR
while execution_status not in ["COMPLETED", "ERROR"]:
time.sleep(global_arguments.poll_interval)
time.sleep(args.poll_interval)
response = client.check_execution_status(status_endpoint)
execution_status = response.get("execution_status")
status_code = response.get("status_code") # Default to 200 if not provided
Expand Down Expand Up @@ -246,21 +281,14 @@ def process_file(
logger.info(f"[{file_path}]: Processing completed: {execution_status}")


def load_folder(
folder_path,
parallel_count,
retry_failed,
skip_unprocessed,
retry_pending,
skip_pending,
):
def load_folder(args: Arguments):
files = [
os.path.join(folder_path, f)
for f in os.listdir(folder_path)
if os.path.isfile(os.path.join(folder_path, f))
os.path.join(args.input_folder_path, f)
for f in os.listdir(args.input_folder_path)
if os.path.isfile(os.path.join(args.input_folder_path, f))
]

with Manager() as manager, Pool(parallel_count) as executor:
with Manager() as manager, Pool(args.parallel_call_count) as executor:
success_count = manager.Value("i", 0) # Shared integer for success count
failure_count = manager.Value("i", 0) # Shared integer for failure count
skipped_count = manager.Value("i", 0) # Shared integer for skipped count
Expand All @@ -280,10 +308,7 @@ def load_folder(
success_count=success_count,
failure_count=failure_count,
skipped_count=skipped_count,
retry_failed=retry_failed,
skip_unprocessed=skip_unprocessed,
retry_pending=retry_pending,
skip_pending=skip_pending,
args=args,
)

for _ in executor.imap_unordered(process_file_partial, files):
Expand Down Expand Up @@ -377,28 +402,32 @@ def main():
choices=["DEBUG", "INFO", "WARN", "ERROR"],
help="Minimum loglevel for logging",
)
parser.add_argument(
"--print_report",
dest="print_report",
action="store_true",
help="Print a detailed report of all file processed.",
)

global global_arguments
global_arguments = parser.parse_args()
args = Arguments(**vars(parser.parse_args()))

ch = logging.StreamHandler(sys.stdout)
ch.setLevel(global_arguments.log_level)
logging.basicConfig(level=global_arguments.log_level, handlers=[ch])
ch.setLevel(args.log_level)
logging.basicConfig(level=args.log_level, handlers=[ch])

logger.warning(f"Running with params: {global_arguments}")
logger.warning(f"Running with params: {args}")

init_db() # Initialize DB

load_folder(
folder_path=global_arguments.input_folder_path,
parallel_count=global_arguments.parallel_call_count,
retry_failed=global_arguments.retry_failed,
skip_unprocessed=global_arguments.skip_unprocessed,
retry_pending=global_arguments.retry_pending,
skip_pending=global_arguments.skip_pending,
)
load_folder(args=args)

print_summary() # Print summary at the end
if args.print_report:
print_report()
logger.warning(
"Elapsed time calculation of a file which was resumed"
" from pending state will not be correct"
)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
unstract-client~=0.1.0
tqdm~=4.66.5
tabulate~=0.9.0