1616
1717import javax .net .ssl .SSLEngine ;
1818import javax .net .ssl .SSLEngineResult ;
19+ import javax .net .ssl .SSLEngineResult .HandshakeStatus ;
1920import javax .net .ssl .SSLException ;
2021import javax .net .ssl .SSLSession ;
2122
2223/**
2324 * Implements the relevant portions of the SocketChannel interface with the SSLEngine wrapper.
2425 */
25- public class SSLSocketChannel2 implements ByteChannel {
26+ public class SSLSocketChannel2 implements ByteChannel , WrappedByteChannel {
27+ private static ByteBuffer emptybuffer = ByteBuffer .allocate ( 0 );
28+
2629 /** raw payload incomming */
27- private ByteBuffer clientIn ;
28- /** raw payload outgoing */
29- private ByteBuffer clientOut ;
30+ private ByteBuffer inData ;
3031 /** encrypted data outgoing */
31- private ByteBuffer cTOs ;
32+ private ByteBuffer outCrypt ;
3233 /** encrypted data incoming */
33- private ByteBuffer sTOc ;
34+ private ByteBuffer inCrypt ;
3435
3536 private SocketChannel sc ;
3637 private SelectionKey key ;
3738
3839 private SSLEngineResult res ;
3940 private SSLEngine sslEngine ;
40- private int SSL ;
4141
4242 public SSLSocketChannel2 ( SelectionKey key , SSLEngine sslEngine ) throws IOException {
4343 this .sc = (SocketChannel ) key .channel ();
4444 this .key = key ;
4545 this .sslEngine = sslEngine ;
46- SSL = 1 ;
47- try {
48- sslEngine .setEnableSessionCreation ( true );
49- SSLSession session = sslEngine .getSession ();
50- createBuffers ( session );
51- // wrap
52- clientOut .clear ();
53- sc .write ( wrap ( clientOut ) );
54- assert ( !clientOut .hasRemaining () );
55- while ( res .getHandshakeStatus () != SSLEngineResult .HandshakeStatus .FINISHED ) {
56- processHandshake ();
57- }
58- clientIn .clear ();
59- clientIn .flip ();
60- SSL = 4 ;
61- } catch ( Exception e ) {
62- e .printStackTrace ( System .out );
63- SSL = 0 ;
64- }
46+
47+ key .interestOps ( key .interestOps () | SelectionKey .OP_WRITE );
48+
49+ sslEngine .setEnableSessionCreation ( true );
50+ SSLSession session = sslEngine .getSession ();
51+ createBuffers ( session );
52+
53+ // there is not yet any user data
54+ inData .flip ();
55+ inCrypt .flip ();
56+ outCrypt .flip ();
57+
58+ wrap ( emptybuffer );
59+
60+ processHandshake ();
6561 }
6662
67- private void processHandshake () throws IOException {
68- assert ( res .getHandshakeStatus () != SSLEngineResult .HandshakeStatus .FINISHED );
63+ private boolean processHandshake () throws IOException {
6964 if ( res .getHandshakeStatus () == SSLEngineResult .HandshakeStatus .NEED_UNWRAP ) {
70- // unwrap
71- if ( !sTOc .hasRemaining () )
72- sTOc .clear ();
73- sc .read ( sTOc );
74- sTOc .flip ();
75- unwrap ( sTOc );
65+ if ( !inCrypt .hasRemaining () )
66+ inCrypt .clear ();
67+ sc .read ( inCrypt );
68+ inCrypt .flip ();
69+ unwrap ();
7670 if ( res .getHandshakeStatus () != SSLEngineResult .HandshakeStatus .FINISHED ) {
77- clientOut .clear ();
78- sc .write ( wrap ( clientOut ) );
71+ // if( !outData.hasRemaining() )
72+ // outData.clear();
73+ sc .write ( wrap ( emptybuffer ) );
7974 }
8075 } else if ( res .getHandshakeStatus () == SSLEngineResult .HandshakeStatus .NEED_WRAP ) {
81- // wrap
82- clientOut .clear ();
83- sc .write ( wrap ( clientOut ) );
76+ sc .write ( wrap ( emptybuffer ) );
8477 } else {
8578 assert ( false );
8679 }
8780
81+ return false ;
8882 }
8983
9084 private synchronized ByteBuffer wrap ( ByteBuffer b ) throws SSLException {
91- cTOs .clear ();
92- res = sslEngine .wrap ( b , cTOs );
93- cTOs .flip ();
94- return cTOs ;
85+ if ( !outCrypt .hasRemaining () )
86+ outCrypt .clear ();
87+ res = sslEngine .wrap ( b , outCrypt );
88+ outCrypt .flip ();
89+ return outCrypt ;
9590 }
9691
97- private synchronized ByteBuffer unwrap ( ByteBuffer b ) throws SSLException {
98- clientIn . clear ();
99- while ( b .hasRemaining () ) {
100- res = sslEngine .unwrap ( b , clientIn );
92+ private synchronized ByteBuffer unwrap () throws SSLException {
93+ inData . compact ();
94+ while ( inCrypt .hasRemaining () ) {
95+ res = sslEngine .unwrap ( inCrypt , inData );
10196 if ( res .getHandshakeStatus () == SSLEngineResult .HandshakeStatus .NEED_TASK ) {
10297 // Task
10398 Runnable task ;
10499 while ( ( task = sslEngine .getDelegatedTask () ) != null ) {
105100 task .run ();
106101 }
107102 } else if ( res .getHandshakeStatus () == SSLEngineResult .HandshakeStatus .FINISHED ) {
108- return clientIn ;
103+ break ;
109104 } else if ( res .getStatus () == SSLEngineResult .Status .BUFFER_UNDERFLOW ) {
110- return clientIn ;
105+ break ;
106+ } else if ( res .getStatus () == SSLEngineResult .Status .BUFFER_OVERFLOW ) {
107+ assert ( false );
111108 }
112109 }
113- return clientIn ;
110+ inData .flip ();
111+ return inData ;
114112 }
115113
116114 private void createBuffers ( SSLSession session ) {
117115 int appBufferMax = session .getApplicationBufferSize ();
118116 int netBufferMax = session .getPacketBufferSize ();
119117
120- clientIn = ByteBuffer .allocate ( 65536 );
121- clientOut = ByteBuffer .allocate ( appBufferMax );
118+ inData = ByteBuffer .allocate ( 65536 );
122119
123- cTOs = ByteBuffer .allocate ( netBufferMax );
124- sTOc = ByteBuffer .allocate ( netBufferMax );
120+ outCrypt = ByteBuffer .allocate ( netBufferMax );
121+ inCrypt = ByteBuffer .allocate ( netBufferMax );
125122 }
126123
127124 public int write ( ByteBuffer src ) throws IOException {
128- if ( SSL == 4 ) {
125+ if ( !isHandShakeComplete () ) {
126+ processHandshake ();
127+ return 0 ;
128+ } else {
129129 return sc .write ( wrap ( src ) );
130130 }
131- return sc .write ( src );
132131 }
133132
134133 public int read ( ByteBuffer dst ) throws IOException {
134+ if ( !isHandShakeComplete () ) {
135+ processHandshake ();
136+ return 0 ;
137+ }
135138 int amount = 0 , limit ;
136- if ( SSL == 4 ) {
137- // test if there was a buffer overflow in dst
138- if ( clientIn .hasRemaining () ) {
139- limit = Math .min ( clientIn .remaining (), dst .remaining () );
140- for ( int i = 0 ; i < limit ; i ++ ) {
141- dst .put ( clientIn .get () );
142- amount ++;
143- }
144- return amount ;
145- }
146- // test if some bytes left from last read (e.g. BUFFER_UNDERFLOW)
147- if ( sTOc .hasRemaining () ) {
148- unwrap ( sTOc );
149- clientIn .flip ();
150- limit = Math .min ( clientIn .limit (), dst .remaining () );
151- for ( int i = 0 ; i < limit ; i ++ ) {
152- dst .put ( clientIn .get () );
153- amount ++;
154- }
155- if ( res .getStatus () != SSLEngineResult .Status .BUFFER_UNDERFLOW ) {
156- sTOc .clear ();
157- sTOc .flip ();
158- return amount ;
159- }
160- }
161- if ( !sTOc .hasRemaining () )
162- sTOc .clear ();
163- else
164- sTOc .compact ();
165-
166- if ( sc .read ( sTOc ) == -1 ) {
167- sTOc .clear ();
168- sTOc .flip ();
169- return -1 ;
170- }
171- sTOc .flip ();
172- unwrap ( sTOc );
173- // write in dst
174- clientIn .flip ();
175- limit = Math .min ( clientIn .limit (), dst .remaining () );
139+ // test if there was a buffer overflow in dst
140+ if ( inData .hasRemaining () ) {
141+ limit = Math .min ( inData .remaining (), dst .remaining () );
176142 for ( int i = 0 ; i < limit ; i ++ ) {
177- dst .put ( clientIn .get () );
143+ dst .put ( inData .get () );
178144 amount ++;
179145 }
180146 return amount ;
181147 }
182- return sc .read ( dst );
148+ // test if some bytes left from last read (e.g. BUFFER_UNDERFLOW)
149+ if ( inCrypt .hasRemaining () ) {
150+ unwrap ();
151+ inData .flip ();
152+ limit = Math .min ( inData .limit (), dst .remaining () );
153+ for ( int i = 0 ; i < limit ; i ++ ) {
154+ dst .put ( inData .get () );
155+ amount ++;
156+ }
157+ if ( res .getStatus () != SSLEngineResult .Status .BUFFER_UNDERFLOW ) {
158+ inCrypt .clear ();
159+ inCrypt .flip ();
160+ return amount ;
161+ }
162+ }
163+ if ( !inCrypt .hasRemaining () )
164+ inCrypt .clear ();
165+ else
166+ inCrypt .compact ();
167+
168+ if ( sc .read ( inCrypt ) == -1 ) {
169+ inCrypt .clear ();
170+ inCrypt .flip ();
171+ return -1 ;
172+ }
173+ inCrypt .flip ();
174+ unwrap ();
175+ // write in dst
176+ // inData.flip();
177+ limit = Math .min ( inData .limit (), dst .remaining () );
178+ for ( int i = 0 ; i < limit ; i ++ ) {
179+ dst .put ( inData .get () );
180+ amount ++;
181+ }
182+ return amount ;
183+
183184 }
184185
185186 public boolean isConnected () {
186187 return sc .isConnected ();
187188 }
188189
189190 public void close () throws IOException {
190- if ( SSL == 4 ) {
191- sslEngine .closeOutbound ();
192- sslEngine .getSession ().invalidate ();
193- clientOut .clear ();
194- sc .write ( wrap ( clientOut ) );
195- sc .close ();
196- } else
197- sc .close ();
191+ sslEngine .closeOutbound ();
192+ sslEngine .getSession ().invalidate ();
193+ outCrypt .compact ();
194+ int wr = sc .write ( wrap ( emptybuffer ) );
195+ sc .close ();
196+ }
197+
198+ private boolean isHandShakeComplete (){
199+ HandshakeStatus status = res .getHandshakeStatus ();
200+ return status == SSLEngineResult .HandshakeStatus .FINISHED || status == SSLEngineResult .HandshakeStatus .NOT_HANDSHAKING ;
198201 }
199202
200203 public SelectableChannel configureBlocking ( boolean b ) throws IOException {
@@ -221,4 +224,14 @@ public boolean isInboundDone() {
221224 public boolean isOpen () {
222225 return sc .isOpen ();
223226 }
227+
228+ @ Override
229+ public boolean isNeedWrite () {
230+ return outCrypt .hasRemaining () || !isHandShakeComplete ();
231+ }
232+
233+ @ Override
234+ public void write () throws IOException {
235+ write ( emptybuffer );
236+ }
224237}
0 commit comments