@@ -17,7 +17,7 @@ use crate::{extract_spans, Diagnostic};
1717use bincode;
1818use proc_macro2:: { Span , TokenStream as TokenStream2 } ;
1919use quote:: quote;
20- use rustpython_bytecode:: bytecode:: CodeObject ;
20+ use rustpython_bytecode:: bytecode:: { CodeObject , FrozenModule } ;
2121use rustpython_compiler:: compile;
2222use std:: collections:: HashMap ;
2323use std:: env;
@@ -52,7 +52,7 @@ impl CompilationSource {
5252 & self ,
5353 mode : & compile:: Mode ,
5454 module_name : String ,
55- ) -> Result < HashMap < String , CodeObject > , Diagnostic > {
55+ ) -> Result < HashMap < String , FrozenModule > , Diagnostic > {
5656 Ok ( match & self . kind {
5757 CompilationSourceKind :: File ( rel_path) => {
5858 let mut path = PathBuf :: from (
@@ -65,10 +65,20 @@ impl CompilationSource {
6565 format ! ( "Error reading file {:?}: {}" , path, err) ,
6666 )
6767 } ) ?;
68- hashmap ! { module_name. clone( ) => self . compile_string( & source, mode, module_name. clone( ) ) ?}
68+ hashmap ! {
69+ module_name. clone( ) => FrozenModule {
70+ code: self . compile_string( & source, mode, module_name. clone( ) ) ?,
71+ package: false ,
72+ } ,
73+ }
6974 }
7075 CompilationSourceKind :: SourceCode ( code) => {
71- hashmap ! { module_name. clone( ) => self . compile_string( code, mode, module_name. clone( ) ) ?}
76+ hashmap ! {
77+ module_name. clone( ) => FrozenModule {
78+ code: self . compile_string( code, mode, module_name. clone( ) ) ?,
79+ package: false ,
80+ } ,
81+ }
7282 }
7383 CompilationSourceKind :: Dir ( rel_path) => {
7484 let mut path = PathBuf :: from (
@@ -85,7 +95,7 @@ impl CompilationSource {
8595 path : & Path ,
8696 parent : String ,
8797 mode : & compile:: Mode ,
88- ) -> Result < HashMap < String , CodeObject > , Diagnostic > {
98+ ) -> Result < HashMap < String , FrozenModule > , Diagnostic > {
8999 let mut code_map = HashMap :: new ( ) ;
90100 let paths = fs:: read_dir ( & path) . map_err ( |err| {
91101 Diagnostic :: spans_error ( self . span , format ! ( "Error listing dir {:?}: {}" , path, err) )
@@ -95,11 +105,13 @@ impl CompilationSource {
95105 Diagnostic :: spans_error ( self . span , format ! ( "Failed to list file: {}" , err) )
96106 } ) ?;
97107 let path = path. path ( ) ;
98- let file_name = path. file_name ( ) . unwrap ( ) . to_str ( ) . unwrap ( ) ;
108+ let file_name = path. file_name ( ) . unwrap ( ) . to_str ( ) . ok_or_else ( || {
109+ Diagnostic :: spans_error ( self . span , format ! ( "Invalid UTF-8 in file name {:?}" , path) )
110+ } ) ?;
99111 if path. is_dir ( ) {
100112 code_map. extend ( self . compile_dir (
101113 & path,
102- format ! ( "{}{}. " , parent, file_name) ,
114+ format ! ( "{}{}" , parent, file_name) ,
103115 mode,
104116 ) ?) ;
105117 } else if file_name. ends_with ( ".py" ) {
@@ -109,11 +121,21 @@ impl CompilationSource {
109121 format ! ( "Error reading file {:?}: {}" , path, err) ,
110122 )
111123 } ) ?;
112- let file_name_splitte: Vec < & str > = file_name. splitn ( 2 , '.' ) . collect ( ) ;
113- let module_name = format ! ( "{}{}" , parent, file_name_splitte[ 0 ] ) ;
124+ let stem = path. file_stem ( ) . unwrap ( ) . to_str ( ) . unwrap ( ) ;
125+ let is_init = stem == "__init__" ;
126+ let module_name = if is_init {
127+ parent. clone ( )
128+ } else if parent. is_empty ( ) {
129+ stem. to_string ( )
130+ } else {
131+ format ! ( "{}.{}" , parent, stem)
132+ } ;
114133 code_map. insert (
115134 module_name. clone ( ) ,
116- self . compile_string ( & source, mode, module_name) ?,
135+ FrozenModule {
136+ code : self . compile_string ( & source, mode, module_name) ?,
137+ package : is_init,
138+ } ,
117139 ) ;
118140 }
119141 }
@@ -128,7 +150,7 @@ struct PyCompileInput {
128150}
129151
130152impl PyCompileInput {
131- fn compile ( & self ) -> Result < HashMap < String , CodeObject > , Diagnostic > {
153+ fn compile ( & self ) -> Result < HashMap < String , FrozenModule > , Diagnostic > {
132154 let mut module_name = None ;
133155 let mut mode = None ;
134156 let mut source: Option < CompilationSource > = None ;
@@ -225,13 +247,21 @@ pub fn impl_py_compile_bytecode(input: TokenStream2) -> Result<TokenStream2, Dia
225247
226248 let code_map = input. compile ( ) ?;
227249
228- let modules = code_map. iter ( ) . map ( |( module_name, code_obj) | {
229- let module_name = LitStr :: new ( & module_name, Span :: call_site ( ) ) ;
230- let bytes = bincode:: serialize ( & code_obj) . expect ( "Failed to serialize" ) ;
231- let bytes = LitByteStr :: new ( & bytes, Span :: call_site ( ) ) ;
232- quote ! { #module_name. into( ) => bincode:: deserialize:: <:: rustpython_vm:: bytecode:: CodeObject >( #bytes)
233- . expect( "Deserializing CodeObject failed" ) }
234- } ) ;
250+ let modules = code_map
251+ . into_iter ( )
252+ . map ( |( module_name, FrozenModule { code, package } ) | {
253+ let module_name = LitStr :: new ( & module_name, Span :: call_site ( ) ) ;
254+ let bytes = bincode:: serialize ( & code) . expect ( "Failed to serialize" ) ;
255+ let bytes = LitByteStr :: new ( & bytes, Span :: call_site ( ) ) ;
256+ quote ! {
257+ #module_name. into( ) => :: rustpython_vm:: bytecode:: FrozenModule {
258+ code: bincode:: deserialize:: <:: rustpython_vm:: bytecode:: CodeObject >(
259+ #bytes
260+ ) . expect( "Deserializing CodeObject failed" ) ,
261+ package: #package,
262+ }
263+ }
264+ } ) ;
235265
236266 let output = quote ! {
237267 ( {
0 commit comments