diff --git a/ml/config/artifacts.yaml b/ml/config/artifacts.yaml index ca594f4e6b49afeb394b7e04435df06e5222e037..3335a5de0b1e8dc961fa1bf9af21afdf131ee405 100644 --- a/ml/config/artifacts.yaml +++ b/ml/config/artifacts.yaml @@ -8,8 +8,8 @@ artifacts: # The 'id' is the direct download URL from the GitLab Package Registry. # 'id' format https://gitlab.wikimedia.org/api/v4/projects/:project_id/packages/generic/:package_name/:version/:filename # TODO: change this to main repo once the MR is merged. - add_a_link-0.1.1-v0.0.1.conda.tgz: - id: https://gitlab.wikimedia.org/api/v4/projects/3241/packages/generic/add_a_link/0.1.1-v0.0.1/add_a_link-0.1.1-v0.0.1.conda.tgz + add_a_link-0.1.1-v0.0.2.conda.tgz: + id: https://gitlab.wikimedia.org/api/v4/projects/3241/packages/generic/add_a_link/0.1.1-v0.0.2/add_a_link-0.1.1-v0.0.2.conda.tgz source: url # Supports SparkSqlOperator to work correctly with https URLs that include url-encoded bits, diff --git a/ml/dags/add_a_link_pipeline_dag.py b/ml/dags/add_a_link_pipeline_dag.py index 1a8b99c022c66f2fb3cf0e85535c15576da30c1c..9897337909b72587b0116779b07d9bd8805020f9 100644 --- a/ml/dags/add_a_link_pipeline_dag.py +++ b/ml/dags/add_a_link_pipeline_dag.py @@ -14,10 +14,23 @@ from wmf_airflow_common.hooks.spark import kwargs_for_virtualenv from wmf_airflow_common.operators.spark import SparkSubmitOperator # The name(s) of the artifact(s) as defined in ml/config/artifacts.yaml file -artifact_name = "add_a_link-0.1.1-v0.0.1.conda.tgz" +artifact_name = "add_a_link-0.1.1-v0.0.2.conda.tgz" # TODO: add more languages once ready. -wiki_dbs = ["nlwiki", "trwiki"] +wiki_dbs = { + "shard_nl": ["nlwiki"], + "shard_tr": ["trwiki"], + "shard_ja": ["jawiki"], +} +snapshot = "2025-06" +wikidata_snapshot = "2025-07-14" +wikidata_properties = ["P31"] +max_sentences_per_wiki = 200_000 +files_per_wiki = 5 +grid_search = False +thresholds = [0.5] +n_max = 10_000 +hdfs_path = "/tmp/ml" with DAG( dag_id="add_a_link_pipeline", @@ -28,36 +41,160 @@ with DAG( catchup=False, tags=["ml"], ) as dag: - common_kwargs = kwargs_for_virtualenv( - virtualenv_archive=artifact(artifact_name), - entry_point="bin/generate_anchor_dictionary.py", - use_virtualenv_spark=True, - retries=0, - launcher="skein", - driver_memory="32G", - executor_memory="32g", - executor_cores=4, - conf={ + common_kwargs = { + "virtualenv_archive": artifact(artifact_name), + "use_virtualenv_spark": True, + "retries": 0, + "launcher": "skein", + "driver_memory": "32G", + "executor_memory": "32g", + "executor_cores": 4, + "env_vars": {"LD_LIBRARY_PATH": "venv/lib"}, + "conf": { "spark.dynamicAllocation.maxExecutors": 119, "spark.driver.maxResultSize": "40G", "spark.sql.execution.arrow.pyspark.enabled": True, "spark.executor.memoryOverhead": "4g", }, - ) + } generate_anchor_dictionary_task = SparkSubmitOperator.partial( - task_id="generate_anchor_dictionary", **common_kwargs + task_id="generate_anchor_dictionary", + **kwargs_for_virtualenv( + **common_kwargs, entry_point="bin/generate_anchor_dictionary.py" # type: ignore[arg-type] + ), ).expand( application_args=[ [ "--wiki_dbs", - wiki, + ",".join(wikis), "--output", - f"/tmp/ml/addalink/{wiki}", + f"{hdfs_path}/addalink/{model_id}", "--model_id", - f"shard_{wiki}", + model_id, + ] + for model_id, wikis in wiki_dbs.items() + ] + ) + + generate_wdproperties_task = SparkSubmitOperator.partial( + task_id="generate_wdproperties", + **kwargs_for_virtualenv(**common_kwargs, entry_point="bin/generate_wdproperties.py"), # type: ignore[arg-type] + ).expand( + application_args=[ + [ + "--wikidata_snapshot", + wikidata_snapshot, + "--wiki_dbs", + ",".join(wikis), + "--wikidata_properties", + ",".join(wikidata_properties), + "--directory", + f"{hdfs_path}/addalink/{model_id}", + ] + for model_id, wikis in wiki_dbs.items() + ] + ) + filter_dict_anchor = SparkSubmitOperator.partial( + task_id="filter_dict_anchor", + **kwargs_for_virtualenv(**common_kwargs, entry_point="bin/filter_dict_anchor.py"), # type: ignore[arg-type] + ).expand( + application_args=[ + [ + "--directory", + f"{hdfs_path}/addalink/{model_id}", ] - for wiki in wiki_dbs + for model_id, _ in wiki_dbs.items() ] ) - generate_anchor_dictionary_task + + generate_backtesting_data = SparkSubmitOperator.partial( + task_id="generate_backtesting_data", + **kwargs_for_virtualenv( + **common_kwargs, entry_point="bin/generate_backtesting_data.py" # type: ignore[arg-type] + ), + ).expand( + application_args=[ + [ + "--wiki_dbs", + ",".join(wikis), + "--directory", + f"{hdfs_path}/addalink/{model_id}", + "--max_sentences_per_wiki", + max_sentences_per_wiki, + ] + for model_id, wikis in wiki_dbs.items() + ] + ) + + generate_training_data = SparkSubmitOperator.partial( + task_id="generate_training_data", + **kwargs_for_virtualenv(**common_kwargs, entry_point="bin/generate_training_data.py"), # type: ignore[arg-type] + ).expand( + application_args=[ + [ + "--snapshot", + snapshot, + "--wiki_dbs", + ",".join(wikis), + "--directory", + f"{hdfs_path}/addalink/{model_id}", + "--files_per_wiki", + files_per_wiki, + "--model_id", + model_id, + ] + for model_id, wikis in wiki_dbs.items() + ] + ) + + generate_addlink_model = SparkSubmitOperator.partial( + task_id="generate_addlink_model", + **kwargs_for_virtualenv(**common_kwargs, entry_point="bin/generate_addlink_model.py"), # type: ignore[arg-type] + ).expand( + application_args=[ + [ + "--wiki_dbs", + ",".join(wikis), + "--model_id", + model_id, + "--directory", + f"{hdfs_path}/addalink/{model_id}", + "--grid_search", + grid_search, + ] + for model_id, wikis in wiki_dbs.items() + ] + ) + + generate_backtesting_eval = SparkSubmitOperator.partial( + task_id="generate_backtesting_eval", + **kwargs_for_virtualenv( + **common_kwargs, entry_point="bin/generate_backtesting_eval.py" # type: ignore[arg-type] + ), + ).expand( + application_args=[ + [ + "--model_id", + model_id, + "--wiki_dbs", + ",".join(wikis), + "--directory", + f"{hdfs_path}/addalink/{model_id}", + "--thresholds", + ",".join((str(t) for t in thresholds)), + "--n_max", + n_max, + ] + for model_id, wikis in wiki_dbs.items() + ] + ) + ( + generate_anchor_dictionary_task + >> generate_wdproperties_task + >> filter_dict_anchor + >> generate_backtesting_data + >> generate_training_data + >> generate_addlink_model + >> generate_backtesting_eval + ) diff --git a/tests/ml/add_a_link_pipeline_dag_test.py b/tests/ml/add_a_link_pipeline_dag_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c4b95e4b69177b1a3767c39f20585420b0c1dd --- /dev/null +++ b/tests/ml/add_a_link_pipeline_dag_test.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import pytest + + +# This fixture defines the dag_path for the shared dagbag one +@pytest.fixture(name="dag_path") +def fixture_dagpath(): + return ["ml", "dags", "add_a_link_pipeline_dag.py"] + + +def test_check_bad_parsing_dag_loaded(dagbag): + assert dagbag.import_errors == {} + dag = dagbag.get_dag(dag_id="add_a_link_pipeline") + assert dag is not None + assert len(dag.tasks) == 7