1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from ..utils import PushToHubMixin , is_kernels_available , is_torch_available
15+ from ..utils import PushToHubMixin , is_torch_available
1616
1717
18- if is_kernels_available ():
19- from kernels import LayerRepository , Mode
20-
2118if is_torch_available ():
2219 import torch
2320
@@ -58,6 +55,8 @@ def infer_device(model):
5855
5956
6057def add_to_mapping (layer_name , device , repo_name , mode , compatible_mapping ):
58+ from kernels import LayerRepository
59+
6160 if device not in ["cuda" , "rocm" , "xpu" ]:
6261 raise ValueError (f"Only cuda, rocm, and xpu devices supported, got: { device } " )
6362 repo_layer_name = repo_name .split (":" )[1 ]
@@ -82,6 +81,8 @@ def __init__(self, kernel_mapping={}):
8281 self .registered_layer_names = {}
8382
8483 def update_kernel (self , repo_id , registered_name , layer_name , device , mode , revision = None ):
84+ from kernels import LayerRepository
85+
8586 self .kernel_mapping [registered_name ] = {
8687 device : {
8788 mode : LayerRepository (
@@ -204,6 +205,8 @@ def create_compatible_mapping(self, model, compile=False):
204205 The device is inferred from the model's parameters if not provided.
205206 The Mode is inferred from the model's training state.
206207 """
208+ from kernels import Mode
209+
207210 compatible_mapping = {}
208211 for layer_name , kernel in self .kernel_mapping .items ():
209212 # Infer Mode: use Mode.TRAINING if model is training, else use Mode.INFERENCE
0 commit comments