diff --git a/taosWS/connection.go b/taosWS/connection.go index 8228661..5038ea4 100644 --- a/taosWS/connection.go +++ b/taosWS/connection.go @@ -50,6 +50,7 @@ type taosConn struct { writeTimeout time.Duration cfg *config endpoint string + closed atomic.Bool // set when conn is closed, } func (tc *taosConn) generateReqID() uint64 { @@ -100,6 +101,10 @@ func (tc *taosConn) Begin() (driver.Tx, error) { } func (tc *taosConn) Close() (err error) { + if tc.closed.Swap(true) { + return nil + } + if tc.client != nil { err = tc.client.Close() } @@ -110,6 +115,9 @@ func (tc *taosConn) Close() (err error) { } func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { + if tc.closed.Load() { + return nil, driver.ErrBadConn + } stmtID, err := tc.stmtInit() if err != nil { return nil, err @@ -410,6 +418,9 @@ func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver } func (tc *taosConn) execCtx(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if tc.closed.Load() { + return nil, driver.ErrBadConn + } if len(args) != 0 { if !tc.cfg.interpolateParams { return nil, driver.ErrSkip @@ -463,6 +474,9 @@ func (tc *taosConn) QueryContext(ctx context.Context, query string, args []drive } func (tc *taosConn) queryCtx(_ context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if tc.closed.Load() { + return nil, driver.ErrBadConn + } if len(args) != 0 { if !tc.cfg.interpolateParams { return nil, driver.ErrSkip @@ -521,6 +535,9 @@ func (tc *taosConn) queryCtx(_ context.Context, query string, args []driver.Name } func (tc *taosConn) Ping(ctx context.Context) (err error) { + if tc.closed.Load() { + return driver.ErrBadConn + } return nil } diff --git a/taosWS/connection_test.go b/taosWS/connection_test.go index 6aee030..e93d710 100644 --- a/taosWS/connection_test.go +++ b/taosWS/connection_test.go @@ -46,3 +46,29 @@ func Test_formatBytes(t *testing.T) { }) } } + +func TestBadConnection(t *testing.T) { + defer func() { + if r := recover(); r != nil { + // bad connection should not panic + t.Fatalf("panic: %v", r) + } + }() + + cfg, err := parseDSN(dataSourceName) + if err != nil { + t.Fatalf("parseDSN error: %v", err) + } + conn, err := newTaosConn(cfg) + if err != nil { + t.Fatalf("newTaosConn error: %v", err) + } + + // to test bad connection, we manually close the connection + conn.Close() + + _, err = conn.Query("select 1", nil) + if err == nil { + t.Fatalf("query should fail") + } +} diff --git a/taosWS/statement.go b/taosWS/statement.go index d313820..cd1e910 100644 --- a/taosWS/statement.go +++ b/taosWS/statement.go @@ -28,6 +28,9 @@ type Stmt struct { } func (stmt *Stmt) Close() error { + if stmt.conn == nil || stmt.conn.closed.Load() { + return driver.ErrBadConn + } err := stmt.conn.stmtClose(stmt.stmtID) stmt.buffer.Reset() stmt.conn = nil @@ -42,6 +45,9 @@ func (stmt *Stmt) NumInput() int { } func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.conn.closed.Load() { + return nil, driver.ErrBadConn + } if stmt.conn == nil { return nil, driver.ErrBadConn } @@ -68,6 +74,9 @@ func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { } func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { + if stmt.conn.closed.Load() { + return nil, driver.ErrBadConn + } if stmt.conn == nil { return nil, driver.ErrBadConn }