PyTorch TorchScript / ONNX 导出
模型训练完成后,需要将模型部署到生产环境。PyTorch 提供了多种模型导出方式,其中 TorchScript 和 ONNX 是最常用的两种格式。
本节详细介绍这两种导出方式的原理、使用方法和最佳实践。
模型导出是将 PyTorch 模型转换为可以在不同平台、不同框架上运行的格式的过程。这对于模型部署、移动端推理、跨框架迁移等场景至关重要。
1. TorchScript 基础
1.1 什么是 TorchScript
TorchScript 是 PyTorch 的序列化格式,它可以将 Python 代码转换为可以独立运行的 C++ 虚拟机代码。TorchScript 程序可以在没有 Python 解释器的环境中运行。
TorchScript 的主要特点:
- 将动态图转换为静态图
- 支持 Python 子集的语法
- 可以在 C++ 环境中运行
- 保留模型的结构和参数
1.2 两种转换方式
TorchScript 提供两种将模型转换为 TorchScript 的方式:
- TorchScript 追踪(Tracing):通过执行模型记录操作,生成静态计算图
- TorchScript 脚本(Scripting):直接分析 Python 代码,编译为 TorchScript
2. TorchScript 追踪 (Tracing)
2.1 基本追踪方法
实例
import torch
import torch.nn as nn
# ── 定义模型 ──────────────────────────────────────
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2)
self.fc = nn.Linear(128 * 8 * 8, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 创建模型实例
model = SimpleNet()
model.eval()
# 示例输入
example_input = torch.randn(1, 3, 64, 64)
# ── 方式一:torch.jit.trace 追踪 ───────────────────
# 通过运行模型并记录操作来创建 TorchScript
traced_model = torch.jit.trace(model, example_input)
print("追踪后的模型:")
print(traced_model)
# 保存模型
traced_model.save("simple_net_traced.pt")
# 加载模型
loaded_model = torch.jit.load("simple_net_traced.pt")
# 使用加载的模型进行推理
output = loaded_model(example_input)
print(f"输出形状: {output.shape}")
import torch.nn as nn
# ── 定义模型 ──────────────────────────────────────
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2)
self.fc = nn.Linear(128 * 8 * 8, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 创建模型实例
model = SimpleNet()
model.eval()
# 示例输入
example_input = torch.randn(1, 3, 64, 64)
# ── 方式一:torch.jit.trace 追踪 ───────────────────
# 通过运行模型并记录操作来创建 TorchScript
traced_model = torch.jit.trace(model, example_input)
print("追踪后的模型:")
print(traced_model)
# 保存模型
traced_model.save("simple_net_traced.pt")
# 加载模型
loaded_model = torch.jit.load("simple_net_traced.pt")
# 使用加载的模型进行推理
output = loaded_model(example_input)
print(f"输出形状: {output.shape}")
2.2 追踪的限制
追踪方法有一些限制:
- 只记录实际执行的操作
- 控制流(如 if、for)会被固定
- 动态大小的输入可能有问题
实例
# 追踪的限制示例
class DynamicModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
# 控制流在追踪时被固定
if x.sum() > 0:
return x * 2
else:
return x / 2
model = DynamicModel()
model.eval()
# 追踪时,if 分支会被固定为追踪时的路径
example_input = torch.tensor([1.0])
traced = torch.jit.trace(model, example_input)
# 即使输入不同,也会执行相同的分支
print(traced(torch.tensor([1.0]))) # 2
print(traced(torch.tensor([-1.0]))) # 仍然是 2,不是 -0.5
class DynamicModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
# 控制流在追踪时被固定
if x.sum() > 0:
return x * 2
else:
return x / 2
model = DynamicModel()
model.eval()
# 追踪时,if 分支会被固定为追踪时的路径
example_input = torch.tensor([1.0])
traced = torch.jit.trace(model, example_input)
# 即使输入不同,也会执行相同的分支
print(traced(torch.tensor([1.0]))) # 2
print(traced(torch.tensor([-1.0]))) # 仍然是 2,不是 -0.5
对于包含控制流的模型,应该使用 TorchScript 脚本(Scripting)而不是追踪。
3. TorchScript 脚本 (Scripting)
3.1 基本使用方法
实例
import torch
# ── 使用 @torch.jit.script 装饰器 ───────────────
@torch.jit.script
def scripted_function(x: torch.Tensor) -> torch.Tensor:
"""使用脚本方式转换函数"""
if x.sum() > 0:
return x * 2
else:
return x / 2
# 测试脚本化的函数
input1 = torch.tensor([1.0, 2.0])
input2 = torch.tensor([-1.0, -2.0])
print(scripted_function(input1)) # [2., 4.]
print(scripted_function(input2)) # [-0.5, -1.]
# ── 脚本化模型 ───────────────────────────────────
class ScriptableModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)
@torch.jit.export
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.fc(x))
@torch.jit.export
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""额外的导出方法"""
out = self.forward(x)
return torch.argmax(out, dim=1)
model = ScriptableModel()
# 脚本化模型
scripted_model = torch.jit.script(model)
print(scripted_model)
# 保存
scripted_model.save("scripted_model.pt")
# ── 使用 @torch.jit.script 装饰器 ───────────────
@torch.jit.script
def scripted_function(x: torch.Tensor) -> torch.Tensor:
"""使用脚本方式转换函数"""
if x.sum() > 0:
return x * 2
else:
return x / 2
# 测试脚本化的函数
input1 = torch.tensor([1.0, 2.0])
input2 = torch.tensor([-1.0, -2.0])
print(scripted_function(input1)) # [2., 4.]
print(scripted_function(input2)) # [-0.5, -1.]
# ── 脚本化模型 ───────────────────────────────────
class ScriptableModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)
@torch.jit.export
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.fc(x))
@torch.jit.export
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""额外的导出方法"""
out = self.forward(x)
return torch.argmax(out, dim=1)
model = ScriptableModel()
# 脚本化模型
scripted_model = torch.jit.script(model)
print(scripted_model)
# 保存
scripted_model.save("scripted_model.pt")
3.2 复杂模型的脚本化
实例
# 更复杂的脚本化示例:带条件的模型
class ConditionModel(nn.Module):
def __init__(self, num_classes: int):
super().__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten()
)
self.classifier = nn.Linear(64, num_classes)
def forward(self, x: torch.Tensor, use_softmax: bool = False) -> torch.Tensor:
"""
支持动态条件的模型
"""
features = self.features(x)
logits = self.classifier(features)
if use_softmax:
return torch.softmax(logits, dim=1)
else:
return logits
def get_prediction(self, x: torch.Tensor) -> torch.Tensor:
"""辅助方法"""
logits = self.forward(x, use_softmax=False)
return torch.argmax(logits, dim=1)
# 脚本化
model = ConditionModel(num_classes=10)
scripted_model = torch.jit.script(model, example_inputs=(torch.randn(1, 3, 32, 32),))
# 测试
test_input = torch.randn(2, 3, 32, 32)
output1 = scripted_model(test_input, use_softmax=False)
output2 = scripted_model(test_input, use_softmax=True)
print(f"Logits 输出形状: {output1.shape}")
print(f"Softmax 输出形状: {output2.shape}")
class ConditionModel(nn.Module):
def __init__(self, num_classes: int):
super().__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten()
)
self.classifier = nn.Linear(64, num_classes)
def forward(self, x: torch.Tensor, use_softmax: bool = False) -> torch.Tensor:
"""
支持动态条件的模型
"""
features = self.features(x)
logits = self.classifier(features)
if use_softmax:
return torch.softmax(logits, dim=1)
else:
return logits
def get_prediction(self, x: torch.Tensor) -> torch.Tensor:
"""辅助方法"""
logits = self.forward(x, use_softmax=False)
return torch.argmax(logits, dim=1)
# 脚本化
model = ConditionModel(num_classes=10)
scripted_model = torch.jit.script(model, example_inputs=(torch.randn(1, 3, 32, 32),))
# 测试
test_input = torch.randn(2, 3, 32, 32)
output1 = scripted_model(test_input, use_softmax=False)
output2 = scripted_model(test_input, use_softmax=True)
print(f"Logits 输出形状: {output1.shape}")
print(f"Softmax 输出形状: {output2.shape}")
4. ONNX 导出
4.1 ONNX 基础
ONNX(Open Neural Network eXchange) 是一个开放的神经网络交换格式,支持在不同的深度学习框架之间转换模型。
ONNX 的优势:
- 跨框架:PyTorch、TensorFlow、Caffe2 等都支持
- 跨平台:支持 CPU、GPU、移动端等多种平台
- 硬件优化:可利用 ONNX Runtime 进行高效推理
- 工具丰富:有大量的优化工具和部署方案
4.2 基本 ONNX 导出
实例
import torch
import torch.nn as nn
import torchvision
# ── 定义模型 ──────────────────────────────────────
class ImageClassifier(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten()
)
self.classifier = nn.Linear(64, 10)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
model = ImageClassifier()
model.eval()
# 示例输入
example_input = torch.randn(1, 3, 32, 32)
# ── 导出为 ONNX ──────────────────────────────────
output_path = "image_classifier.onnx"
torch.onnx.export(
model,
example_input,
output_path,
export_params=True, # 导出模型参数
opset_version=14, # ONNX 版本
do_constant_folding=True, # 常量折叠优化
input_names=['input'], # 输入张量名称
output_names=['output'], # 输出张量名称
dynamic_axes={
'input': {0: 'batch_size'}, # 动态批次维度
'output': {0: 'batch_size'}
}
)
print(f"模型已导出到: {output_path}")
# 验证导出的模型
import onnx
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("ONNX 模型验证通过!")
import torch.nn as nn
import torchvision
# ── 定义模型 ──────────────────────────────────────
class ImageClassifier(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten()
)
self.classifier = nn.Linear(64, 10)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
model = ImageClassifier()
model.eval()
# 示例输入
example_input = torch.randn(1, 3, 32, 32)
# ── 导出为 ONNX ──────────────────────────────────
output_path = "image_classifier.onnx"
torch.onnx.export(
model,
example_input,
output_path,
export_params=True, # 导出模型参数
opset_version=14, # ONNX 版本
do_constant_folding=True, # 常量折叠优化
input_names=['input'], # 输入张量名称
output_names=['output'], # 输出张量名称
dynamic_axes={
'input': {0: 'batch_size'}, # 动态批次维度
'output': {0: 'batch_size'}
}
)
print(f"模型已导出到: {output_path}")
# 验证导出的模型
import onnx
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("ONNX 模型验证通过!")
4.3 导出复杂模型
实例
# 导出完整的 ResNet 模型
import torchvision.models as models
# 加载预训练模型
model = models.resnet18(pretrained=True)
model.eval()
# 示例输入(ResNet 标准的 224x224)
example_input = torch.randn(1, 3, 224, 224)
# 导出
torch.onnx.export(
model,
example_input,
"resnet18.onnx",
export_params=True,
opset_version=14,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size', 2: 'height', 3: 'width'},
'output': {0: 'batch_size'}
}
)
print("ResNet18 已导出")
# 导出时处理特殊的层
# 对于需要特殊处理的层,使用 workaround
class ModelWithSpecialOps(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 3, padding=1)
def forward(self, x):
# 使用 F.interpolate 而不是 nn.functional 的别名
x = self.conv(x)
x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
return x
model = ModelWithSpecialOps()
model.eval()
torch.onnx.export(
model,
torch.randn(1, 3, 32, 32),
"special_ops.onnx",
input_names=['input'],
output_names=['output'],
opset_version=14
)
import torchvision.models as models
# 加载预训练模型
model = models.resnet18(pretrained=True)
model.eval()
# 示例输入(ResNet 标准的 224x224)
example_input = torch.randn(1, 3, 224, 224)
# 导出
torch.onnx.export(
model,
example_input,
"resnet18.onnx",
export_params=True,
opset_version=14,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size', 2: 'height', 3: 'width'},
'output': {0: 'batch_size'}
}
)
print("ResNet18 已导出")
# 导出时处理特殊的层
# 对于需要特殊处理的层,使用 workaround
class ModelWithSpecialOps(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 3, padding=1)
def forward(self, x):
# 使用 F.interpolate 而不是 nn.functional 的别名
x = self.conv(x)
x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
return x
model = ModelWithSpecialOps()
model.eval()
torch.onnx.export(
model,
torch.randn(1, 3, 32, 32),
"special_ops.onnx",
input_names=['input'],
output_names=['output'],
opset_version=14
)
5. 导出的验证与优化
5.1 验证导出模型
实例
import numpy as np
import onnx
import onnxruntime as ort
def verify_onnx_model(onnx_path, pytorch_model, test_input):
"""
验证 ONNX 模型与 PyTorch 模型的输出一致性
"""
# 1. 检查 ONNX 模型
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("✓ ONNX 模型结构验证通过")
# 2. PyTorch 推理
pytorch_model.eval()
with torch.no_grad():
pytorch_output = pytorch_model(test_input)
# 3. ONNX Runtime 推理
ort_session = ort.InferenceSession(onnx_path)
ort_output = ort_session.run(None, {'input': test_input.numpy()})[0]
# 4. 对比输出
diff = np.abs(pytorch_output.numpy() - ort_output).max()
print(f"✓ 最大输出差异: {diff:.6f}")
if diff < 1e-5:
print("✓ 输出验证通过")
return True
else:
print("⚠ 输出差异较大")
return False
# 使用示例
model = ImageClassifier()
model.eval()
# 先导出
torch.onnx.export(
model,
torch.randn(1, 3, 32, 32),
"test_model.onnx",
input_names=['input'],
output_names=['output']
)
# 验证
test_input = torch.randn(2, 3, 32, 32)
verify_onnx_model("test_model.onnx", model, test_input)
import onnx
import onnxruntime as ort
def verify_onnx_model(onnx_path, pytorch_model, test_input):
"""
验证 ONNX 模型与 PyTorch 模型的输出一致性
"""
# 1. 检查 ONNX 模型
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("✓ ONNX 模型结构验证通过")
# 2. PyTorch 推理
pytorch_model.eval()
with torch.no_grad():
pytorch_output = pytorch_model(test_input)
# 3. ONNX Runtime 推理
ort_session = ort.InferenceSession(onnx_path)
ort_output = ort_session.run(None, {'input': test_input.numpy()})[0]
# 4. 对比输出
diff = np.abs(pytorch_output.numpy() - ort_output).max()
print(f"✓ 最大输出差异: {diff:.6f}")
if diff < 1e-5:
print("✓ 输出验证通过")
return True
else:
print("⚠ 输出差异较大")
return False
# 使用示例
model = ImageClassifier()
model.eval()
# 先导出
torch.onnx.export(
model,
torch.randn(1, 3, 32, 32),
"test_model.onnx",
input_names=['input'],
output_names=['output']
)
# 验证
test_input = torch.randn(2, 3, 32, 32)
verify_onnx_model("test_model.onnx", model, test_input)
5.2 ONNX Runtime 优化
实例
import onnxruntime as ort
from onnxruntime import GraphOptimizationLevel
# 创建推理会话并配置优化
def create_optimized_session(onnx_path, providers=None):
if providers is None:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
# 配置会话选项
sess_options = ort.SessionOptions()
# 开启图优化
sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
# 启用其他优化
sess_options.intra_op_num_threads = 4
sess_options.inter_op_num_threads = 4
# 创建会话
session = ort.InferenceSession(onnx_path, sess_options, providers=providers)
return session
# 使用优化的会话进行推理
session = create_optimized_session("resnet18.onnx")
# 获取输入输出名称
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 推理
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
output = session.run([output_name], {input_name: input_data})[0]
print(f"推理输出形状: {output.shape}")
from onnxruntime import GraphOptimizationLevel
# 创建推理会话并配置优化
def create_optimized_session(onnx_path, providers=None):
if providers is None:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
# 配置会话选项
sess_options = ort.SessionOptions()
# 开启图优化
sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
# 启用其他优化
sess_options.intra_op_num_threads = 4
sess_options.inter_op_num_threads = 4
# 创建会话
session = ort.InferenceSession(onnx_path, sess_options, providers=providers)
return session
# 使用优化的会话进行推理
session = create_optimized_session("resnet18.onnx")
# 获取输入输出名称
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 推理
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
output = session.run([output_name], {input_name: input_data})[0]
print(f"推理输出形状: {output.shape}")
5.3 模型量化
实例
# ONNX 模型量化
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
def quantize_onnx_model(input_path, output_path):
"""
动态量化 ONNX 模型
减少模型大小并加速推理
"""
# 动态量化(不需要校准数据)
quantize_dynamic(
input_path,
output_path,
# 指定需要量化的权重类型
weight_type=QuantType.QInt8
)
# 获取文件大小对比
import os
original_size = os.path.getsize(input_path) / 1024 / 1024
quantized_size = os.path.getsize(output_path) / 1024 / 1024
print(f"原始模型: {original_size:.2f} MB")
print(f"量化模型: {quantized_size:.2f} MB")
print(f"压缩比: {original_size/quantized_size:.2f}x")
# 使用
quantize_onnx_model("resnet18.onnx", "resnet18_quantized.onnx")
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
def quantize_onnx_model(input_path, output_path):
"""
动态量化 ONNX 模型
减少模型大小并加速推理
"""
# 动态量化(不需要校准数据)
quantize_dynamic(
input_path,
output_path,
# 指定需要量化的权重类型
weight_type=QuantType.QInt8
)
# 获取文件大小对比
import os
original_size = os.path.getsize(input_path) / 1024 / 1024
quantized_size = os.path.getsize(output_path) / 1024 / 1024
print(f"原始模型: {original_size:.2f} MB")
print(f"量化模型: {quantized_size:.2f} MB")
print(f"压缩比: {original_size/quantized_size:.2f}x")
# 使用
quantize_onnx_model("resnet18.onnx", "resnet18_quantized.onnx")
6. 移动端部署
6.1 导出到移动端格式
实例
# 方式一:TorchScript 移动端部署
# 1. 追踪模型
model = ImageClassifier()
model.eval()
example_input = torch.randn(1, 3, 32, 32)
traced_model = torch.jit.trace(model, example_input)
# 2. 优化模型(移动端优化)
optimized_model = torch.jit.optimize_for_inference(traced_model)
# 3. 保存移动端模型
optimized_model.save("mobile_model.pt")
# 方式二:使用 torchmobile(如果可用)
# torchmobile 用于更轻量的移动端部署
# 请参考官方文档进行交叉编译
# 方式三:导出为 ONNX 并使用移动端运行时
# iOS: 使用 Core ML Tools
# Android: 使用 ONNX Runtime Mobile
print("移动端模型已准备")
# 1. 追踪模型
model = ImageClassifier()
model.eval()
example_input = torch.randn(1, 3, 32, 32)
traced_model = torch.jit.trace(model, example_input)
# 2. 优化模型(移动端优化)
optimized_model = torch.jit.optimize_for_inference(traced_model)
# 3. 保存移动端模型
optimized_model.save("mobile_model.pt")
# 方式二:使用 torchmobile(如果可用)
# torchmobile 用于更轻量的移动端部署
# 请参考官方文档进行交叉编译
# 方式三:导出为 ONNX 并使用移动端运行时
# iOS: 使用 Core ML Tools
# Android: 使用 ONNX Runtime Mobile
print("移动端模型已准备")
6.2 Core ML 导出(iOS)
对于 iOS 平台,可以使用 Core ML Tools 将模型转换为 Core ML 格式。需要先转换为 ONNX 格式,再通过 Core ML Tools 转换。
实例
# 需要安装 coremltools
# pip install coremltools
# Core ML 导出步骤
# 注意:这只在 macOS 上可用
# 步骤 1: 先将 PyTorch 模型导出为 ONNX
# from pytorch_example import your_model
# model = your_model()
# model.eval()
# example_input = torch.randn(1, 3, 224, 224)
# torch.onnx.export(
# model,
# example_input,
# "model.onnx",
# input_names=['input'],
# output_names=['output']
# )
# 步骤 2: 使用 Core ML Tools 转换为 Core ML 格式
# from coremltools.converters import onnx as onnx_coreml
# coreml_model = onnx_coreml.convert(
# model="model.onnx",
# minimum_deployment_target='13',
# image_input_names=['input']
# )
# coreml_model.save("model.mlmodel")
# 使用示例
print("iOS 部署需要 macOS 环境")
print("详细步骤:1. pip install coremltools 2. 导出 ONNX 3. 使用 coremltools 转换")
# pip install coremltools
# Core ML 导出步骤
# 注意:这只在 macOS 上可用
# 步骤 1: 先将 PyTorch 模型导出为 ONNX
# from pytorch_example import your_model
# model = your_model()
# model.eval()
# example_input = torch.randn(1, 3, 224, 224)
# torch.onnx.export(
# model,
# example_input,
# "model.onnx",
# input_names=['input'],
# output_names=['output']
# )
# 步骤 2: 使用 Core ML Tools 转换为 Core ML 格式
# from coremltools.converters import onnx as onnx_coreml
# coreml_model = onnx_coreml.convert(
# model="model.onnx",
# minimum_deployment_target='13',
# image_input_names=['input']
# )
# coreml_model.save("model.mlmodel")
# 使用示例
print("iOS 部署需要 macOS 环境")
print("详细步骤:1. pip install coremltools 2. 导出 ONNX 3. 使用 coremltools 转换")
7. 最佳实践与常见问题
7.1 导出问题排查
| 问题 | 原因 | 解决方案 |
|---|---|---|
| 导出失败 | 不支持的操作 | 使用 opset_version 更高的版本,或替换操作 |
| 输出不一致 | 动态控制流 | 使用脚本化而非追踪,或固定输入尺寸 |
| ONNX Runtime 错误 | 算子不支持 | 检查 ONNX 支持的操作列表 |
7.2 导出格式对比
| 格式 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| TorchScript (.pt) | PyTorch 原生支持 | 跨框架支持差 | 桌面/服务器部署 |
| ONNX (.onnx) | 跨框架、跨平台 | 有些操作不支持 | 通用部署 |
7.3 导出检查清单
实例
# 导出前检查清单
def export_checklist(model, example_input):
"""导出前的检查"""
model.eval()
# 1. 确保模型在推理模式下
print(f"模型模式: {'eval' if not model.training else 'train'}")
# 2. 检查输入尺寸
print(f"示例输入形状: {example_input.shape}")
# 3. 验证前向传播
with torch.no_grad():
output = model(example_input)
print(f"输出形状: {output.shape}")
# 4. 检查是否有不可序列化的操作
# 例如:lambda 函数
for name, module in model.named_modules():
if hasattr(module, '__call__'):
pass # 检查自定义模块
print("✓ 导出检查完成")
# 使用示例
model = ImageClassifier()
example_input = torch.randn(1, 3, 32, 32)
export_checklist(model, example_input)
def export_checklist(model, example_input):
"""导出前的检查"""
model.eval()
# 1. 确保模型在推理模式下
print(f"模型模式: {'eval' if not model.training else 'train'}")
# 2. 检查输入尺寸
print(f"示例输入形状: {example_input.shape}")
# 3. 验证前向传播
with torch.no_grad():
output = model(example_input)
print(f"输出形状: {output.shape}")
# 4. 检查是否有不可序列化的操作
# 例如:lambda 函数
for name, module in model.named_modules():
if hasattr(module, '__call__'):
pass # 检查自定义模块
print("✓ 导出检查完成")
# 使用示例
model = ImageClassifier()
example_input = torch.randn(1, 3, 32, 32)
export_checklist(model, example_input)
7.4 完整导出流程
实例
# 完整的模型导出流程示例
class ProductionModel(nn.Module):
"""生产级模型"""
def __init__(self):
super().__init__()
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(3, stride=2, padding=1),
# ... 更多层
)
self.head = nn.Linear(512, 10)
def forward(self, x, return_features=False):
features = self.backbone(x)
features = torch.nn.functional.adaptive_avg_pool2d(features, 1).flatten(1)
if return_features:
return features
return self.head(features)
# 完整导出流程
model = ProductionModel()
model.load_state_dict(torch.load("weights.pth")) # 加载权重
model.eval()
# 准备示例输入
example_input = torch.randn(1, 3, 224, 224)
# 1. 追踪 TorchScript
traced = torch.jit.trace(model, example_input)
traced.save("model_traced.pt")
# 2. 导出 ONNX
torch.onnx.export(
traced,
example_input,
"model.onnx",
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
opset_version=14
)
# 3. 优化 ONNX
# 使用 onnx-simplifier 去除冗余操作
# onnxsim model.onnx model_simplified.onnx
print("所有格式导出完成!")
class ProductionModel(nn.Module):
"""生产级模型"""
def __init__(self):
super().__init__()
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(3, stride=2, padding=1),
# ... 更多层
)
self.head = nn.Linear(512, 10)
def forward(self, x, return_features=False):
features = self.backbone(x)
features = torch.nn.functional.adaptive_avg_pool2d(features, 1).flatten(1)
if return_features:
return features
return self.head(features)
# 完整导出流程
model = ProductionModel()
model.load_state_dict(torch.load("weights.pth")) # 加载权重
model.eval()
# 准备示例输入
example_input = torch.randn(1, 3, 224, 224)
# 1. 追踪 TorchScript
traced = torch.jit.trace(model, example_input)
traced.save("model_traced.pt")
# 2. 导出 ONNX
torch.onnx.export(
traced,
example_input,
"model.onnx",
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
opset_version=14
)
# 3. 优化 ONNX
# 使用 onnx-simplifier 去除冗余操作
# onnxsim model.onnx model_simplified.onnx
print("所有格式导出完成!")
点我分享笔记