Skip to content

Commit

Permalink
deprecate state manager
Browse files Browse the repository at this point in the history
  • Loading branch information
ajar98 committed Jul 3, 2024
1 parent fc6d087 commit de03ec3
Show file tree
Hide file tree
Showing 34 changed files with 538 additions and 566 deletions.
8 changes: 4 additions & 4 deletions playground/streaming/agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vocode.streaming.models.message import BaseMessage
from vocode.streaming.models.transcript import Transcript
from vocode.streaming.pipeline.worker import InterruptibleAgentResponseEvent, QueueConsumer
from vocode.streaming.utils.state_manager import AbstractConversationStateManager
from vocode.streaming.streaming_conversation import StreamingConversation

load_dotenv()

Expand Down Expand Up @@ -75,7 +75,7 @@ def create_action(self, action_config: ActionConfig) -> BaseAction:
raise Exception("Invalid action type")


class DummyConversationManager(AbstractConversationStateManager):
class DummyStreamingConversation(StreamingConversation):
"""For use with Agents operating in a non-call context."""

def __init__(
Expand Down Expand Up @@ -192,7 +192,7 @@ async def sender():
)
actions_worker.consumer = agent
agent.actions_consumer = actions_worker
actions_worker.attach_conversation_state_manager(agent.conversation_state_manager)
actions_worker.pipeline = agent.streaming_conversation
actions_worker.start()

await asyncio.gather(receiver(), sender())
Expand Down Expand Up @@ -226,7 +226,7 @@ async def agent_main():
),
action_factory=ShoutActionFactory(),
)
agent.attach_conversation_state_manager(DummyConversationManager())
agent.streaming_conversation = DummyStreamingConversation()
agent.attach_transcript(transcript)
agent.start()

Expand Down
109 changes: 107 additions & 2 deletions tests/fakedata/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,24 @@
from vocode.streaming.models.audio import AudioEncoding
from vocode.streaming.models.message import BaseMessage
from vocode.streaming.models.synthesizer import PlayHtSynthesizerConfig, SynthesizerConfig
from vocode.streaming.models.telephony import PhoneCallDirection, TwilioConfig, VonageConfig
from vocode.streaming.models.transcriber import DeepgramTranscriberConfig, TranscriberConfig
from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice
from vocode.streaming.output_device.audio_chunk import ChunkState
from vocode.streaming.streaming_conversation import StreamingConversation
from vocode.streaming.streaming_conversation import (
StreamingConversation,
StreamingConversationFactory,
)
from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer
from vocode.streaming.telephony.config_manager.base_config_manager import BaseConfigManager
from vocode.streaming.telephony.constants import DEFAULT_CHUNK_SIZE, DEFAULT_SAMPLING_RATE
from vocode.streaming.transcriber.base_transcriber import BaseTranscriber
from vocode.streaming.telephony.conversation.twilio_phone_conversation import (
TwilioPhoneConversation,
)
from vocode.streaming.telephony.conversation.vonage_phone_conversation import (
VonagePhoneConversation,
)
from vocode.streaming.transcriber.base_transcriber import AbstractTranscriber, BaseTranscriber
from vocode.streaming.transcriber.deepgram_transcriber import DeepgramEndpointingConfig
from vocode.streaming.utils.events_manager import EventsManager

