forked from googleapis/python-bigquery-dataframes
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
141 lines (110 loc) · 4.84 KB
/
Copy pathutils.py
File metadata and controls
141 lines (110 loc) · 4.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from typing import Any, Generator, Iterable, Literal, Mapping, Optional, Union
import bigframes_vendored.constants as constants
from google.cloud import bigquery
from bigframes.core import blocks
import bigframes.pandas as bpd
# Internal type alias
ArrayType = Union[bpd.DataFrame, bpd.Series]
def convert_to_dataframe(*input: ArrayType) -> Generator[bpd.DataFrame, None, None]:
return (_convert_to_dataframe(frame) for frame in input)
def _convert_to_dataframe(frame: ArrayType) -> bpd.DataFrame:
if isinstance(frame, bpd.DataFrame):
return frame
if isinstance(frame, bpd.Series):
return frame.to_frame()
raise ValueError(
f"Unsupported type {type(frame)} to convert to DataFrame. {constants.FEEDBACK_LINK}"
)
def convert_to_series(*input: ArrayType) -> Generator[bpd.Series, None, None]:
return (_convert_to_series(frame) for frame in input)
def _convert_to_series(frame: ArrayType) -> bpd.Series:
if isinstance(frame, bpd.DataFrame):
if len(frame.columns) != 1:
raise ValueError(
"To convert into Series, DataFrames can only contain one column. "
f"Try input with only one column. {constants.FEEDBACK_LINK}"
)
label = typing.cast(blocks.Label, frame.columns.tolist()[0])
return typing.cast(bpd.Series, frame[label])
if isinstance(frame, bpd.Series):
return frame
raise ValueError(
f"Unsupported type {type(frame)} to convert to Series. {constants.FEEDBACK_LINK}"
)
def convert_to_types(
inputs: Iterable[Union[ArrayType, None]],
type_instances: Iterable[Union[ArrayType, None]],
) -> tuple[Union[ArrayType, None]]:
"""Convert the DF, Series and None types of the input to corresponding type_instances types."""
results = []
for input, type_instance in zip(inputs, type_instances):
results.append(_convert_to_type(input, type_instance))
return tuple(results)
def _convert_to_type(
input: Union[ArrayType, None], type_instance: Union[ArrayType, None]
):
if type_instance is None:
if input is not None:
raise ValueError(
f"Trying to convert not None type to None. {constants.FEEDBACK_LINK}"
)
return None
if input is None:
raise ValueError(
f"Trying to convert None type to not None. {constants.FEEDBACK_LINK}"
)
if isinstance(type_instance, bpd.DataFrame):
return _convert_to_dataframe(input)
if isinstance(type_instance, bpd.Series):
return _convert_to_series(input)
raise ValueError(
f"Unsupport converting to {type(type_instance)}. {constants.FEEDBACK_LINK}"
)
def parse_model_endpoint(model_endpoint: str) -> tuple[str, Optional[str]]:
"""Parse model endpoint string to model_name and version."""
model_name = model_endpoint
version = None
at_idx = model_endpoint.find("@")
if at_idx != -1:
version = model_endpoint[at_idx + 1 :]
model_name = model_endpoint[:at_idx]
return model_name, version
def _resolve_param_type(t: type) -> type:
def is_optional(t):
return typing.get_origin(t) is Union and type(None) in typing.get_args(t)
# Optional[type] to type
if is_optional(t):
union_set = set(typing.get_args(t))
union_set.remove(type(None))
t = Union[tuple(union_set)] # type: ignore
# Literal[value0, value1...] to type(value0)
if typing.get_origin(t) is Literal:
return type(typing.get_args(t)[0])
return t
def retrieve_params_from_bq_model(
cls, bq_model: bigquery.Model, params_mapping: Mapping[str, str]
) -> dict[str, Any]:
"""Retrieve parameters of class constructor from BQ model. params_mapping specifies the names mapping param_name -> bqml_name. Params couldn't be found will be ignored."""
kwargs = {}
# See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
last_fitting = bq_model.training_runs[-1]["trainingOptions"]
for bf_param, bf_param_type in typing.get_type_hints(cls.__init__).items():
bqml_param = params_mapping.get(bf_param)
if bqml_param in last_fitting:
bf_param_type = _resolve_param_type(bf_param_type)
kwargs[bf_param] = bf_param_type(last_fitting[bqml_param])
return kwargs