forked from googleapis/python-bigquery-dataframes
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtree_properties.py
More file actions
175 lines (142 loc) · 6.25 KB
/
Copy pathtree_properties.py
File metadata and controls
175 lines (142 loc) · 6.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import functools
import itertools
from typing import Callable, Dict, Optional, Sequence
import bigframes.core.nodes as nodes
def is_trivially_executable(node: nodes.BigFrameNode) -> bool:
if local_only(node):
return True
children_trivial = all(is_trivially_executable(child) for child in node.child_nodes)
self_trivial = (not node.non_local) and (node.row_preserving)
return children_trivial and self_trivial
def local_only(node: nodes.BigFrameNode) -> bool:
return all(isinstance(node, nodes.ReadLocalNode) for node in node.roots)
def can_fast_peek(node: nodes.BigFrameNode) -> bool:
if local_only(node):
return True
children_peekable = all(can_fast_peek(child) for child in node.child_nodes)
self_peekable = not node.non_local
return children_peekable and self_peekable
def can_fast_head(node: nodes.BigFrameNode) -> bool:
"""Can get head fast if can push head operator down to leafs and operators preserve rows."""
if isinstance(node, nodes.LeafNode):
return node.supports_fast_head
if isinstance(node, (nodes.ProjectionNode, nodes.SelectionNode)):
return can_fast_head(node.child)
return False
def row_count(node: nodes.BigFrameNode) -> Optional[int]:
"""Determine row count from local metadata, return None if unknown."""
if isinstance(node, nodes.LeafNode):
return node.row_count
if isinstance(node, nodes.AggregateNode):
if len(node.by_column_ids) == 0:
return 1
return None
if isinstance(node, nodes.ConcatNode):
sub_counts = list(map(row_count, node.child_nodes))
total = 0
for count in sub_counts:
if count is None:
return None
total += count
return total
if isinstance(node, nodes.UnaryNode) and node.row_preserving:
return row_count(node.child)
return None
# Replace modified_cost(node) = cost(apply_cache(node))
def select_cache_target(
root: nodes.BigFrameNode,
min_complexity: float,
max_complexity: float,
cache: dict[nodes.BigFrameNode, nodes.BigFrameNode],
heuristic: Callable[[int, int], float],
) -> Optional[nodes.BigFrameNode]:
"""Take tree, and return candidate nodes with (# of occurences, post-caching planning complexity).
heurstic takes two args, node complexity, and node occurence count, in that order
"""
@functools.cache
def _with_caching(subtree: nodes.BigFrameNode) -> nodes.BigFrameNode:
return replace_nodes(subtree, cache)
def _combine_counts(
left: Dict[nodes.BigFrameNode, int], right: Dict[nodes.BigFrameNode, int]
) -> Dict[nodes.BigFrameNode, int]:
return {
key: left.get(key, 0) + right.get(key, 0)
for key in itertools.chain(left.keys(), right.keys())
}
@functools.cache
def _node_counts_inner(
subtree: nodes.BigFrameNode,
) -> Dict[nodes.BigFrameNode, int]:
"""Helper function to count occurences of duplicate nodes in a subtree. Considers only nodes in a complexity range"""
empty_counts: Dict[nodes.BigFrameNode, int] = {}
subtree_complexity = _with_caching(subtree).planning_complexity
if subtree_complexity >= min_complexity:
child_counts = [_node_counts_inner(child) for child in subtree.child_nodes]
node_counts = functools.reduce(_combine_counts, child_counts, empty_counts)
if subtree_complexity <= max_complexity:
return _combine_counts(node_counts, {subtree: 1})
else:
return node_counts
return empty_counts
node_counts = _node_counts_inner(root)
if len(node_counts) == 0:
raise ValueError("node counts should be non-zero")
return max(
node_counts.keys(),
key=lambda node: heuristic(
_with_caching(node).planning_complexity, node_counts[node]
),
)
def count_nodes(forest: Sequence[nodes.BigFrameNode]) -> dict[nodes.BigFrameNode, int]:
"""
Counts the number of instances of each subtree present within a forest.
Memoizes internally to accelerate execution, but cache not persisted (not reused between invocations).
Args:
forest (Sequence of BigFrameNode):
The roots of each tree in the forest
Returns:
dict[BigFramesNode, int]: The number of occurences of each subtree.
"""
def _combine_counts(
left: Dict[nodes.BigFrameNode, int], right: Dict[nodes.BigFrameNode, int]
) -> Dict[nodes.BigFrameNode, int]:
return {
key: left.get(key, 0) + right.get(key, 0)
for key in itertools.chain(left.keys(), right.keys())
}
empty_counts: Dict[nodes.BigFrameNode, int] = {}
@functools.cache
def _node_counts_inner(
subtree: nodes.BigFrameNode,
) -> Dict[nodes.BigFrameNode, int]:
"""Helper function to count occurences of duplicate nodes in a subtree. Considers only nodes in a complexity range"""
child_counts = [_node_counts_inner(child) for child in subtree.child_nodes]
node_counts = functools.reduce(_combine_counts, child_counts, empty_counts)
return _combine_counts(node_counts, {subtree: 1})
counts = [_node_counts_inner(root) for root in forest]
return functools.reduce(_combine_counts, counts, empty_counts)
def replace_nodes(
root: nodes.BigFrameNode,
replacements: dict[nodes.BigFrameNode, nodes.BigFrameNode],
):
@functools.cache
def apply_substition(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
if node in replacements.keys():
return replacements[node]
else:
return node.transform_children(apply_substition)
return apply_substition(root)