Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 148 additions & 5 deletions src/crawlee/_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from __future__ import annotations

from collections.abc import Iterator, Mapping
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Annotated, Any, Literal, Protocol, TypeVar, Union
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Literal,
Optional,
Protocol,
TypeVar,
Union,
cast,
overload,
)

from pydantic import ConfigDict, Field, PlainValidator, RootModel
from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack
Expand All @@ -19,7 +30,8 @@
from crawlee.http_clients import HttpResponse
from crawlee.proxy_configuration import ProxyInfo
from crawlee.sessions._session import Session
from crawlee.storages._dataset import ExportToKwargs, GetDataKwargs, PushDataKwargs
from crawlee.storages._dataset import ExportToKwargs, GetDataKwargs
from crawlee.storages._key_value_store import KeyValueStore

# Workaround for https://github.com/pydantic/pydantic/issues/9445
J = TypeVar('J', bound='JsonSerializable')
Expand Down Expand Up @@ -188,6 +200,10 @@ def __call__(
) -> Coroutine[None, None, DatasetItemsListPage]: ...


class PushDataKwargs(TypedDict):
"""Keyword arguments for dataset's `push_data` method."""


class PushDataFunction(Protocol):
"""Type of a function for pushing data to the dataset.

Expand All @@ -204,6 +220,12 @@ def __call__(
) -> Coroutine[None, None, None]: ...


class PushDataFunctionCall(PushDataKwargs):
data: JsonSerializable
dataset_id: str | None
dataset_name: str | None


class ExportToFunction(Protocol):
"""Type of a function for exporting data from a dataset.

Expand Down Expand Up @@ -251,6 +273,42 @@ def __call__(
) -> Coroutine[None, None, HttpResponse]: ...


T = TypeVar('T')


class KeyValueStoreInterface(Protocol):
"""The (limited) part of the `KeyValueStore` interface that should be accessible from a request handler."""

@overload
async def get_value(self, key: str) -> Any: ...

@overload
async def get_value(self, key: str, default_value: T) -> T: ...

@overload
async def get_value(self, key: str, default_value: T | None = None) -> T | None: ...

async def get_value(self, key: str, default_value: T | None = None) -> T | None: ...

async def set_value(
self,
key: str,
value: Any,
content_type: str | None = None,
) -> None: ...


class GetKeyValueStoreFromRequestHandlerFunction(Protocol):
"""Type of a function for accessing a key-value store from within a request handler."""

def __call__(
self,
*,
id: str | None = None,
name: str | None = None,
) -> Coroutine[None, None, KeyValueStoreInterface]: ...


@dataclass(frozen=True)
class BasicCrawlingContext:
"""Basic crawling context intended to be extended by crawlers."""
Expand All @@ -261,19 +319,104 @@ class BasicCrawlingContext:
send_request: SendRequestFunction
add_requests: AddRequestsFunction
push_data: PushDataFunction
get_key_value_store: GetKeyValueStoreFromRequestHandlerFunction
log: logging.Logger


@dataclass()
class KeyValueStoreValue:
content: Any
content_type: str | None


class KeyValueStoreChangeRecords:
def __init__(self, actual_key_value_store: KeyValueStore) -> None:
self.updates = dict[str, KeyValueStoreValue]()
self._actual_key_value_store = actual_key_value_store

async def set_value(
self,
key: str,
value: Any,
content_type: str | None = None,
) -> None:
self.updates[key] = KeyValueStoreValue(value, content_type)

@overload
async def get_value(self, key: str) -> Any: ...

@overload
async def get_value(self, key: str, default_value: T) -> T: ...

@overload
async def get_value(self, key: str, default_value: T | None = None) -> T | None: ...

async def get_value(self, key: str, default_value: T | None = None) -> T | None:
if key in self.updates:
return cast(T, self.updates[key].content)

return await self._actual_key_value_store.get_value(key, default_value)


class GetKeyValueStoreFunction(Protocol):
"""Type of a function for accessing the live implementation of a key-value store."""