Expand Down Expand Up @@ -95,18 +106,47 @@ def create_fake_transcriber(mocker: MockerFixture, transcriber_config: Transcrib
return transcriber


def create_fake_transcriber_factory(
mocker: MockerFixture, transcriber: Optional[AbstractTranscriber] = None
):
factory = mocker.MagicMock()
factory.create_transcriber = mocker.MagicMock(
return_value=transcriber
or create_fake_transcriber(mocker, DEFAULT_DEEPGRAM_TRANSCRIBER_CONFIG)
)
return factory


def create_fake_agent(mocker: MockerFixture, agent_config: AgentConfig):
agent = mocker.MagicMock()
agent.get_agent_config = mocker.MagicMock(return_value=agent_config)
return agent


def create_fake_agent_factory(mocker: MockerFixture, agent: Optional[BaseAgent] = None):
factory = mocker.MagicMock()
factory.create_agent = mocker.MagicMock(
return_value=agent or create_fake_agent(mocker, DEFAULT_CHAT_GPT_AGENT_CONFIG)
)
return factory


def create_fake_synthesizer(mocker: MockerFixture, synthesizer_config: SynthesizerConfig):
synthesizer = mocker.MagicMock()
synthesizer.get_synthesizer_config = mocker.MagicMock(return_value=synthesizer_config)
return synthesizer


def create_fake_synthesizer_factory(
mocker: MockerFixture, synthesizer: Optional[BaseSynthesizer] = None
):
factory = mocker.MagicMock()
factory.create_synthesizer = mocker.MagicMock(
return_value=synthesizer or create_fake_synthesizer(mocker, DEFAULT_SYNTHESIZER_CONFIG)
)
return factory


def create_fake_streaming_conversation(
mocker: MockerFixture,
transcriber: Optional[BaseTranscriber[TranscriberConfig]] = None,
Expand All @@ -132,3 +172,68 @@ def create_fake_streaming_conversation(
conversation_id=conversation_id,
events_manager=events_manager,
)


def create_fake_streaming_conversation_factory(
mocker: MockerFixture,
transcriber: Optional[BaseTranscriber[TranscriberConfig]] = None,
agent: Optional[BaseAgent] = None,
synthesizer: Optional[BaseSynthesizer] = None,
):
return StreamingConversationFactory(
transcriber_factory=create_fake_transcriber_factory(mocker, transcriber),
agent_factory=create_fake_agent_factory(mocker, agent),
synthesizer_factory=create_fake_synthesizer_factory(mocker, synthesizer),
)


def create_fake_twilio_phone_conversation_with_streaming_conversation_pipeline(
mocker: MockerFixture,
streaming_conversation_factory: StreamingConversationFactory,
direction: PhoneCallDirection = "outbound",
from_phone: str = "+1234567890",
to_phone: str = "+0987654321",
base_url: str = "http://test.com",
twilio_sid: str = "test_sid",
twilio_config: Optional[TwilioConfig] = None,
config_manager: Optional[BaseConfigManager] = None,
events_manager: Optional[EventsManager] = None,
):
return TwilioPhoneConversation(
direction=direction,
from_phone=from_phone,
to_phone=to_phone,
base_url=base_url,
config_manager=config_manager,
pipeline_factory=streaming_conversation_factory,
pipeline_config=mocker.MagicMock(),
twilio_sid=twilio_sid,
twilio_config=twilio_config,
events_manager=events_manager,
)


def create_fake_vonage_phone_conversation_with_streaming_conversation_pipeline(
mocker: MockerFixture,
streaming_conversation_factory: StreamingConversationFactory,
direction: PhoneCallDirection = "outbound",
from_phone: str = "+1234567890",
to_phone: str = "+0987654321",
base_url: str = "http://test.com",
vonage_uuid: str = "test_uuid",
vonage_config: Optional[VonageConfig] = None,
config_manager: Optional[BaseConfigManager] = None,
events_manager: Optional[EventsManager] = None,
):
return VonagePhoneConversation(
direction=direction,
from_phone=from_phone,
to_phone=to_phone,
base_url=base_url,
config_manager=config_manager,
pipeline_factory=streaming_conversation_factory,
pipeline_config=mocker.MagicMock(),
vonage_uuid=vonage_uuid,
vonage_config=vonage_config or mocker.MagicMock(),
events_manager=events_manager,
)
42 changes: 28 additions & 14 deletions tests/streaming/action/test_dtmf.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import pytest
from aioresponses import aioresponses

from tests.fakedata.conversation import (
create_fake_agent,
create_fake_streaming_conversation_factory,
create_fake_vonage_phone_conversation_with_streaming_conversation_pipeline,
)
from tests.fakedata.id import generate_uuid
from vocode.streaming.action.dtmf import (
DTMFParameters,
DTMFVocodeActionConfig,
TwilioDTMF,
VonageDTMF,
)
from vocode.streaming.models.actions import (
TwilioPhoneConversationActionInput,
VonagePhoneConversationActionInput,
)
from vocode.streaming.models.actions import ActionInput
from vocode.streaming.models.agent import ChatGPTAgentConfig
from vocode.streaming.models.telephony import VonageConfig
from vocode.streaming.utils import create_conversation_id
from vocode.streaming.utils.state_manager import VonagePhoneConversationStateManager


@pytest.mark.asyncio
Expand All @@ -23,22 +25,35 @@ async def test_vonage_dtmf_press_digits(mocker, mock_env):
vonage_uuid = generate_uuid()
digits = "1234"

vonage_phone_conversation_mock = mocker.MagicMock()
vonage_config = VonageConfig(
api_key="api_key",
api_secret="api_secret",
application_id="application_id",
private_key="-----BEGIN PRIVATE KEY-----\nasdf\n-----END PRIVATE KEY-----",
)
vonage_phone_conversation_mock.vonage_config = vonage_config
vonage_phone_conversation_mock = (
create_fake_vonage_phone_conversation_with_streaming_conversation_pipeline(
mocker,
streaming_conversation_factory=create_fake_streaming_conversation_factory(
mocker,
agent=create_fake_agent(
mocker,
agent_config=ChatGPTAgentConfig(
prompt_preamble="",
actions=[action.action_config],
),
),
),
vonage_config=vonage_config,
vonage_uuid=vonage_uuid,
)
)
mocker.patch("vonage.Client._create_jwt_auth_string", return_value=b"asdf")

action.attach_conversation_state_manager(
VonagePhoneConversationStateManager(vonage_phone_conversation_mock)
)
vonage_phone_conversation_mock.pipeline.actions_worker.attach_state(action)

assert (
action.conversation_state_manager.create_vonage_client().get_telephony_config()
vonage_phone_conversation_mock.create_vonage_client().get_telephony_config()
== vonage_config
)

Expand All @@ -48,11 +63,10 @@ async def test_vonage_dtmf_press_digits(mocker, mock_env):
status=200,
)
action_output = await action.run(
action_input=VonagePhoneConversationActionInput(
action_input=ActionInput(
action_config=DTMFVocodeActionConfig(),
conversation_id=create_conversation_id(),
params=DTMFParameters(buttons=digits),
vonage_uuid=str(vonage_uuid),
)
)

Expand All @@ -66,7 +80,7 @@ async def test_twilio_dtmf_press_digits(mocker, mock_env):
twilio_sid = "twilio_sid"

action_output = await action.run(
action_input=TwilioPhoneConversationActionInput(
action_input=ActionInput(
action_config=DTMFVocodeActionConfig(),
conversation_id=create_conversation_id(),
params=DTMFParameters(buttons=digits),
Expand Down
Loading

0 comments on commit de03ec3

Please sign in to comment.