1414# limitations under the License.
1515""" PyTorch CLIP model."""
1616
17- from dataclasses import dataclass
1817import math
18+ from dataclasses import dataclass
1919from typing import Any , Optional , Tuple , Union
2020
2121import torch
2222import torch .utils .checkpoint
2323from torch import nn
2424
25+ from transformers import CLIPConfig , CLIPModel , CLIPTextConfig , CLIPVisionConfig
2526from transformers .activations import ACT2FN
2627from transformers .modeling_outputs import BaseModelOutput , BaseModelOutputWithPooling
2728from transformers .modeling_utils import PreTrainedModel
3233 logging ,
3334 replace_return_docstrings ,
3435)
35- from transformers import CLIPModel , CLIPConfig , CLIPVisionConfig , CLIPTextConfig
36+
3637
3738logger = logging .get_logger (__name__ )
3839
@@ -153,11 +154,11 @@ def __init__(self, config: CLIPTextConfig):
153154 self .register_buffer ("position_ids" , torch .arange (config .max_position_embeddings ).expand ((1 , - 1 )))
154155
155156 def forward (
156- self ,
157- input_ids : Optional [torch .LongTensor ] = None ,
158- position_ids : Optional [torch .LongTensor ] = None ,
159- inputs_embeds : Optional [torch .FloatTensor ] = None ,
160- attention_mask : Optional [torch .Tensor ] = None ,
157+ self ,
158+ input_ids : Optional [torch .LongTensor ] = None ,
159+ position_ids : Optional [torch .LongTensor ] = None ,
160+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
161+ attention_mask : Optional [torch .Tensor ] = None ,
161162 ) -> torch .Tensor :
162163 seq_length = input_ids .shape [- 1 ] if input_ids is not None else inputs_embeds .shape [- 2 ]
163164
@@ -193,16 +194,15 @@ def __init__(self, config):
193194 )
194195 self .scale = 1 / math .sqrt (math .sqrt (self .head_dim ))
195196
196- self .qkv_proj = nn .Linear (self .embed_dim , self .embed_dim * 3 )
197+ self .qkv_proj = nn .Linear (self .embed_dim , self .embed_dim * 3 )
197198 self .out_proj = nn .Linear (self .embed_dim , self .embed_dim )
198199
199-
200200 def forward (
201- self ,
202- hidden_states : torch .Tensor ,
203- attention_mask : Optional [torch .Tensor ] = None ,
204- causal_attention_mask : Optional [torch .Tensor ] = None ,
205- output_attentions : Optional [bool ] = False ,
201+ self ,
202+ hidden_states : torch .Tensor ,
203+ attention_mask : Optional [torch .Tensor ] = None ,
204+ causal_attention_mask : Optional [torch .Tensor ] = None ,
205+ output_attentions : Optional [bool ] = False ,
206206 ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
207207 """Input shape: Batch x Time x Channel"""
208208
@@ -212,9 +212,7 @@ def forward(
212212 qkv_states = qkv_states .view (bsz , tgt_len , self .num_heads , - 1 )
213213 query_states , key_states , value_states = torch .split (qkv_states , self .head_dim , dim = - 1 )
214214
215- attn_weights = torch .einsum (
216- "bthc,bshc->bhts" , query_states * self .scale , key_states * self .scale
217- )
215+ attn_weights = torch .einsum ("bthc,bshc->bhts" , query_states * self .scale , key_states * self .scale )
218216
219217 wdtype = attn_weights .dtype
220218 attn_weights = nn .functional .softmax (attn_weights .float (), dim = - 1 ).type (wdtype )
@@ -252,11 +250,11 @@ def __init__(self, config: CLIPConfig):
252250 self .layer_norm2 = nn .LayerNorm (self .embed_dim )
253251
254252 def forward (
255- self ,
256- hidden_states : torch .Tensor ,
257- attention_mask : torch .Tensor ,
258- causal_attention_mask : torch .Tensor ,
259- output_attentions : Optional [bool ] = False ,
253+ self ,
254+ hidden_states : torch .Tensor ,
255+ attention_mask : torch .Tensor ,
256+ causal_attention_mask : torch .Tensor ,
257+ output_attentions : Optional [bool ] = False ,
260258 ) -> Tuple [torch .FloatTensor ]:
261259 """
262260 Args:
@@ -313,31 +311,31 @@ def _init_weights(self, module):
313311 module .padding_embedding .weight .data .normal_ (mean = 0.0 , std = factor * 0.02 )
314312 elif isinstance (module , CLIPVisionEmbeddings ):
315313 factor = self .config .initializer_factor
316- nn .init .normal_ (module .class_embedding , mean = 0.0 , std = module .embed_dim ** - 0.5 * factor )
314+ nn .init .normal_ (module .class_embedding , mean = 0.0 , std = module .embed_dim ** - 0.5 * factor )
317315 nn .init .normal_ (module .patch_embedding .weight , std = module .config .initializer_range * factor )
318316 nn .init .normal_ (module .position_embedding .weight , std = module .config .initializer_range * factor )
319317 elif isinstance (module , CLIPAttention ):
320318 factor = self .config .initializer_factor
321- in_proj_std = (module .embed_dim ** - 0.5 ) * ((2 * module .config .num_hidden_layers ) ** - 0.5 ) * factor
322- out_proj_std = (module .embed_dim ** - 0.5 ) * factor
319+ in_proj_std = (module .embed_dim ** - 0.5 ) * ((2 * module .config .num_hidden_layers ) ** - 0.5 ) * factor
320+ out_proj_std = (module .embed_dim ** - 0.5 ) * factor
323321 nn .init .normal_ (module .qkv_proj .weight , std = in_proj_std )
324322 nn .init .normal_ (module .out_proj .weight , std = out_proj_std )
325323 elif isinstance (module , CLIPMLP ):
326324 factor = self .config .initializer_factor
327325 in_proj_std = (
328- (module .config .hidden_size ** - 0.5 ) * ((2 * module .config .num_hidden_layers ) ** - 0.5 ) * factor
326+ (module .config .hidden_size ** - 0.5 ) * ((2 * module .config .num_hidden_layers ) ** - 0.5 ) * factor
329327 )
330328 fc_std = (2 * module .config .hidden_size ) ** - 0.5 * factor
331329 nn .init .normal_ (module .fc1 .weight , std = fc_std )
332330 nn .init .normal_ (module .fc2 .weight , std = in_proj_std )
333331 elif isinstance (module , CLIPModel ):
334332 nn .init .normal_ (
335333 module .text_projection .weight ,
336- std = module .text_embed_dim ** - 0.5 * self .config .initializer_factor ,
334+ std = module .text_embed_dim ** - 0.5 * self .config .initializer_factor ,
337335 )
338336 nn .init .normal_ (
339337 module .visual_projection .weight ,
340- std = module .vision_embed_dim ** - 0.5 * self .config .initializer_factor ,
338+ std = module .vision_embed_dim ** - 0.5 * self .config .initializer_factor ,
341339 )
342340
343341 if isinstance (module , nn .LayerNorm ):
@@ -463,13 +461,13 @@ def __init__(self, config: CLIPConfig):
463461 self .gradient_checkpointing = False
464462
465463 def forward (
466- self ,
467- inputs_embeds ,
468- attention_mask : Optional [torch .Tensor ] = None ,
469- causal_attention_mask : Optional [torch .Tensor ] = None ,
470- output_attentions : Optional [bool ] = None ,
471- output_hidden_states : Optional [bool ] = None ,
472- return_dict : Optional [bool ] = None ,
464+ self ,
465+ inputs_embeds ,
466+ attention_mask : Optional [torch .Tensor ] = None ,
467+ causal_attention_mask : Optional [torch .Tensor ] = None ,
468+ output_attentions : Optional [bool ] = None ,
469+ output_hidden_states : Optional [bool ] = None ,
470+ return_dict : Optional [bool ] = None ,
473471 ) -> Union [Tuple , BaseModelOutput ]:
474472 r"""
475473 Args:
@@ -562,13 +560,13 @@ def __init__(self, config: CLIPTextConfig):
562560 @add_start_docstrings_to_model_forward (CLIP_TEXT_INPUTS_DOCSTRING )
563561 @replace_return_docstrings (output_type = BaseModelOutputWithPooling , config_class = CLIPTextConfig )
564562 def forward (
565- self ,
566- input_ids : Optional [torch .Tensor ] = None ,
567- attention_mask : Optional [torch .Tensor ] = None ,
568- position_ids : Optional [torch .Tensor ] = None ,
569- output_attentions : Optional [bool ] = None ,
570- output_hidden_states : Optional [bool ] = None ,
571- return_dict : Optional [bool ] = None ,
563+ self ,
564+ input_ids : Optional [torch .Tensor ] = None ,
565+ attention_mask : Optional [torch .Tensor ] = None ,
566+ position_ids : Optional [torch .Tensor ] = None ,
567+ output_attentions : Optional [bool ] = None ,
568+ output_hidden_states : Optional [bool ] = None ,
569+ return_dict : Optional [bool ] = None ,
572570 ) -> Union [Tuple , BaseModelOutputWithPooling ]:
573571 r"""
574572 Returns:
@@ -652,13 +650,13 @@ def set_input_embeddings(self, value):
652650 @add_start_docstrings_to_model_forward (CLIP_TEXT_INPUTS_DOCSTRING )
653651 @replace_return_docstrings (output_type = BaseModelOutputWithPooling , config_class = CLIPTextConfig )
654652 def forward (
655- self ,
656- input_ids : Optional [torch .Tensor ] = None ,
657- attention_mask : Optional [torch .Tensor ] = None ,
658- position_ids : Optional [torch .Tensor ] = None ,
659- output_attentions : Optional [bool ] = None ,
660- output_hidden_states : Optional [bool ] = None ,
661- return_dict : Optional [bool ] = None ,
653+ self ,
654+ input_ids : Optional [torch .Tensor ] = None ,
655+ attention_mask : Optional [torch .Tensor ] = None ,
656+ position_ids : Optional [torch .Tensor ] = None ,
657+ output_attentions : Optional [bool ] = None ,
658+ output_hidden_states : Optional [bool ] = None ,
659+ return_dict : Optional [bool ] = None ,
662660 ) -> Union [Tuple , BaseModelOutputWithPooling ]:
663661 r"""
664662 Returns:
@@ -684,4 +682,4 @@ def forward(
684682 output_attentions = output_attentions ,
685683 output_hidden_states = output_hidden_states ,
686684 return_dict = return_dict ,
687- )
685+ )
0 commit comments