2525import torch .nn as nn
2626import torch .nn .functional as F
2727from PIL import Image
28- from safetensors .torch import save_file
29- from safetensors .torch import load_model
28+ from safetensors .torch import load_model , save_file
3029from transformers import CLIPImageProcessor
3130
3231from ..runtime .session import Session
3332
3433
3534def add_multimodal_arguments (parser ):
36- parser .add_argument ('--model_type' ,
37- type = str ,
38- default = None ,
39- choices = [
40- 'blip2' , 'llava' , 'llava_next' , 'llava_onevision' ,
41- 'llava_onevision_lmms ' , 'vila ' , 'nougat ' , 'cogvlm ' ,
42- 'fuyu ' , 'pix2struct ' , 'neva ' , 'kosmos-2 ' ,
43- 'video-neva ' , 'phi-3-vision ' , 'phi-4-multimodal ' ,
44- 'mllama' , 'internvl' , 'qwen2_vl' ,
45- 'internlm-xcomposer2' , 'qwen2_audio' , 'pixtral' , 'eclair'
46- ],
47- help = "Model type" )
35+ parser .add_argument (
36+ '--model_type' ,
37+ type = str ,
38+ default = None ,
39+ choices = [
40+ 'blip2 ' , 'llava ' , 'llava_next ' , 'llava_onevision ' ,
41+ 'llava_onevision_lmms' , 'vila ' , 'nougat ' , 'cogvlm ' , 'fuyu ' ,
42+ 'pix2struct' , 'neva' , 'kosmos-2 ' , 'video-neva ' , 'phi-3-vision ' ,
43+ 'phi-4-multimodal' , 'mllama' , 'internvl' , 'qwen2_vl' ,
44+ 'internlm-xcomposer2' , 'qwen2_audio' , 'pixtral' , 'eclair'
45+ ],
46+ help = "Model type" )
4847 parser .add_argument (
4948 '--model_path' ,
5049 type = str ,
@@ -1743,20 +1742,33 @@ def forward(self, pixel_values, attention_mask):
17431742 engine_name = f"model.engine" ,
17441743 dtype = torch .bfloat16 )
17451744
1745+
17461746def build_eclair_engine (args ):
1747-
1747+
17481748 class RadioWithNeck (torch .nn .Module ):
1749+
17491750 def __init__ (self ):
17501751 super ().__init__ ()
17511752
1752- self .model_encoder = torch .hub .load ("NVlabs/RADIO" , "radio_model" , version = "radio_v2.5-h" )
1753+ self .model_encoder = torch .hub .load ("NVlabs/RADIO" ,
1754+ "radio_model" ,
1755+ version = "radio_v2.5-h" )
17531756 self .model_encoder .summary_idxs = torch .tensor (4 )
17541757
17551758 self .conv1 = torch .nn .Conv1d (1280 , 1024 , 1 )
1756- self .layer_norm1 = torch .nn .LayerNorm (1024 , eps = 1e-6 , elementwise_affine = True )
1757- self .conv2 = torch .nn .Conv2d (1024 , 1024 , kernel_size = (1 , 4 ), stride = (1 , 4 ), padding = 0 , bias = False )
1758- self .layer_norm2 = torch .nn .LayerNorm (1024 , eps = 1e-6 , elementwise_affine = True )
1759-
1759+ self .layer_norm1 = torch .nn .LayerNorm (1024 ,
1760+ eps = 1e-6 ,
1761+ elementwise_affine = True )
1762+ self .conv2 = torch .nn .Conv2d (1024 ,
1763+ 1024 ,
1764+ kernel_size = (1 , 4 ),
1765+ stride = (1 , 4 ),
1766+ padding = 0 ,
1767+ bias = False )
1768+ self .layer_norm2 = torch .nn .LayerNorm (1024 ,
1769+ eps = 1e-6 ,
1770+ elementwise_affine = True )
1771+
17601772 @torch .no_grad
17611773 def forward (self , pixel_values ):
17621774 _ , feature = self .model_encoder (pixel_values )
@@ -1770,26 +1782,29 @@ def forward(self, pixel_values):
17701782 output = output .flatten (- 2 , - 1 ).permute (0 , 2 , 1 )
17711783 output = self .layer_norm2 (output )
17721784 return output
1773-
1785+
17741786 processor = NougatProcessor .from_pretrained (args .model_path )
17751787 model = VisionEncoderDecoderModel .from_pretrained ("facebook/nougat-base" )
17761788 model .encoder = RadioWithNeck ()
17771789 model .decoder .resize_token_embeddings (len (processor .tokenizer ))
1778- model .config .decoder_start_token_id = processor .tokenizer .eos_token_id # 2
1790+ model .config .decoder_start_token_id = processor .tokenizer .eos_token_id # 2
17791791 model .config .pad_token_id = processor .tokenizer .pad_token_id # 1
17801792 load_model (model , os .path .join (args .model_path , "model.safetensors" ))
1781-
1793+
17821794 wrapper = model .encoder .to (args .device )
17831795 # temporary fix due to TRT onnx export bug
17841796 for block in wrapper .model_encoder .model .blocks :
17851797 block .attn .fused_attn = False
1786-
1787- image = torch .randn ((1 , 3 , 2048 , 1648 ), device = args .device , dtype = torch .float16 )
1798+
1799+ image = torch .randn ((1 , 3 , 2048 , 1648 ),
1800+ device = args .device ,
1801+ dtype = torch .float16 )
17881802 export_onnx (wrapper , image , f'{ args .output_dir } /onnx' )
17891803 build_trt_engine (
17901804 args .model_type ,
17911805 [image .shape [1 ], image .shape [2 ], image .shape [3 ]], # [3, H, W]
17921806 f'{ args .output_dir } /onnx' ,
17931807 args .output_dir ,
17941808 args .max_batch_size ,
1795- dtype = torch .bfloat16 ,engine_name = 'visual_encoder.engine' )
1809+ dtype = torch .bfloat16 ,
1810+ engine_name = 'visual_encoder.engine' )
0 commit comments