Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEATURE: partial tool call support for OpenAI and Anthropic #908

Merged
merged 5 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions lib/ai_bot/bot.rb
Original file line number Diff line number Diff line change
Expand Up @@ -185,20 +185,23 @@ def process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
end

def invoke_tool(tool, llm, cancel, context, &update_blk)
update_blk.call("", cancel, build_placeholder(tool.summary, ""))
show_placeholder = !context[:skip_tool_details]

update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder

result =
tool.invoke do |progress|
placeholder = build_placeholder(tool.summary, progress)
update_blk.call("", cancel, placeholder)
if show_placeholder
placeholder = build_placeholder(tool.summary, progress)
update_blk.call("", cancel, placeholder)
end
end

tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)

if context[:skip_tool_details] && tool.custom_raw.present?
update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
elsif !context[:skip_tool_details]
if show_placeholder
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
update_blk.call(tool_details, cancel, nil, :tool_details)
elsif tool.custom_raw.present?
update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
end

result
Expand Down
2 changes: 1 addition & 1 deletion lib/ai_bot/playground.rb
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def reply_to(post, custom_instructions: nil, &blk)
bot.reply(context) do |partial, cancel, placeholder, type|
reply << partial
raw = reply.dup
raw << "\n\n" << placeholder if placeholder.present? && !context[:skip_tool_details]
raw << "\n\n" << placeholder if placeholder.present?

blk.call(partial) if blk && type != :tool_details

Expand Down
36 changes: 32 additions & 4 deletions lib/completions/anthropic_message_processor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,50 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
class AnthropicToolCall
attr_reader :name, :raw_json, :id

def initialize(name, id)
def initialize(name, id, partial_tool_calls: false)
@name = name
@id = id
@raw_json = +""
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
@streaming_parser =
DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls
end

def append(json)
@raw_json << json
@streaming_parser << json if @streaming_parser
end

def notify_progress(key, value)
@tool_call.partial = true
@tool_call.parameters[key.to_sym] = value
@has_new_data = true
end

def has_partial?
@has_new_data
end

def partial_tool_call
@has_new_data = false
@tool_call
end

def to_tool_call
parameters = JSON.parse(raw_json, symbolize_names: true)
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
@tool_call.partial = false
@tool_call.parameters = parameters
@tool_call
end
end

attr_reader :tool_calls, :input_tokens, :output_tokens

def initialize(streaming_mode:)
def initialize(streaming_mode:, partial_tool_calls: false)
@streaming_mode = streaming_mode
@tool_calls = []
@current_tool_call = nil
@partial_tool_calls = partial_tool_calls
end

def to_tool_calls
Expand All @@ -38,11 +60,17 @@ def process_streamed_message(parsed)
tool_name = parsed.dig(:content_block, :name)
tool_id = parsed.dig(:content_block, :id)
result = @current_tool_call.to_tool_call if @current_tool_call
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
@current_tool_call =
AnthropicToolCall.new(
tool_name,
tool_id,
partial_tool_calls: @partial_tool_calls,
) if tool_name
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
if @current_tool_call
tool_delta = parsed.dig(:delta, :partial_json).to_s
@current_tool_call.append(tool_delta)
result = @current_tool_call.partial_tool_call if @current_tool_call.has_partial?
else
result = parsed.dig(:delta, :text).to_s
end
Expand Down
5 changes: 4 additions & 1 deletion lib/completions/endpoints/anthropic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def decode(response_data)

def processor
@processor ||=
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
DiscourseAi::Completions::AnthropicMessageProcessor.new(
streaming_mode: @streaming_mode,
partial_tool_calls: partial_tool_calls,
)
end

def has_tool?(_response_data)
Expand Down
4 changes: 4 additions & 0 deletions lib/completions/endpoints/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ module DiscourseAi
module Completions
module Endpoints
class Base
attr_reader :partial_tool_calls

CompletionFailed = Class.new(StandardError)
TIMEOUT = 60

Expand Down Expand Up @@ -58,8 +60,10 @@ def perform_completion!(
model_params = {},
feature_name: nil,
feature_context: nil,
partial_tool_calls: false,
&blk
)
@partial_tool_calls = partial_tool_calls
model_params = normalize_model_params(model_params)
orig_blk = blk

Expand Down
3 changes: 2 additions & 1 deletion lib/completions/endpoints/canned_response.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def perform_completion!(
_user,
_model_params,
feature_name: nil,
feature_context: nil
feature_context: nil,
partial_tool_calls: false
)
@dialect = dialect
response = responses[completions]
Expand Down
3 changes: 2 additions & 1 deletion lib/completions/endpoints/fake.rb
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def perform_completion!(
user,
model_params = {},
feature_name: nil,
feature_context: nil
feature_context: nil,
partial_tool_calls: false
)
last_call = { dialect: dialect, user: user, model_params: model_params }
self.class.last_call = last_call
Expand Down
17 changes: 12 additions & 5 deletions lib/completions/endpoints/open_ai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def perform_completion!(
model_params = {},
feature_name: nil,
feature_context: nil,
partial_tool_calls: false,
&blk
)
if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given?
Expand Down Expand Up @@ -103,10 +104,16 @@ def decode(response_raw)

def decode_chunk(chunk)
@decoder ||= JsonStreamDecoder.new
(@decoder << chunk)
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
.flatten
.compact
elements =
(@decoder << chunk)
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
.flatten
.compact

# Remove duplicate partial tool calls
# sometimes we stream weird chunks
seen_tools = Set.new
elements.select { |item| !item.is_a?(ToolCall) || seen_tools.add?(item) }
end

def decode_chunk_finish
Expand All @@ -120,7 +127,7 @@ def xml_tools_enabled?
private

def processor
@processor ||= OpenAiMessageProcessor.new
@processor ||= OpenAiMessageProcessor.new(partial_tool_calls: partial_tool_calls)
end
end
end
Expand Down
Loading