Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(middleware): add match wrappers #566

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions internal/xmaps/set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package xmaps

// BuildSet builds a set from the given values.
func BuildSet[K comparable](s ...K) map[K]struct{} {
r := make(map[K]struct{}, len(s))
for _, v := range s {
r[v] = struct{}{}
}
return r
}
79 changes: 79 additions & 0 deletions middleware/match.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package middleware

import (
"regexp"

"github.com/ogen-go/ogen/internal/xmaps"
)

// OperationID calls the next middleware if request operation ID matches the given operationID.
func OperationID(m Middleware, operationID ...string) Middleware {
switch len(operationID) {
case 0:
return justCallNext
case 1:
val := operationID[0]
return func(req Request, next Next) (Response, error) {
if req.OperationID == val {
return m(req, next)
}
return next(req)
}
default:
set := xmaps.BuildSet(operationID...)
return func(req Request, next Next) (Response, error) {
if _, ok := set[req.OperationID]; ok {
return m(req, next)
}
return next(req)
}
}
}

// OperationName calls the next middleware if request operation name matches the given operationName.
func OperationName(m Middleware, operationName ...string) Middleware {
switch len(operationName) {
case 0:
return justCallNext
case 1:
val := operationName[0]
return func(req Request, next Next) (Response, error) {
if req.OperationName == val {
return m(req, next)
}
return next(req)
}
default:
set := xmaps.BuildSet(operationName...)
return func(req Request, next Next) (Response, error) {
if _, ok := set[req.OperationName]; ok {
return m(req, next)
}
return next(req)
}
}
}

// PathRegex calls the next middleware if request path matches the given regex.
func PathRegex(re *regexp.Regexp, m Middleware) Middleware {
if re == nil {
return justCallNext
}

return func(req Request, next Next) (Response, error) {
if re.MatchString(req.Raw.URL.Path) {
return m(req, next)
}
return next(req)
}
}

// BodyType calls the next middleware if request body type matches the given type.
func BodyType[T any](m Middleware) Middleware {
return func(req Request, next Next) (Response, error) {
if _, ok := req.Body.(T); ok {
return m(req, next)
}
return next(req)
}
}
8 changes: 5 additions & 3 deletions middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,14 @@ type (
Middleware func(req Request, next Next) (Response, error)
)

func justCallNext(req Request, next Next) (Response, error) {
return next(req)
}

// ChainMiddlewares chains middlewares into a single middleware, which will be executed in the order they are passed.
func ChainMiddlewares(m ...Middleware) Middleware {
if len(m) == 0 {
return func(req Request, next Next) (Response, error) {
return next(req)
}
return justCallNext
}
tail := ChainMiddlewares(m[1:]...)
return func(req Request, next Next) (Response, error) {
Expand Down
5 changes: 1 addition & 4 deletions middleware/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,11 @@ func TestChainMiddlewares(t *testing.T) {

func BenchmarkChainMiddlewares(b *testing.B) {
const N = 20
noop := func(req Request, next Next) (Response, error) {
return next(req)
}

var (
chain = ChainMiddlewares(func() (r []Middleware) {
for i := 0; i < N; i++ {
r = append(r, noop)
r = append(r, justCallNext)
}
return r
}()...)
Expand Down