diff --git a/changelog/28875.txt b/changelog/28875.txt new file mode 100644 index 000000000000..471920e9b26d --- /dev/null +++ b/changelog/28875.txt @@ -0,0 +1,3 @@ +```release-note:change +storage/raft: Do not allow nodes that have been removed from the raft cluster configuration to respond to requests. Shutdown and seal raft nodes when they are removed. +``` diff --git a/http/handler.go b/http/handler.go index c74f392aee75..d8f040251ff5 100644 --- a/http/handler.go +++ b/http/handler.go @@ -988,9 +988,14 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { // ErrCannotForward and we simply fall back statusCode, header, retBytes, err := core.ForwardRequest(r) if err != nil { - if err == vault.ErrCannotForward { + switch { + case errors.Is(err, vault.ErrCannotForward): core.Logger().Trace("cannot forward request (possibly disabled on active node), falling back to redirection to standby") - } else { + case errors.Is(err, vault.StatusNotHAMember): + core.Logger().Trace("this node is not a member of the HA cluster", "error", err) + respondError(w, http.StatusInternalServerError, err) + return + default: core.Logger().Error("forward request error", "error", err) } diff --git a/vault/core.go b/vault/core.go index d635165e3c12..12a9a9116936 100644 --- a/vault/core.go +++ b/vault/core.go @@ -4579,16 +4579,8 @@ func (c *Core) setupAuditedHeadersConfig(ctx context.Context) error { // RemovableNodeHABackend interface. The value of the `ok` result will be false // if the HA and underlyingPhysical backends are nil or do not support this operation. func (c *Core) IsRemovedFromCluster() (removed, ok bool) { - var haBackend any - if c.ha != nil { - haBackend = c.ha - } else if c.underlyingPhysical != nil { - haBackend = c.underlyingPhysical - } else { - return false, false - } - removableNodeHA, ok := haBackend.(physical.RemovableNodeHABackend) - if !ok { + removableNodeHA := c.getRemovableHABackend() + if removableNodeHA == nil { return false, false } diff --git a/vault/core_test.go b/vault/core_test.go index 0f2e0913fdc4..5a767cb09756 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -3724,7 +3724,7 @@ func TestCore_IsRemovedFromCluster(t *testing.T) { core.underlyingPhysical = mockHA removed, ok = core.IsRemovedFromCluster() if removed || !ok { - t.Fatalf("expected removed and ok to be false, got removed: %v, ok: %v", removed, ok) + t.Fatalf("expected removed to be false and ok to be true, got removed: %v, ok: %v", removed, ok) } // Test case where HA backend is nil, but the underlying physical is there, supports RemovableNodeHABackend, and is removed @@ -3735,6 +3735,7 @@ func TestCore_IsRemovedFromCluster(t *testing.T) { } // Test case where HA backend does not support RemovableNodeHABackend + core.underlyingPhysical = &MockHABackend{} core.ha = &MockHABackend{} removed, ok = core.IsRemovedFromCluster() if removed || ok { diff --git a/vault/external_tests/raft/raft_test.go b/vault/external_tests/raft/raft_test.go index 9f36486c3e4f..33dc48d67c43 100644 --- a/vault/external_tests/raft/raft_test.go +++ b/vault/external_tests/raft/raft_test.go @@ -1360,3 +1360,33 @@ func TestRaft_Join_InitStatus(t *testing.T) { verifyInitStatus(i, true) } } + +// TestRaftCluster_Removed creates a 3 node raft cluster and then removes one of +// the nodes. The test verifies that a write on the removed node errors, and that +// the removed node is sealed. +func TestRaftCluster_Removed(t *testing.T) { + t.Parallel() + cluster, _ := raftCluster(t, nil) + defer cluster.Cleanup() + + follower := cluster.Cores[2] + followerClient := follower.Client + _, err := followerClient.Logical().Write("secret/foo", map[string]interface{}{ + "test": "data", + }) + require.NoError(t, err) + + _, err = cluster.Cores[0].Client.Logical().Write("/sys/storage/raft/remove-peer", map[string]interface{}{ + "server_id": follower.NodeID, + }) + followerClient.SetCheckRedirect(func(request *http.Request, requests []*http.Request) error { + require.Fail(t, "request caused a redirect", request.URL.Path) + return fmt.Errorf("no redirects allowed") + }) + require.NoError(t, err) + _, err = followerClient.Logical().Write("secret/foo", map[string]interface{}{ + "test": "other_data", + }) + require.Error(t, err) + require.True(t, follower.Sealed()) +} diff --git a/vault/ha.go b/vault/ha.go index 9e063cfde985..46fc7f7757b4 100644 --- a/vault/ha.go +++ b/vault/ha.go @@ -1223,3 +1223,16 @@ func (c *Core) SetNeverBecomeActive(on bool) { atomic.StoreUint32(c.neverBecomeActive, 0) } } + +func (c *Core) getRemovableHABackend() physical.RemovableNodeHABackend { + var haBackend physical.RemovableNodeHABackend + if removableHA, ok := c.ha.(physical.RemovableNodeHABackend); ok { + haBackend = removableHA + } + + if removableHA, ok := c.underlyingPhysical.(physical.RemovableNodeHABackend); ok { + haBackend = removableHA + } + + return haBackend +} diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index 619222c344aa..51dfa3a9093f 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -22,13 +22,116 @@ import ( "github.com/hashicorp/vault/helper/forwarding" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/physical" "github.com/hashicorp/vault/vault/cluster" "github.com/hashicorp/vault/vault/replication" "golang.org/x/net/http2" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) +var ( + NotHAMember = "node is not in HA cluster membership" + StatusNotHAMember = status.Errorf(codes.FailedPrecondition, NotHAMember) +) + +const haNodeIDKey = "ha_node_id" + +func haIDFromContext(ctx context.Context) (string, bool) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", false + } + res := md.Get(haNodeIDKey) + if len(res) == 0 { + return "", false + } + return res[0], true +} + +// haMembershipServerCheck extracts the client's HA node ID from the context +// and checks if this client has been removed. The function returns +// StatusNotHAMember if the client has been removed +func haMembershipServerCheck(ctx context.Context, c *Core, haBackend physical.RemovableNodeHABackend) error { + if haBackend == nil { + return nil + } + nodeID, ok := haIDFromContext(ctx) + if !ok { + return nil + } + removed, err := haBackend.IsNodeRemoved(ctx, nodeID) + if err != nil { + c.logger.Error("failed to check if node is removed", "error", err) + return err + } + if removed { + return StatusNotHAMember + } + return nil +} + +func haMembershipUnaryServerInterceptor(c *Core, haBackend physical.RemovableNodeHABackend) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + err = haMembershipServerCheck(ctx, c, haBackend) + if err != nil { + return nil, err + } + return handler(ctx, req) + } +} + +func haMembershipStreamServerInterceptor(c *Core, haBackend physical.RemovableNodeHABackend) grpc.StreamServerInterceptor { + return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + err := haMembershipServerCheck(ss.Context(), c, haBackend) + if err != nil { + return err + } + return handler(srv, ss) + } +} + +// haMembershipClientCheck checks if the given error from the server +// is StatusNotHAMember. If so, the client will mark itself as removed +// and shutdown +func haMembershipClientCheck(err error, c *Core, haBackend physical.RemovableNodeHABackend) { + if !errors.Is(err, StatusNotHAMember) { + return + } + removeErr := haBackend.RemoveSelf() + if removeErr != nil { + c.logger.Debug("failed to remove self", "error", removeErr) + } + go c.ShutdownCoreError(errors.New("node removed from HA configuration")) +} + +func haMembershipUnaryClientInterceptor(c *Core, haBackend physical.RemovableNodeHABackend) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if haBackend == nil { + return invoker(ctx, method, req, reply, cc, opts...) + } + ctx = metadata.AppendToOutgoingContext(ctx, haNodeIDKey, haBackend.NodeID()) + err := invoker(ctx, method, req, reply, cc, opts...) + haMembershipClientCheck(err, c, haBackend) + return err + } +} + +func haMembershipStreamClientInterceptor(c *Core, haBackend physical.RemovableNodeHABackend) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if haBackend == nil { + return streamer(ctx, desc, cc, method, opts...) + } + ctx = metadata.AppendToOutgoingContext(ctx, haNodeIDKey, haBackend.NodeID()) + stream, err := streamer(ctx, desc, cc, method, opts...) + haMembershipClientCheck(err, c, haBackend) + return stream, err + } +} + type requestForwardingHandler struct { fws *http2.Server fwRPCServer *grpc.Server @@ -47,6 +150,7 @@ type requestForwardingClusterClient struct { func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots chan struct{}, perfStandbyRepCluster *replication.Cluster) (*requestForwardingHandler, error) { // Resolve locally to avoid races ha := c.ha != nil + removableHABackend := c.getRemovableHABackend() fwRPCServer := grpc.NewServer( grpc.KeepaliveParams(keepalive.ServerParameters{ @@ -54,6 +158,8 @@ func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots ch }), grpc.MaxRecvMsgSize(math.MaxInt32), grpc.MaxSendMsgSize(math.MaxInt32), + grpc.StreamInterceptor(haMembershipStreamServerInterceptor(c, removableHABackend)), + grpc.UnaryInterceptor(haMembershipUnaryServerInterceptor(c, removableHABackend)), ) if ha && c.clusterHandler != nil { @@ -274,6 +380,8 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd core: c, }) + removableHABackend := c.getRemovableHABackend() + // Set up grpc forwarding handling // It's not really insecure, but we have to dial manually to get the // ALPN header right. It's just "insecure" because GRPC isn't managing @@ -285,6 +393,8 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 2 * c.clusterHeartbeatInterval, }), + grpc.WithStreamInterceptor(haMembershipStreamClientInterceptor(c, removableHABackend)), + grpc.WithUnaryInterceptor(haMembershipUnaryClientInterceptor(c, removableHABackend)), grpc.WithDefaultCallOptions( grpc.MaxCallRecvMsgSize(math.MaxInt32), grpc.MaxCallSendMsgSize(math.MaxInt32), @@ -374,6 +484,10 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro if err != nil { metrics.IncrCounter([]string{"ha", "rpc", "client", "forward", "errors"}, 1) c.logger.Error("error during forwarded RPC request", "error", err) + + if errors.Is(err, StatusNotHAMember) { + return 0, nil, nil, fmt.Errorf("error during forwarding RPC request: %w", err) + } return 0, nil, nil, fmt.Errorf("error during forwarding RPC request") } diff --git a/vault/request_forwarding_test.go b/vault/request_forwarding_test.go new file mode 100644 index 000000000000..9df49b5a99fa --- /dev/null +++ b/vault/request_forwarding_test.go @@ -0,0 +1,131 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package vault + +import ( + "context" + "errors" + "testing" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/physical" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" +) + +// Test_haIDFromContext verifies that the HA node ID gets correctly extracted +// from a gRPC context +func Test_haIDFromContext(t *testing.T) { + testCases := []struct { + name string + md metadata.MD + wantID string + wantOk bool + }{ + { + name: "no ID", + md: metadata.MD{}, + wantID: "", + wantOk: false, + }, + { + name: "with ID", + md: metadata.MD{haNodeIDKey: {"node_id"}}, + wantID: "node_id", + wantOk: true, + }, + { + name: "with empty string ID", + md: metadata.MD{haNodeIDKey: {""}}, + wantID: "", + wantOk: true, + }, + { + name: "with empty ID", + md: metadata.MD{haNodeIDKey: {}}, + wantID: "", + wantOk: false, + }, + + { + name: "with multiple IDs", + md: metadata.MD{haNodeIDKey: {"1", "2"}}, + wantID: "1", + wantOk: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), tc.md) + id, ok := haIDFromContext(ctx) + require.Equal(t, tc.wantID, id) + require.Equal(t, tc.wantOk, ok) + }) + } +} + +type mockHARemovableNodeBackend struct { + physical.RemovableNodeHABackend + isRemoved func(context.Context, string) (bool, error) +} + +func (m *mockHARemovableNodeBackend) IsNodeRemoved(ctx context.Context, nodeID string) (bool, error) { + return m.isRemoved(ctx, nodeID) +} + +func newMockHARemovableNodeBackend(isRemoved func(context.Context, string) (bool, error)) physical.RemovableNodeHABackend { + return &mockHARemovableNodeBackend{isRemoved: isRemoved} +} + +// Test_haMembershipServerCheck verifies that the correct error is returned +// when the context contains a removed node ID +func Test_haMembershipServerCheck(t *testing.T) { + nodeIDCtx := metadata.NewIncomingContext(context.Background(), metadata.MD{haNodeIDKey: {"node_id"}}) + otherErr := errors.New("error checking") + testCases := []struct { + name string + nodeIDCtx context.Context + haBackend physical.RemovableNodeHABackend + wantError error + }{ + { + name: "nil backend", + haBackend: nil, + nodeIDCtx: nodeIDCtx, + }, { + name: "no node ID context", + haBackend: newMockHARemovableNodeBackend(func(ctx context.Context, s string) (bool, error) { + return false, nil + }), + nodeIDCtx: context.Background(), + }, { + name: "node removed", + haBackend: newMockHARemovableNodeBackend(func(ctx context.Context, s string) (bool, error) { + return true, nil + }), + nodeIDCtx: nodeIDCtx, + wantError: StatusNotHAMember, + }, { + name: "node removed err", + haBackend: newMockHARemovableNodeBackend(func(ctx context.Context, s string) (bool, error) { + return false, otherErr + }), + nodeIDCtx: nodeIDCtx, + wantError: otherErr, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := &Core{ + logger: hclog.NewNullLogger(), + } + err := haMembershipServerCheck(tc.nodeIDCtx, c, tc.haBackend) + if tc.wantError != nil { + require.EqualError(t, err, tc.wantError.Error()) + } else { + require.NoError(t, err) + } + }) + } +}