Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50286][SQL] Correctly propagate SQL options to WriteBuilder #48822

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ abstract class InMemoryBaseTable(
TableCapability.TRUNCATE)

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryScanBuilder(schema)
new InMemoryScanBuilder(schema, options)
}

private def canEvaluate(filter: Filter): Boolean = {
Expand All @@ -309,16 +309,18 @@ abstract class InMemoryBaseTable(
}
}

class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder
with SupportsPushDownRequiredColumns with SupportsPushDownFilters {
class InMemoryScanBuilder(
tableSchema: StructType,
options: CaseInsensitiveStringMap) extends ScanBuilder
with SupportsPushDownRequiredColumns with SupportsPushDownFilters {
private var schema: StructType = tableSchema
private var postScanFilters: Array[Filter] = Array.empty
private var evaluableFilters: Array[Filter] = Array.empty
private var _pushedFilters: Array[Filter] = Array.empty

override def build: Scan = {
val scan = InMemoryBatchScan(
data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema)
data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, options)
if (evaluableFilters.nonEmpty) {
scan.filter(evaluableFilters)
}
Expand Down Expand Up @@ -442,7 +444,8 @@ abstract class InMemoryBaseTable(
case class InMemoryBatchScan(
var _data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
tableSchema: StructType,
options: CaseInsensitiveStringMap)
extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeFiltering {

override def filterAttributes(): Array[NamedReference] = {
Expand Down Expand Up @@ -474,17 +477,17 @@ abstract class InMemoryBaseTable(
}
}

abstract class InMemoryWriterBuilder() extends SupportsTruncate with SupportsDynamicOverwrite
with SupportsStreamingUpdateAsAppend {
abstract class InMemoryWriterBuilder(val info: LogicalWriteInfo)
extends SupportsTruncate with SupportsDynamicOverwrite with SupportsStreamingUpdateAsAppend {

protected var writer: BatchWrite = Append
protected var streamingWriter: StreamingWrite = StreamingAppend
protected var writer: BatchWrite = new Append(info)
protected var streamingWriter: StreamingWrite = new StreamingAppend(info)

override def overwriteDynamicPartitions(): WriteBuilder = {
if (writer != Append) {
if (!writer.isInstanceOf[Append]) {
throw new IllegalArgumentException(s"Unsupported writer type: $writer")
}
writer = DynamicOverwrite
writer = new DynamicOverwrite(info)
streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions")
this
}
Expand Down Expand Up @@ -525,21 +528,21 @@ abstract class InMemoryBaseTable(
override def abort(messages: Array[WriterCommitMessage]): Unit = {}
}

protected object Append extends TestBatchWrite {
class Append(val info: LogicalWriteInfo) extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
withData(messages.map(_.asInstanceOf[BufferedRows]))
}
}

private object DynamicOverwrite extends TestBatchWrite {
class DynamicOverwrite(val info: LogicalWriteInfo) extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
val newData = messages.map(_.asInstanceOf[BufferedRows])
dataMap --= newData.flatMap(_.rows.map(getKey))
withData(newData)
}
}

protected object TruncateAndAppend extends TestBatchWrite {
class TruncateAndAppend(val info: LogicalWriteInfo) extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
dataMap.clear()
withData(messages.map(_.asInstanceOf[BufferedRows]))
Expand Down Expand Up @@ -568,15 +571,15 @@ abstract class InMemoryBaseTable(
s"${operation} isn't supported for streaming query.")
}

private object StreamingAppend extends TestStreamingWrite {
class StreamingAppend(val info: LogicalWriteInfo) extends TestStreamingWrite {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
dataMap.synchronized {
withData(messages.map(_.asInstanceOf[BufferedRows]))
}
}
}

protected object StreamingTruncateAndAppend extends TestStreamingWrite {
class StreamingTruncateAndAppend(val info: LogicalWriteInfo) extends TestStreamingWrite {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
dataMap.synchronized {
dataMap.clear()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class InMemoryRowLevelOperationTable(
}

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryScanBuilder(schema) {
new InMemoryScanBuilder(schema, options) {
override def build: Scan = {
val scan = super.build()
configuredScan = scan.asInstanceOf[InMemoryBatchScan]
Expand Down Expand Up @@ -115,7 +115,7 @@ class InMemoryRowLevelOperationTable(
override def rowId(): Array[NamedReference] = Array(PK_COLUMN_REF)

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryScanBuilder(schema)
new InMemoryScanBuilder(schema, options)
}

override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,23 @@ class InMemoryTable(
InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties))
InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)

new InMemoryWriterBuilderWithOverWrite()
new InMemoryWriterBuilderWithOverWrite(info)
}

private class InMemoryWriterBuilderWithOverWrite() extends InMemoryWriterBuilder
with SupportsOverwrite {
class InMemoryWriterBuilderWithOverWrite(override val info: LogicalWriteInfo)
extends InMemoryWriterBuilder(info) with SupportsOverwrite {

override def truncate(): WriteBuilder = {
if (writer != Append) {
if (!writer.isInstanceOf[Append]) {
throw new IllegalArgumentException(s"Unsupported writer type: $writer")
}
writer = TruncateAndAppend
streamingWriter = StreamingTruncateAndAppend
writer = new TruncateAndAppend(info)
streamingWriter = new StreamingTruncateAndAppend(info)
this
}

override def overwrite(filters: Array[Filter]): WriteBuilder = {
if (writer != Append) {
if (!writer.isInstanceOf[Append]) {
throw new IllegalArgumentException(s"Unsupported writer type: $writer")
}
writer = new Overwrite(filters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,22 @@ class InMemoryTableWithV2Filter(
}

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryV2FilterScanBuilder(schema)
new InMemoryV2FilterScanBuilder(schema, options)
}

class InMemoryV2FilterScanBuilder(tableSchema: StructType)
extends InMemoryScanBuilder(tableSchema) {
class InMemoryV2FilterScanBuilder(
tableSchema: StructType,
options: CaseInsensitiveStringMap)
extends InMemoryScanBuilder(tableSchema, options) {
override def build: Scan = InMemoryV2FilterBatchScan(
data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema)
data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, options)
}

case class InMemoryV2FilterBatchScan(
var _data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
tableSchema: StructType,
options: CaseInsensitiveStringMap)
extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeV2Filtering {

override def filterAttributes(): Array[NamedReference] = {
Expand Down Expand Up @@ -93,21 +96,21 @@ class InMemoryTableWithV2Filter(
InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties))
InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)

new InMemoryWriterBuilderWithOverWrite()
new InMemoryWriterBuilderWithOverWrite(info)
}

private class InMemoryWriterBuilderWithOverWrite() extends InMemoryWriterBuilder
with SupportsOverwriteV2 {
class InMemoryWriterBuilderWithOverWrite(override val info: LogicalWriteInfo)
extends InMemoryWriterBuilder(info) with SupportsOverwriteV2 {

override def truncate(): WriteBuilder = {
assert(writer == Append)
writer = TruncateAndAppend
streamingWriter = StreamingTruncateAndAppend
assert(writer.isInstanceOf[Append])
writer = new TruncateAndAppend(info)
streamingWriter = new StreamingTruncateAndAppend(info)
this
}

override def overwrite(predicates: Array[Predicate]): WriteBuilder = {
assert(writer == Append)
assert(writer.isInstanceOf[Append])
writer = new Overwrite(predicates)
streamingWriter = new StreamingNotSupportedOperation(
s"overwrite (${predicates.mkString("filters(", ", ", ")")})")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.v2

import java.util.{Optional, UUID}

import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project, ReplaceData, WriteDelta}
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -44,7 +46,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {

override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case a @ AppendData(r: DataSourceV2Relation, query, options, _, None, _) =>
val writeBuilder = newWriteBuilder(r.table, options, query.schema)
val writeBuilder = newWriteBuilder(r.table, r.options.asScala.toMap ++ options, query.schema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add an assert that only one of them can be non empty?

val write = writeBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
a.copy(write = Some(write), query = newQuery)
Expand All @@ -61,7 +63,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
}.toArray

val table = r.table
val writeBuilder = newWriteBuilder(table, options, query.schema)
val writeBuilder = newWriteBuilder(table, r.options.asScala.toMap ++ options, query.schema)
val write = writeBuilder match {
case builder: SupportsTruncate if isTruncate(predicates) =>
builder.truncate().build()
Expand All @@ -76,7 +78,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {

case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) =>
val table = r.table
val writeBuilder = newWriteBuilder(table, options, query.schema)
val writeBuilder = newWriteBuilder(table, r.options.asScala.toMap ++ options, query.schema)
val write = writeBuilder match {
case builder: SupportsDynamicOverwrite =>
builder.overwriteDynamicPartitions().build()
Expand All @@ -87,26 +89,26 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
o.copy(write = Some(write), query = newQuery)

case WriteToMicroBatchDataSource(
relation, table, query, queryId, writeOptions, outputMode, Some(batchId)) =>

relationOpt, table, query, queryId, options, outputMode, Some(batchId)) =>
val writeOptions = relationOpt.map(r => r.options.asScala.toMap ++ options).getOrElse(options)
val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId)
val write = buildWriteForMicroBatch(table, writeBuilder, outputMode)
val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming)
val customMetrics = write.supportedCustomMetrics.toImmutableArraySeq
val funCatalogOpt = relation.flatMap(_.funCatalog)
val funCatalogOpt = relationOpt.flatMap(_.funCatalog)
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt)
WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics)
WriteToDataSourceV2(relationOpt, microBatchWrite, newQuery, customMetrics)

case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, _, None) =>
val rowSchema = DataTypeUtils.fromAttributes(rd.dataInput)
val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema)
val writeBuilder = newWriteBuilder(r.table, r.options.asScala.toMap, rowSchema)
val write = writeBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
// project away any metadata columns that could be used for distribution and ordering
rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery))

case wd @ WriteDelta(r: DataSourceV2Relation, _, query, _, projections, None) =>
val deltaWriteBuilder = newDeltaWriteBuilder(r.table, Map.empty, projections)
val deltaWriteBuilder = newDeltaWriteBuilder(r.table, r.options.asScala.toMap, projections)
val deltaWrite = deltaWriteBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(deltaWrite, query, r.funCatalog)
wd.copy(write = Some(deltaWrite), query = newQuery)
Expand Down
Loading