def __call__(
self,
*,
id: str | None = None,
name: str | None = None,
) -> Coroutine[None, None, KeyValueStore]: ...


class RequestHandlerRunResult:
"""Record of calls to storage-related context helpers."""

add_requests_calls: list[AddRequestsFunctionCall] = field(default_factory=list)
def __init__(self, *, key_value_store_getter: GetKeyValueStoreFunction) -> None:
self._key_value_store_getter = key_value_store_getter
self.add_requests_calls = list[AddRequestsFunctionCall]()
self.push_data_calls = list[PushDataFunctionCall]()
self.key_value_store_changes = dict[tuple[Optional[str], Optional[str]], KeyValueStoreChangeRecords]()

async def add_requests(
self,
requests: Sequence[str | BaseRequestData],
**kwargs: Unpack[AddRequestsKwargs],
) -> None:
"""Track a call to the `add_requests` context helper."""
self.add_requests_calls.append(AddRequestsFunctionCall(requests=requests, **kwargs))
self.add_requests_calls.append(
AddRequestsFunctionCall(
requests=requests,
**kwargs,
)
)

async def push_data(
self,
data: JsonSerializable,
dataset_id: str | None = None,
dataset_name: str | None = None,
**kwargs: Unpack[PushDataKwargs],
) -> None:
"""Track a call to the `push_data` context helper."""
self.push_data_calls.append(
PushDataFunctionCall(
data=data,
dataset_id=dataset_id,
dataset_name=dataset_name,
**kwargs,
)
)

async def get_key_value_store(
self,
*,
id: str | None = None,
name: str | None = None,
) -> KeyValueStoreInterface:
if (id, name) not in self.key_value_store_changes:
self.key_value_store_changes[id, name] = KeyValueStoreChangeRecords(
await self._key_value_store_getter(id=id, name=name)
)

return self.key_value_store_changes[id, name]
34 changes: 23 additions & 11 deletions src/crawlee/basic_crawler/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from crawlee._autoscaling.snapshotter import Snapshotter
from crawlee._autoscaling.system_status import SystemStatus
from crawlee._log_config import configure_logger, get_configured_log_level
from crawlee._request import BaseRequestData, Request, RequestState
from crawlee._request import Request, RequestState
from crawlee._types import BasicCrawlingContext, HttpHeaders, RequestHandlerRunResult, SendRequestFunction
from crawlee._utils.byte_size import ByteSize
from crawlee._utils.http import is_status_code_client_error
Expand Down Expand Up @@ -716,41 +716,52 @@ async def _commit_request_handler_result(
request_provider = await self.get_request_provider()
origin = context.request.loaded_url or context.request.url

for call in result.add_requests_calls:
for add_requests_call in result.add_requests_calls:
requests = list[Request]()

for request in call['requests']:
if (limit := call.get('limit')) is not None and len(requests) >= limit:
for request in add_requests_call['requests']:
if (limit := add_requests_call.get('limit')) is not None and len(requests) >= limit:
break

# If the request is a Request object, keep it as it is
if isinstance(request, Request):
dst_request = request
# If the request is a string, convert it to Request object.
if isinstance(request, str):
if is_url_absolute(request):
dst_request = Request.from_url(request)

# If the request URL is relative, make it absolute using the origin URL.
else:
base_url = call['base_url'] if call.get('base_url') else origin
base_url = url if (url := add_requests_call.get('base_url')) else origin
absolute_url = convert_to_absolute_url(base_url, request)
dst_request = Request.from_url(absolute_url)

# If the request is a BaseRequestData, convert it to Request object.
elif isinstance(request, BaseRequestData):
else:
dst_request = Request.from_base_request_data(request)

if self._check_enqueue_strategy(
call.get('strategy', EnqueueStrategy.ALL),
add_requests_call.get('strategy', EnqueueStrategy.ALL),
target_url=urlparse(dst_request.url),
origin_url=urlparse(origin),
) and self._check_url_patterns(
dst_request.url,
call.get('include', None),
call.get('exclude', None),
add_requests_call.get('include', None),
add_requests_call.get('exclude', None),
):
requests.append(dst_request)

await request_provider.add_requests_batched(requests)

for push_data_call in result.push_data_calls:
await self._push_data(**push_data_call)

for (id, name), changes in result.key_value_store_changes.items():
store = await self.get_key_value_store(id=id, name=name)
for key, value in changes.updates.items():
await store.set_value(key, value.content, value.content_type)

async def __is_finished_function(self) -> bool:
request_provider = await self.get_request_provider()
is_finished = await request_provider.is_finished()
Expand Down Expand Up @@ -792,15 +803,16 @@ async def __run_task_function(self) -> None:

session = await self._get_session()
proxy_info = await self._get_proxy_info(request, session)
result = RequestHandlerRunResult()
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store)

