Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor boto3 client creation #23

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 30 additions & 60 deletions pydantic_settings_aws/aws.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, AnyStr, Dict, Optional, Type, Union
from typing import Any, AnyStr, Dict, Literal, Optional, Type, Union

import boto3 # type: ignore[import-untyped]
from pydantic import ValidationError
Expand All @@ -8,11 +8,15 @@
from .logger import logger
from .models import AwsSecretsArgs, AwsSession

AWSService = Literal["ssm", "secretsmanager"]

ClientParam = Literal["secrets_client", "ssm_client"]


def get_ssm_content(
settings: Type[BaseSettings],
field_name: str,
ssm_info: Optional[Union[Dict[Any, AnyStr], AnyStr]] = None
ssm_info: Optional[Union[Dict[Any, AnyStr], AnyStr]] = None,
) -> Optional[str]:
client = None
ssm_name = field_name
Expand All @@ -33,18 +37,20 @@ def get_ssm_content(

if not client:
logger.debug("Boto3 client not specified in metadata")
client = _get_ssm_boto3_client(settings)
client = _create_client_from_settings(settings, "ssm", "ssm_client")

logger.debug(f"Getting parameter {ssm_name} value with boto3 client")
ssm_response: Dict[str, Any] = client.get_parameter( # type: ignore
ssm_response: Dict[str, Any] = client.get_parameter( # type: ignore
Name=ssm_name, WithDecryption=True
)

return ssm_response.get("Parameter", {}).get("Value", None)


def get_secrets_content(settings: Type[BaseSettings]) -> Dict[str, Any]:
client = _get_secrets_boto3_client(settings)
client = _create_client_from_settings(
settings, "secretsmanager", "secrets_client"
)
secrets_args: AwsSecretsArgs = _get_secrets_args(settings)

logger.debug("Getting secrets manager value with boto3 client")
Expand All @@ -69,42 +75,6 @@ def get_secrets_content(settings: Type[BaseSettings]) -> Dict[str, Any]:
raise json_err


def _get_secrets_boto3_client( settings: Type[BaseSettings]): # type: ignore[no-untyped-def]
logger.debug("Getting secrets manager content.")
client = settings.model_config.get("secrets_client", None)

if client:
return client

logger.debug("No boto3 client was informed. Will try to create a new one")
return _create_secrets_client(settings)


def _create_secrets_client(settings: Type[BaseSettings]): # type: ignore[no-untyped-def]
"""Create a boto3 client for secrets manager.

Neither `boto3` nor `pydantic` exceptions will be handled.

Args:
settings (BaseSettings): Settings from `pydantic_settings`

Returns:
SecretsManagerClient: A secrets manager boto3 client.
"""
logger.debug("Extracting settings prefixed with aws_")
args: Dict[str, Any] = {
k: v for k, v in settings.model_config.items() if k.startswith("aws_")
}

session_args = AwsSession(**args)

session: boto3.Session = boto3.Session(
**session_args.model_dump(by_alias=True, exclude_none=True)
)

return session.client("secretsmanager")


def _get_secrets_args(settings: Type[BaseSettings]) -> AwsSecretsArgs:
logger.debug(
"Extracting settings prefixed with secrets_, except _client and _dir"
Expand Down Expand Up @@ -152,39 +122,39 @@ def _get_secrets_content(
return secrets_content


def _get_ssm_boto3_client(settings: Type[BaseSettings]): # type: ignore[no-untyped-def]
logger.debug("Getting secrets manager content.")
client = settings.model_config.get("ssm_client", None)
def _create_client_from_settings( # type: ignore[no-untyped-def]
settings: Type[BaseSettings], service: AWSService, client_param: ClientParam
):
client = settings.model_config.get(client_param)

if client:
return client

logger.debug(
"No ssm boto3 client was informed. Will try to create a new one"
)
return _create_ssm_client(settings)
logger.debug("Extracting settings prefixed with aws_")
args: Dict[str, Any] = {
k: v for k, v in settings.model_config.items() if k.startswith("aws_")
}

session_args = AwsSession(**args)

return _create_boto3_client(session_args, service)

def _create_ssm_client(settings: Type[BaseSettings]): # type: ignore[no-untyped-def]
"""Create a boto3 client for parameter store.

def _create_boto3_client(session_args: AwsSession, service: AWSService): # type: ignore[no-untyped-def]
"""Create a boto3 client for the service informed.

Neither `boto3` nor `pydantic` exceptions will be handled.

Args:
settings (BaseSettings): Settings from `pydantic_settings`
session_args (AwsSession): Settings informed in `SettingsConfigDict` to create
the boto3 session.
service (str): The service client that will be created.

Returns:
SSMClient: A parameter ssm boto3 client.
boto3.client: An aws service boto3 client.
"""
logger.debug("Extracting settings prefixed with aws_")
args: Dict[str, Any] = {
k: v for k, v in settings.model_config.items() if k.startswith("aws_")
}

session_args = AwsSession(**args)

session: boto3.Session = boto3.Session(
**session_args.model_dump(by_alias=True, exclude_none=True)
)

return session.client("ssm")
return session.client(service)
4 changes: 2 additions & 2 deletions tests/aws_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

TARGET_SSM_BOTO3_CLIENT = "pydantic_settings_aws.aws._get_ssm_boto3_client"

TARGET_SECRETS_CLIENT = "pydantic_settings_aws.aws._create_secrets_client"
TARGET_SECRETS_CLIENT = "pydantic_settings_aws.aws._create_boto3_client"

TARGET_SSM_CLIENT = "pydantic_settings_aws.aws._create_ssm_client"
TARGET_CREATE_CLIENT_FROM_SETTINGS = "pydantic_settings_aws.aws._create_client_from_settings"

TARGET_SECRET_CONTENT = "pydantic_settings_aws.aws._get_secrets_content"

Expand Down
26 changes: 13 additions & 13 deletions tests/aws_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from pydantic_settings_aws import aws

from .aws_mocks import (
TARGET_CREATE_CLIENT_FROM_SETTINGS,
TARGET_SECRET_CONTENT,
TARGET_SECRETS_BOTO3_CLIENT,
TARGET_SECRETS_CLIENT,
TARGET_SESSION,
TARGET_SSM_CLIENT,
BaseSettingsMock,
mock_create_client,
mock_secrets_content_empty,
Expand All @@ -21,7 +21,7 @@
from .boto3_mocks import SessionMock


@mock.patch(TARGET_SSM_CLIENT, mock_ssm)
@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_ssm)
def test_get_ssm_content_must_return_parameter_content_if_annotated_with_parameter_name(*args):
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
Expand All @@ -31,7 +31,7 @@ def test_get_ssm_content_must_return_parameter_content_if_annotated_with_paramet
assert isinstance(parameter_value, str)


@mock.patch(TARGET_SSM_CLIENT, mock_ssm)
@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_ssm)
def test_get_ssm_content_must_return_parameter_content_if_annotated_with_dict_args(*args):
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
Expand All @@ -41,7 +41,7 @@ def test_get_ssm_content_must_return_parameter_content_if_annotated_with_dict_ar
assert isinstance(parameter_value, str)


@mock.patch(TARGET_SSM_CLIENT, mock_ssm)
@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_ssm)
def test_get_ssm_content_must_use_client_if_present_in_metadata(*args):
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
Expand All @@ -51,7 +51,7 @@ def test_get_ssm_content_must_use_client_if_present_in_metadata(*args):
assert isinstance(parameter_value, str)


@mock.patch(TARGET_SSM_CLIENT, mock_ssm)
@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_ssm)
def test_get_ssm_content_must_use_field_name_if_ssm_name_not_in_metadata(*args):
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
Expand All @@ -65,16 +65,16 @@ def test_get_ssm_content_must_use_field_name_if_ssm_name_not_in_metadata(*args):
def test_create_ssm_client(*args):
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
client = aws._create_ssm_client(settings)
client = aws._create_client_from_settings(settings, "ssm", "ssm_client")

assert client is not None


@mock.patch(TARGET_SSM_CLIENT, mock_create_client)
@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_create_client)
def test_get_ssm_boto3_client_must_create_a_client_if_its_not_given(*args):
settings = BaseSettingsMock()
settings.model_config = {}
client = aws._get_ssm_boto3_client(settings)
client = aws._create_client_from_settings(settings, "ssm", "ssm_client")

assert client is not None

Expand All @@ -83,22 +83,22 @@ def test_get_ssm_boto3_client_must_create_a_client_if_its_not_given(*args):
def test_create_secrets_client(*args):
settings = BaseSettingsMock()
settings.model_config = {"aws_region": "region", "aws_profile": "profile"}
client = aws._create_secrets_client(settings)
client = aws._create_client_from_settings(settings, "secretsmanager", "secrets_client")

assert client is not None


@mock.patch(TARGET_SECRETS_CLIENT, mock_create_client)
@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_create_client)
def test_get_secrets_boto3_client_must_create_a_client_if_its_not_given(*args):
settings = BaseSettingsMock()
settings.model_config = {}
client = aws._get_secrets_boto3_client(settings)
client = aws._create_client_from_settings(settings, "secretsmanager", "secrets_client")

assert client is not None


@mock.patch(TARGET_SECRETS_BOTO3_CLIENT, mock_secrets_content_empty)
@mock.patch(TARGET_SECRET_CONTENT, lambda *args: None)
@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_secrets_content_empty)
def test_get_secrets_content_must_raise_value_error_if_secrets_content_is_none(
*args,
):
Expand All @@ -113,7 +113,7 @@ def test_get_secrets_content_must_raise_value_error_if_secrets_content_is_none(
aws.get_secrets_content(settings)


@mock.patch(TARGET_SECRETS_BOTO3_CLIENT, mock_secrets_content_invalid_json)
@mock.patch(TARGET_CREATE_CLIENT_FROM_SETTINGS, mock_secrets_content_invalid_json)
def test_should_not_obfuscate_json_error_in_case_of_invalid_secrets(*args):
settings = BaseSettingsMock()
settings.model_config = {
Expand Down
Loading