diff --git a/ext/mysql2/client.c b/ext/mysql2/client.c index 25a35029..a834cfac 100644 --- a/ext/mysql2/client.c +++ b/ext/mysql2/client.c @@ -504,26 +504,34 @@ static int opt_connect_attr_add_i(VALUE key, VALUE value, VALUE arg) } #endif -static VALUE rb_mysql_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE port, VALUE database, VALUE socket, VALUE flags, VALUE conn_attrs) { +static VALUE rb_mysql_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE port, VALUE database, VALUE socket, VALUE flags, VALUE conn_attrs, VALUE tls_sni_name) { struct nogvl_connect_args args; time_t start_time, end_time, elapsed_time, connect_timeout; + const char *sni_hostname; VALUE rv; GET_CLIENT(self); - args.host = NIL_P(host) ? NULL : StringValueCStr(host); - args.unix_socket = NIL_P(socket) ? NULL : StringValueCStr(socket); - args.port = NIL_P(port) ? 0 : NUM2INT(port); - args.user = NIL_P(user) ? NULL : StringValueCStr(user); - args.passwd = NIL_P(pass) ? NULL : StringValueCStr(pass); - args.db = NIL_P(database) ? NULL : StringValueCStr(database); - args.mysql = wrapper->client; - args.client_flag = NUM2ULONG(flags); + args.host = NIL_P(host) ? NULL : StringValueCStr(host); + args.unix_socket = NIL_P(socket) ? NULL : StringValueCStr(socket); + args.port = NIL_P(port) ? 0 : NUM2INT(port); + args.user = NIL_P(user) ? NULL : StringValueCStr(user); + args.passwd = NIL_P(pass) ? NULL : StringValueCStr(pass); + args.db = NIL_P(database) ? NULL : StringValueCStr(database); + args.mysql = wrapper->client; + args.client_flag = NUM2ULONG(flags); + + sni_hostname = NIL_P(tls_sni_name) ? NULL : StringValueCStr(tls_sni_name); #ifdef CLIENT_CONNECT_ATTRS mysql_options(wrapper->client, MYSQL_OPT_CONNECT_ATTR_RESET, 0); rb_hash_foreach(conn_attrs, opt_connect_attr_add_i, (VALUE)wrapper); #endif + if(sni_hostname != NULL) { + /* Set the TLS SNI name if provided */ + mysql_options(wrapper->client, MYSQL_OPT_TLS_SNI_SERVERNAME, sni_hostname); + } + if (wrapper->connect_timeout) time(&start_time); rv = (VALUE) rb_thread_call_without_gvl(nogvl_connect, &args, RUBY_UBF_IO, 0); @@ -1619,7 +1627,7 @@ void init_mysql2_client() { rb_define_private_method(cMysql2Client, "ssl_mode=", rb_set_ssl_mode_option, 1); rb_define_private_method(cMysql2Client, "enable_cleartext_plugin=", set_enable_cleartext_plugin, 1); rb_define_private_method(cMysql2Client, "initialize_ext", initialize_ext, 0); - rb_define_private_method(cMysql2Client, "connect", rb_mysql_connect, 8); + rb_define_private_method(cMysql2Client, "connect", rb_mysql_connect, 9); rb_define_private_method(cMysql2Client, "_query", rb_mysql_query, 2); sym_id = ID2SYM(rb_intern("id")); diff --git a/lib/mysql2/client.rb b/lib/mysql2/client.rb index 2bb81a87..8e1f89b1 100644 --- a/lib/mysql2/client.rb +++ b/lib/mysql2/client.rb @@ -73,12 +73,13 @@ def initialize(opts = {}) check_and_clean_query_options - user = opts[:username] || opts[:user] - pass = opts[:password] || opts[:pass] - host = opts[:host] || opts[:hostname] - port = opts[:port] - database = opts[:database] || opts[:dbname] || opts[:db] - socket = opts[:socket] || opts[:sock] + user = opts[:username] || opts[:user] + pass = opts[:password] || opts[:pass] + host = opts[:host] || opts[:hostname] + tls_sni_name = opts[:tls_sni_name] + port = opts[:port] + database = opts[:database] || opts[:dbname] || opts[:db] + socket = opts[:socket] || opts[:sock] # Correct the data types before passing these values down to the C level user = user.to_s unless user.nil? @@ -86,10 +87,11 @@ def initialize(opts = {}) host = host.to_s unless host.nil? port = port.to_i unless port.nil? database = database.to_s unless database.nil? + tls_sni_name = tls_sni_name.to_s unless tls_sni_name.nil? socket = socket.to_s unless socket.nil? conn_attrs = parse_connect_attrs(opts[:connect_attrs]) - connect user, pass, host, port, database, socket, flags, conn_attrs + connect user, pass, host, port, database, socket, flags, conn_attrs, tls_sni_name end def parse_ssl_mode(mode)