Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion r2dbc-mysql/src/main/java/MysqlConnectionFactoryProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class MysqlConnectionFactoryProvider : ConnectionFactoryProvider {
database = connectionFactoryOptions.getValue(DATABASE) as String?,
applicationName = connectionFactoryOptions.getValue(APPLICATION_NAME) as String?,
connectionTimeout = (connectionFactoryOptions.getValue(CONNECT_TIMEOUT) as Duration?)?.toMillis()?.toInt() ?: 5000,
queryTimeout = connectionFactoryOptions.getValue(STATEMENT_TIMEOUT) as Duration?
queryTimeout = connectionFactoryOptions.getValue(STATEMENT_TIMEOUT) as Duration?,
ssl = MysqlSSLConfigurationFactory.create(connectionFactoryOptions)
)
return JasyncConnectionFactory(MySQLConnectionFactory(configuration))
}
Expand Down
33 changes: 33 additions & 0 deletions r2dbc-mysql/src/main/java/MysqlSSLConfigurationFactory.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.github.jasync.r2dbc.mysql

import com.github.jasync.sql.db.SSLConfiguration
import com.github.jasync.sql.db.SSLConfiguration.Mode.Disable
import com.github.jasync.sql.db.SSLConfiguration.Mode.Prefer
import com.github.jasync.sql.db.SSLConfiguration.Mode.Require
import com.github.jasync.sql.db.SSLConfiguration.Mode.VerifyCA
import com.github.jasync.sql.db.SSLConfiguration.Mode.VerifyFull
import io.r2dbc.spi.ConnectionFactoryOptions
import io.r2dbc.spi.Option

object MysqlSSLConfigurationFactory {

private val SSL_MODE_OPTION = Option.valueOf<String>("sslMode")
private val SSL_MODE_MAP = mapOf(
"disabled" to Disable,
"preferred" to Prefer,
"required" to Require,
"verify_ca" to VerifyCA,
"verify_identity" to VerifyFull
)

fun create(options: ConnectionFactoryOptions): SSLConfiguration {
if (!options.hasOption(ConnectionFactoryOptions.SSL)) {
return SSLConfiguration(mode = Disable)
}
if (!options.hasOption(SSL_MODE_OPTION)) {
return SSLConfiguration(mode = Prefer)
}
val sslMode = options.getValue(SSL_MODE_OPTION) as String
return SSLConfiguration(mode = SSL_MODE_MAP.getOrDefault(sslMode.lowercase(), Prefer))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.github.jasync.r2dbc.mysql

import com.github.jasync.sql.db.SSLConfiguration
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import io.r2dbc.spi.ConnectionFactoryOptions
import org.junit.Test

internal class MysqlConnectionFactoryProviderTest {

private val provider = MysqlConnectionFactoryProvider()

@Test
fun shouldCreateMysqlConnectionWithMysqlSSLConfigurationFactory() {
// given
mockkObject(MysqlSSLConfigurationFactory)
every { MysqlSSLConfigurationFactory.create(any()) } returns SSLConfiguration()

val options =
ConnectionFactoryOptions.parse("r2dbc:mysql://user@host:443/")

// when
provider.create(options)

// then
verify {
MysqlSSLConfigurationFactory.create(options)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package com.github.jasync.r2dbc.mysql

import com.github.jasync.sql.db.SSLConfiguration
import com.github.jasync.sql.db.SSLConfiguration.Mode.Disable
import com.github.jasync.sql.db.SSLConfiguration.Mode.Prefer
import com.github.jasync.sql.db.SSLConfiguration.Mode.Require
import com.github.jasync.sql.db.SSLConfiguration.Mode.VerifyCA
import com.github.jasync.sql.db.SSLConfiguration.Mode.VerifyFull
import io.r2dbc.spi.ConnectionFactoryOptions
import io.r2dbc.spi.ConnectionFactoryOptions.SSL
import io.r2dbc.spi.Option
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import kotlin.test.assertEquals

@RunWith(Parameterized::class)
internal class MysqlSSLConfigurationFactoryTest(
private val options: ConnectionFactoryOptions,
private val expectedSSLConfiguration: SSLConfiguration,
private val message: String
) {

companion object {

private val SSL_MODE_OPTION = Option.valueOf<String>("sslMode")

@JvmStatic
@Parameterized.Parameters
fun data() = listOf(
createTestParams(
options = ConnectionFactoryOptions.builder().build(),
expected = SSLConfiguration(),
message = "sslMode should be 'disabled' for non-secure protocol"
),
createTestParams(
options = ConnectionFactoryOptions.builder()
.option(SSL, true)
.build(),
expected = SSLConfiguration(mode = Prefer),
message = "sslMode should be 'preferred' by default"
),
createTestParams(
options = ConnectionFactoryOptions.builder()
.option(SSL, true)
.option(SSL_MODE_OPTION, "invalid")
.build(),
expected = SSLConfiguration(mode = Prefer),
message = "sslMode should be 'preferred' for invalid value"
),
createTestParams(
options = ConnectionFactoryOptions.builder()
.option(SSL, true)
.option(SSL_MODE_OPTION, "REQUIRED")
.build(),
expected = SSLConfiguration(mode = Require),
message = "sslMode should be case insensitive"
),
createTestParams(
options = ConnectionFactoryOptions.builder()
.option(SSL, true)
.option(SSL_MODE_OPTION, "disabled")
.build(),
expected = SSLConfiguration(mode = Disable),
message = "sslMode should be 'disabled'"
),
createTestParams(
options = ConnectionFactoryOptions.builder()
.option(SSL, true)
.option(SSL_MODE_OPTION, "preferred")
.build(),
expected = SSLConfiguration(mode = Prefer),
message = "sslMode should be 'preferred'"
),
createTestParams(
options = ConnectionFactoryOptions.builder()
.option(SSL, true)
.option(SSL_MODE_OPTION, "required")
.build(),
expected = SSLConfiguration(mode = Require),
message = "sslMode should be 'required'"
),
createTestParams(
options = ConnectionFactoryOptions.builder()
.option(SSL, true)
.option(SSL_MODE_OPTION, "verify_ca")
.build(),
expected = SSLConfiguration(mode = VerifyCA),
message = "sslMode should be 'verify_ca'"
),
createTestParams(
options = ConnectionFactoryOptions.builder()
.option(SSL, true)
.option(SSL_MODE_OPTION, "verify_identity")
.build(),
expected = SSLConfiguration(mode = VerifyFull),
message = "sslMode should be 'verify_identity'"
),
)

private fun createTestParams(
options: ConnectionFactoryOptions,
expected: SSLConfiguration,
message: String
) = arrayOf(options, expected, message)
}

@Test
fun shouldCreateProperSSLConfiguration() {
// when
val result = MysqlSSLConfigurationFactory.create(options)

// then
assertEquals(expectedSSLConfiguration, result, message)
}
}