This repository was archived by the owner on Apr 1, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 68
Expand file tree
/
Copy pathutils.py
More file actions
245 lines (189 loc) · 7.96 KB
/
utils.py
File metadata and controls
245 lines (189 loc) · 7.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from typing import (
Any,
Generator,
Hashable,
Iterable,
Literal,
Mapping,
Optional,
Tuple,
Union,
)
import bigframes_vendored.constants as constants
from google.cloud import bigquery
import pandas as pd
from bigframes.core import convert, guid
import bigframes.pandas as bpd
from bigframes.session import Session
# Internal type alias
ArrayType = Union[bpd.DataFrame, bpd.Series, pd.DataFrame, pd.Series]
BigFramesArrayType = Union[bpd.DataFrame, bpd.Series]
def batch_convert_to_dataframe(
*input: ArrayType,
session: Optional[Session] = None,
) -> Generator[bpd.DataFrame, None, None]:
"""Converts the input to BigFrames DataFrame.
Args:
session:
The session to convert local pandas instances to BigFrames counter-parts.
It is not used if the input itself is already a BigFrame data frame or series.
"""
_validate_sessions(*input, session=session)
return (
convert.to_bf_dataframe(frame, default_index=None, session=session)
for frame in input
)
def batch_convert_to_series(
*input: ArrayType, session: Optional[Session] = None
) -> Generator[bpd.Series, None, None]:
"""Converts the input to BigFrames Series.
Args:
session:
The session to convert local pandas instances to BigFrames counter-parts.
It is not used if the input itself is already a BigFrame data frame or series.
"""
_validate_sessions(*input, session=session)
return (
convert.to_bf_series(
_get_only_column(frame), default_index=None, session=session
)
for frame in input
)
def batch_convert_to_bf_equivalent(
*input: ArrayType, session: Optional[Session] = None
) -> Generator[Union[bpd.DataFrame, bpd.Series], None, None]:
"""Converts the input to BigFrames DataFrame or Series.
Args:
session:
The session to convert local pandas instances to BigFrames counter-parts.
It is not used if the input itself is already a BigFrame data frame or series.
"""
_validate_sessions(*input, session=session)
for frame in input:
if isinstance(frame, bpd.DataFrame) or isinstance(frame, pd.DataFrame):
yield convert.to_bf_dataframe(frame, default_index=None, session=session)
elif isinstance(frame, bpd.Series) or isinstance(frame, pd.Series):
yield convert.to_bf_series(
_get_only_column(frame), default_index=None, session=session
)
else:
raise ValueError(f"Unsupported type: {type(frame)}")
def _validate_sessions(*input: ArrayType, session: Optional[Session]):
session_ids = set(
i._session.session_id
for i in input
if isinstance(i, bpd.DataFrame) or isinstance(i, bpd.Series)
)
if len(session_ids) > 1:
raise ValueError("Cannot convert data from multiple sessions")
def _get_only_column(input: ArrayType) -> Union[pd.Series, bpd.Series]:
if isinstance(input, pd.Series) or isinstance(input, bpd.Series):
return input
if len(input.columns) != 1:
raise ValueError(
"To convert into Series, DataFrames can only contain one column. "
f"Try input with only one column. {constants.FEEDBACK_LINK}"
)
label = typing.cast(Hashable, input.columns.tolist()[0])
if isinstance(input, pd.DataFrame):
return typing.cast(pd.Series, input[label])
return typing.cast(bpd.Series, input[label]) # type: ignore
def parse_model_endpoint(model_endpoint: str) -> tuple[str, Optional[str]]:
"""Parse model endpoint string to model_name and version."""
model_name = model_endpoint
version = None
if model_endpoint.startswith("multimodalembedding"):
return model_name, version
at_idx = model_endpoint.find("@")
if at_idx != -1:
version = model_endpoint[at_idx + 1 :]
model_name = model_endpoint[:at_idx]
return model_name, version
def _resolve_param_type(t: type) -> type:
def is_optional(t):
return typing.get_origin(t) is Union and type(None) in typing.get_args(t)
# Optional[type] to type
if is_optional(t):
union_set = set(typing.get_args(t))
union_set.remove(type(None))
t = Union[tuple(union_set)] # type: ignore
# Literal[value0, value1...] to type(value0)
if typing.get_origin(t) is Literal:
return type(typing.get_args(t)[0])
return t
def retrieve_params_from_bq_model(
cls, bq_model: bigquery.Model, params_mapping: Mapping[str, str]
) -> dict[str, Any]:
"""Retrieve parameters of class constructor from BQ model. params_mapping specifies the names mapping param_name -> bqml_name. Params couldn't be found will be ignored."""
kwargs = {}
# See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
last_fitting = bq_model.training_runs[-1]["trainingOptions"]
for bf_param, bf_param_type in typing.get_type_hints(cls.__init__).items():
bqml_param = params_mapping.get(bf_param)
if bqml_param in last_fitting:
bf_param_type = _resolve_param_type(bf_param_type)
kwargs[bf_param] = bf_param_type(last_fitting[bqml_param])
return kwargs
def combine_training_and_evaluation_data(
X_train: bpd.DataFrame,
y_train: bpd.DataFrame,
X_eval: bpd.DataFrame,
y_eval: bpd.DataFrame,
bqml_options: dict,
) -> Tuple[bpd.DataFrame, bpd.DataFrame, dict]:
"""
Combine training data and labels with evlauation data and labels, and keep
them differentiated through a split column in the combined data and labels.
"""
assert X_train.columns.equals(X_eval.columns)
assert y_train.columns.equals(y_eval.columns)
# create a custom split column for BQML and supply the evaluation
# data along with the training data in a combined single table
# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-dnn-models#data_split_col.
split_col = guid.generate_guid()
assert split_col not in X_train.columns
# To prevent side effects on the input dataframes, we operate on copies
X_train = X_train.copy()
X_eval = X_eval.copy()
X_train[split_col] = False
X_eval[split_col] = True
# Rename y columns to avoid collision with X columns during join
y_mapping = {col: guid.generate_guid() + str(col) for col in y_train.columns}
y_train_renamed = y_train.rename(columns=y_mapping)
y_eval_renamed = y_eval.rename(columns=y_mapping)
# Join X and y first to preserve row alignment
train_combined = X_train.join(y_train_renamed, how="outer")
eval_combined = X_eval.join(y_eval_renamed, how="outer")
combined = bpd.concat([train_combined, eval_combined])
X = combined[X_train.columns]
y = combined[list(y_mapping.values())].rename(
columns={v: k for k, v in y_mapping.items()}
)
# create options copy to not mutate the incoming one
bqml_options = bqml_options.copy()
bqml_options["data_split_method"] = "CUSTOM"
bqml_options["data_split_col"] = split_col
return X, y, bqml_options
def standardize_type(v: str, supported_dtypes: Optional[Iterable[str]] = None):
t = v.lower()
t = t.replace("boolean", "bool")
if supported_dtypes:
if t not in supported_dtypes:
raise ValueError(
f"Data type {v} is not supported. We only support {', '.join(supported_dtypes)}."
)
return t