Skip to content

Commit 5b6c209

Browse files
authored
[kernels] change import time in KernelConfig (#42004)
* change import time * style
1 parent 258c76e commit 5b6c209

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/transformers/utils/kernel_config.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from ..utils import PushToHubMixin, is_kernels_available, is_torch_available
15+
from ..utils import PushToHubMixin, is_torch_available
1616

1717

18-
if is_kernels_available():
19-
from kernels import LayerRepository, Mode
20-
2118
if is_torch_available():
2219
import torch
2320

@@ -58,6 +55,8 @@ def infer_device(model):
5855

5956

6057
def add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping):
58+
from kernels import LayerRepository
59+
6160
if device not in ["cuda", "rocm", "xpu"]:
6261
raise ValueError(f"Only cuda, rocm, and xpu devices supported, got: {device}")
6362
repo_layer_name = repo_name.split(":")[1]
@@ -82,6 +81,8 @@ def __init__(self, kernel_mapping={}):
8281
self.registered_layer_names = {}
8382

8483
def update_kernel(self, repo_id, registered_name, layer_name, device, mode, revision=None):
84+
from kernels import LayerRepository
85+
8586
self.kernel_mapping[registered_name] = {
8687
device: {
8788
mode: LayerRepository(
@@ -204,6 +205,8 @@ def create_compatible_mapping(self, model, compile=False):
204205
The device is inferred from the model's parameters if not provided.
205206
The Mode is inferred from the model's training state.
206207
"""
208+
from kernels import Mode
209+
207210
compatible_mapping = {}
208211
for layer_name, kernel in self.kernel_mapping.items():
209212
# Infer Mode: use Mode.TRAINING if model is training, else use Mode.INFERENCE

0 commit comments

Comments
 (0)