diff --git a/src/main/scala/com/microsoft/sqlserver/jdbc/spark/SQLServerBulkJdbcOptions.scala b/src/main/scala/com/microsoft/sqlserver/jdbc/spark/SQLServerBulkJdbcOptions.scala index 710f570..e9e03a8 100644 --- a/src/main/scala/com/microsoft/sqlserver/jdbc/spark/SQLServerBulkJdbcOptions.scala +++ b/src/main/scala/com/microsoft/sqlserver/jdbc/spark/SQLServerBulkJdbcOptions.scala @@ -68,10 +68,12 @@ class SQLServerBulkJdbcOptions(val params: CaseInsensitiveMap[String]) val allowEncryptedValueModifications = params.getOrElse("allowEncryptedValueModifications", "false").toBoolean - val schemaCheckEnabled = params.getOrElse("schemaCheckEnabled", "true").toBoolean + val hideGraphColumns = + params.getOrElse("hideGraphColumns", "true").toBoolean + // Not a feature // Only used for internally testing data idempotency val testDataIdempotency = diff --git a/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala b/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala index 347694a..86cd003 100644 --- a/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala +++ b/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala @@ -186,8 +186,29 @@ object BulkCopyUtils extends Logging { */ private[spark] def getComputedCols( conn: Connection, - table: String): List[String] = { - val queryStr = s"SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID('${table}');" + table: String, + hideGraphColumns: Boolean): List[String] = { + // TODO can optimize this, also evaluate SQLi issues + val queryStr = if (hideGraphColumns) s"""IF (SERVERPROPERTY('EngineEdition') = 5 OR SERVERPROPERTY('ProductMajorVersion') >= 14) +exec sp_executesql N'SELECT name + FROM sys.computed_columns + WHERE object_id = OBJECT_ID(''${table}'') + UNION ALL + SELECT C.name + FROM sys.tables AS T + JOIN sys.columns AS C + ON T.object_id = C.object_id + WHERE T.object_id = OBJECT_ID(''${table}'') + AND (T.is_edge = 1 OR T.is_node = 1) + AND C.is_hidden = 0 + AND C.graph_type = 2' +ELSE +SELECT name + FROM sys.computed_columns + WHERE object_id = OBJECT_ID('${table}') + """ + else s"SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID('${table}');" + val computedColRs = conn.createStatement.executeQuery(queryStr) val computedCols = ListBuffer[String]() while (computedColRs.next()) { @@ -263,7 +284,7 @@ object BulkCopyUtils extends Logging { val colMetaData = { if(checkSchema) { checkExTableType(conn, options) - matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled) + matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled, options.hideGraphColumns) } else { defaultColMetadataMap(rs.getMetaData()) } @@ -289,6 +310,7 @@ object BulkCopyUtils extends Logging { * @param url: String, * @param isCaseSensitive: Boolean * @param strictSchemaCheck: Boolean + * @param hideGraphColumns - Whether to hide the $node_id, $from_id, $to_id, $edge_id columns in SQL graph tables */ private[spark] def matchSchemas( conn: Connection, @@ -297,13 +319,14 @@ object BulkCopyUtils extends Logging { rs: ResultSet, url: String, isCaseSensitive: Boolean, - strictSchemaCheck: Boolean): Array[ColumnMetadata]= { + strictSchemaCheck: Boolean, + hideGraphColumns: Boolean): Array[ColumnMetadata]= { val dfColCaseMap = (df.schema.fieldNames.map(item => item.toLowerCase) zip df.schema.fieldNames.toList).toMap val dfCols = df.schema val tableCols = getSchema(rs, JdbcDialects.get(url)) - val computedCols = getComputedCols(conn, dbtable) + val computedCols = getComputedCols(conn, dbtable, hideGraphColumns) val prefix = "Spark Dataframe and SQL Server table have differing"