crawling_context = BasicCrawlingContext(
request=request,
session=session,
proxy_info=proxy_info,
send_request=self._prepare_send_request_function(session, proxy_info),
add_requests=result.add_requests,
push_data=self._push_data,
push_data=result.push_data,
get_key_value_store=result.get_key_value_store,
log=self._logger,
)

Expand Down
2 changes: 2 additions & 0 deletions src/crawlee/beautifulsoup_crawler/_beautifulsoup_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ async def _make_http_request(self, context: BasicCrawlingContext) -> AsyncGenera
add_requests=context.add_requests,
send_request=context.send_request,
push_data=context.push_data,
get_key_value_store=context.get_key_value_store,
log=context.log,
http_response=result.http_response,
)
Expand Down Expand Up @@ -159,6 +160,7 @@ async def enqueue_links(
add_requests=context.add_requests,
send_request=context.send_request,
push_data=context.push_data,
get_key_value_store=context.get_key_value_store,
log=context.log,
http_response=context.http_response,
soup=soup,
Expand Down
1 change: 1 addition & 0 deletions src/crawlee/http_crawler/_http_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ async def _make_http_request(self, context: BasicCrawlingContext) -> AsyncGenera
add_requests=context.add_requests,
send_request=context.send_request,
push_data=context.push_data,
get_key_value_store=context.get_key_value_store,
log=context.log,
http_response=result.http_response,
)
Expand Down
2 changes: 2 additions & 0 deletions src/crawlee/parsel_crawler/_parsel_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ async def _make_http_request(self, context: BasicCrawlingContext) -> AsyncGenera
add_requests=context.add_requests,
send_request=context.send_request,
push_data=context.push_data,
get_key_value_store=context.get_key_value_store,
log=context.log,
http_response=result.http_response,
)
Expand Down Expand Up @@ -158,6 +159,7 @@ async def enqueue_links(
add_requests=context.add_requests,
send_request=context.send_request,
push_data=context.push_data,
get_key_value_store=context.get_key_value_store,
log=context.log,
http_response=context.http_response,
selector=parsel_selector,
Expand Down
1 change: 1 addition & 0 deletions src/crawlee/playwright_crawler/_playwright_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ async def enqueue_links(
send_request=context.send_request,
push_data=context.push_data,
proxy_info=context.proxy_info,
get_key_value_store=context.get_key_value_store,
log=context.log,
page=crawlee_page.page,
infinite_scroll=lambda: infinite_scroll(crawlee_page.page),
Expand Down
6 changes: 1 addition & 5 deletions src/crawlee/storages/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from crawlee.storages._key_value_store import KeyValueStore

if TYPE_CHECKING:
from crawlee._types import JsonSerializable
from crawlee._types import JsonSerializable, PushDataKwargs
from crawlee.base_storage_client import BaseStorageClient
from crawlee.base_storage_client._models import DatasetItemsListPage
from crawlee.configuration import Configuration
Expand Down Expand Up @@ -54,10 +54,6 @@ class GetDataKwargs(TypedDict):
view: NotRequired[str]


class PushDataKwargs(TypedDict):
"""Keyword arguments for dataset's `push_data` method."""


class ExportToKwargs(TypedDict):
"""Keyword arguments for dataset's `export_to` method.

Expand Down
Loading