diff --git a/examples/local/vstream_client.go b/examples/local/vstream_client.go index ab00f83871d..38f47bc5eb1 100644 --- a/examples/local/vstream_client.go +++ b/examples/local/vstream_client.go @@ -18,11 +18,17 @@ package main import ( "context" + "encoding/json" + "errors" "fmt" "io" "log" + "slices" "time" + "google.golang.org/protobuf/proto" + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" _ "vitess.io/vitess/go/vt/vtctl/grpcvtctlclient" _ "vitess.io/vitess/go/vt/vtgate/grpcvtgateconn" "vitess.io/vitess/go/vt/vtgate/vtgateconn" @@ -38,35 +44,19 @@ import ( */ func main() { ctx := context.Background() - streamCustomer := true - var vgtid *binlogdatapb.VGtid - if streamCustomer { - vgtid = &binlogdatapb.VGtid{ - ShardGtids: []*binlogdatapb.ShardGtid{{ - Keyspace: "customer", - Shard: "-80", - // Gtid "" is to stream from the start, "current" is to stream from the current gtid - // you can also specify a gtid to start with. - Gtid: "", //"current" // "MySQL56/36a89abd-978f-11eb-b312-04ed332e05c2:1-265" - }, { - Keyspace: "customer", - Shard: "80-", - Gtid: "", - }}} - } else { - vgtid = &binlogdatapb.VGtid{ - ShardGtids: []*binlogdatapb.ShardGtid{{ - Keyspace: "commerce", - Shard: "0", - Gtid: "", - }}} + + vgtid, err := getLastVgtid(ctx) + if err != nil { + log.Fatal(err) } + filter := &binlogdatapb.Filter{ Rules: []*binlogdatapb.Rule{{ Match: "customer", Filter: "select * from customer", }}, } + conn, err := vtgateconn.Dial(ctx, "localhost:15991") if err != nil { log.Fatal(err) @@ -82,17 +72,260 @@ func main() { if err != nil { log.Fatal(err) } + + err = readEvents(ctx, reader) + if err != nil { + log.Fatal(err) + } +} + +// getLastVgtid retrieves the last vgtid processed by the client, so that it can resume from that position. +func getLastVgtid(ctx context.Context) (*binlogdatapb.VGtid, error) { + var vgtid binlogdatapb.VGtid + + // if storeLastVgtid was implemented, you would retrieve the last vgtid and unmarshal it here + // err := json.Unmarshal([]byte{}, &vgtid) + // if err != nil { + // return nil, err + // } + + streamCustomer := true + if streamCustomer { + vgtid = binlogdatapb.VGtid{ + ShardGtids: []*binlogdatapb.ShardGtid{{ + Keyspace: "customer", + Shard: "-80", + // Gtid "" is to stream from the start, "current" is to stream from the current gtid + // you can also specify a gtid to start with. + Gtid: "", // "current" // "MySQL56/36a89abd-978f-11eb-b312-04ed332e05c2:1-265" + }, { + Keyspace: "customer", + Shard: "80-", + Gtid: "", + }}, + } + } else { + vgtid = binlogdatapb.VGtid{ + ShardGtids: []*binlogdatapb.ShardGtid{{ + Keyspace: "commerce", + Shard: "0", + Gtid: "", + }}, + } + } + + return &vgtid, nil +} + +// storeLastVgtid stores the last vgtid processed by the client, so that it can resume from that position on restart. +// Storing a json blob in a database is just one way to do this, you could put it anywhere. +func storeLastVgtid(ctx context.Context, vgtid *binlogdatapb.VGtid) error { + _, err := json.Marshal(vgtid) + if err != nil { + return err + } + + return nil +} + +type Customer struct { + ID int64 + Email string + + // the fields below aren't actually in the schema, but are added for illustrative purposes + EmailConfirmed bool + Details map[string]any + CreatedAt time.Time +} + +func readEvents(ctx context.Context, reader vtgateconn.VStreamReader) error { + // the first event will be the field event, which contains the schema + var customerFields, corderFields []*querypb.Field + + // as we process events, we will keep track of the vgtid so that we can resume from the last position + var lastFlushedVgtid, latestVgtid *binlogdatapb.VGtid + var lastFlushedAt time.Time + + // to avoid flushing too often, we will only flush if it has been at least minFlushDuration since the last flush. + // we're relying on heartbeat events to handle max duration between flushes, in case there are no other events. + const minFlushDuration = 5 * time.Second + const maxCustomersToFlush = 1000 + + var customers []*Customer + + flushFunc := func() error { + // if the lastFlushedVgtid is the same as the latestVgtid, we don't need to do anything + if proto.Equal(lastFlushedVgtid, latestVgtid) { + return nil + } + + // if it hasn't been long enough since the last flush, and we haven't exceeded our match batch size, don't + // flush. We can always replay as needed. + if time.Since(lastFlushedAt) < minFlushDuration && len(customers) < maxCustomersToFlush { + return nil + } + + // if the customer db is the same db you're storing the vgtid in, you could do both in the same transaction + + // flush the customers to the database, using the max batch size + for customerChunk := range slices.Chunk(customers, maxCustomersToFlush) { + err := upsertCustomersToDB(ctx, customerChunk) + if err != nil { + return err + } + } + + // reset the customers slice to free up memory. If you really want to be efficient, you could reuse the slice. + customers = slices.Delete(customers, 0, len(customers)) + + // always store the latest vgtid, even if there are no customers to store + err := storeLastVgtid(ctx, latestVgtid) + if err != nil { + return err + } + + lastFlushedVgtid = latestVgtid + lastFlushedAt = time.Now() + + return nil + } + for { - e, err := reader.Recv() - switch err { - case nil: - fmt.Printf("%v\n", e) - case io.EOF: - fmt.Printf("stream ended\n") - return + events, err := reader.Recv() + switch { + case err == nil: // no error, continue processing below + + case errors.Is(err, io.EOF): + fmt.Println("stream ended") + return nil + default: - fmt.Printf("%s:: remote error: %v\n", time.Now(), err) - return + return fmt.Errorf("remote error: %w", err) + } + + for _, ev := range events { + switch ev.Type { + case binlogdatapb.VEventType_FIELD: + switch ev.RowEvent.TableName { + case "customer": + customerFields = ev.FieldEvent.Fields + case "corder": + corderFields = ev.FieldEvent.Fields + } + + case binlogdatapb.VEventType_ROW: + // since our filter is "select * from customer", we could rely on that and not check the table name, + // but this shows how you might handle multiple tables in the same stream + switch ev.RowEvent.TableName { + case "customer": + var customer *Customer + customer, err = processCustomerRowEvent(customerFields, ev.RowEvent) + if err != nil { + return err + } + + customers = append(customers, customer) + + case "corder": + fmt.Printf("corder event: %v | fields: %v\n", ev.RowEvent, corderFields) + return errors.New("unexpected table name: " + ev.RowEvent.TableName) + } + + case binlogdatapb.VEventType_VGTID: + latestVgtid = ev.Vgtid + + case binlogdatapb.VEventType_COMMIT, binlogdatapb.VEventType_DDL, binlogdatapb.VEventType_OTHER: + // only flush when we have an event that guarantees we're not flushing mid-transaction + err = flushFunc() + if err != nil { + return err + } + + if ev.Type == binlogdatapb.VEventType_DDL { + // TODO: alter the destination schema based on the DDL event + } + + case binlogdatapb.VEventType_COPY_COMPLETED: + // TODO: don't flush until the copy is completed? do some sort of cleanup if we haven't received this? + } } } } + +func processCustomerRowEvent(fields []*querypb.Field, rowEvent *binlogdatapb.RowEvent) (*Customer, error) { + if fields == nil { + // Unreachable. + return nil, errors.New("internal error: unexpected rows without fields") + } + + var customer *Customer + var err error + + // TODO: I'm not exactly sure how to handle multiple rows in a single event, so I'm just going to take the last one + for _, rc := range rowEvent.RowChanges { + // ignore deletes + if rc.After == nil { + continue + } + + row := sqltypes.MakeRowTrusted(fields, rc.After) + + customer, err = rowToCustomer(fields, row) + if err != nil { + return nil, err + } + } + + return customer, nil +} + +// rowToCustomer builds a customer from a row event +func rowToCustomer(fields []*querypb.Field, row []sqltypes.Value) (*Customer, error) { + customer := &Customer{} + var err error + + for i := range row { + if row[i].IsNull() { + continue + } + + switch fields[i].Name { + case "workspace_id": + customer.ID, err = row[i].ToCastInt64() + + case "email": + customer.Email = row[i].ToString() + + // the fields below aren't actually in the schema, but are added to show how you might handle different data types + + case "email_confirmed": + customer.EmailConfirmed, err = row[i].ToBool() + + case "details": + // assume the details field is a json blob + var b []byte + b, err = row[i].ToBytes() + if err == nil { + err = json.Unmarshal(b, &customer.Details) + } + + case "created_at": + customer.CreatedAt, err = row[i].ToTime() + } + if err != nil { + return nil, fmt.Errorf("error processing field %s: %w", fields[i].Name, err) + } + } + + return customer, nil +} + +// upsertCustomersToDB is a placeholder for the function that would actually store the customers in the database, +// sync the data to another system, etc. +func upsertCustomersToDB(ctx context.Context, customers []*Customer) error { + fmt.Printf("upserting %d customers\n", len(customers)) + for i, customer := range customers { + fmt.Printf("upserting customer %d: %v\n", i, customer) + } + return nil +}