Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions Sources/AnyLanguageModel/LanguageModelSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,6 @@ public final class LanguageModelSession: @unchecked Sendable {
let relay = AsyncThrowingStream<ResponseStream<Content>.Snapshot, any Error> { continuation in
let stream = upstream
Task {
// Add prompt to transcript when stream starts
await MainActor.run {
session.transcript.append(promptEntry)
}

await session.beginResponding()
var lastSnapshot: ResponseStream<Content>.Snapshot?
do {
Expand Down Expand Up @@ -225,14 +220,15 @@ public final class LanguageModelSession: @unchecked Sendable {
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<Content> where Content: Generable {
// Create prompt entry that will be added when stream starts
// Add prompt to transcript
let promptEntry = Transcript.Entry.prompt(
Transcript.Prompt(
segments: [.text(.init(content: prompt.description))],
options: options,
responseFormat: nil
)
)
transcript.append(promptEntry)

return wrapStream(
model.streamResponse(
Expand Down
128 changes: 94 additions & 34 deletions Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,6 @@ public struct AnthropicLanguageModel: LanguageModel {
let url = baseURL.appendingPathComponent("v1/messages")
let headers = buildHeaders()

let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description)
let messages = [
AnthropicMessage(role: .user, content: convertSegmentsToAnthropicContent(userSegments))
]

// Convert available tools to Anthropic format
let anthropicTools: [AnthropicTool] = try session.tools.map { tool in
try convertToolToAnthropicFormat(tool)
Expand All @@ -338,7 +333,7 @@ public struct AnthropicLanguageModel: LanguageModel {
let params = try createMessageParams(
model: model,
system: nil,
messages: messages,
messages: session.transcript.toAnthropicMessages(),
tools: anthropicTools.isEmpty ? nil : anthropicTools,
options: options
)
Expand Down Expand Up @@ -396,11 +391,6 @@ public struct AnthropicLanguageModel: LanguageModel {
fatalError("AnthropicLanguageModel only supports generating String content")
}

let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description)
let messages = [
AnthropicMessage(role: .user, content: convertSegmentsToAnthropicContent(userSegments))
]

let url = baseURL.appendingPathComponent("v1/messages")

let stream: AsyncThrowingStream<LanguageModelSession.ResponseStream<Content>.Snapshot, any Error> = .init {
Expand All @@ -417,7 +407,7 @@ public struct AnthropicLanguageModel: LanguageModel {
var params = try createMessageParams(
model: model,
system: nil,
messages: messages,
messages: session.transcript.toAnthropicMessages(),
tools: anthropicTools.isEmpty ? nil : anthropicTools,
options: options
)
Expand Down Expand Up @@ -640,8 +630,77 @@ private func toGeneratedContent(_ value: [String: JSONValue]?) throws -> Generat
return try GeneratedContent(json: json)
}

private func fromGeneratedContent(_ content: GeneratedContent) throws -> [String: JSONValue] {
let data = try JSONEncoder().encode(content)
let jsonValue = try JSONDecoder().decode(JSONValue.self, from: data)

guard case .object(let dict) = jsonValue else {
return [:]
}
return dict
}

// MARK: - Supporting Types

extension Transcript {
fileprivate func toAnthropicMessages() -> [AnthropicMessage] {
var messages = [AnthropicMessage]()
for item in self {
switch item {
case .instructions(let instructions):
messages.append(
.init(
role: .user,
content: convertSegmentsToAnthropicContent(instructions.segments)
)
)
case .prompt(let prompt):
messages.append(
.init(
role: .user,
content: convertSegmentsToAnthropicContent(prompt.segments)
)
)
case .response(let response):
messages.append(
.init(
role: .assistant,
content: convertSegmentsToAnthropicContent(response.segments)
)
)
case .toolCalls(let toolCalls):
// Add assistant message with tool use blocks
let toolUseBlocks: [AnthropicContent] = toolCalls.map { call in
let input = try? fromGeneratedContent(call.arguments)
return .toolUse(AnthropicToolUse(
id: call.id,
name: call.toolName,
input: input
))
}
messages.append(
.init(
role: .assistant,
content: toolUseBlocks
)
)
case .toolOutput(let toolOutput):
// Add user message with tool result
messages.append(
.init(
role: .user,
content: [.toolResult(AnthropicToolResult(
toolUseId: toolOutput.id,
content: convertSegmentsToAnthropicContent(toolOutput.segments)
))]
)
)
}
}
return messages
}
}

private struct AnthropicTool: Codable, Sendable {
let name: String
let description: String
Expand All @@ -665,10 +724,11 @@ private enum AnthropicContent: Codable, Sendable {
case text(AnthropicText)
case image(AnthropicImage)
case toolUse(AnthropicToolUse)
case toolResult(AnthropicToolResult)

enum CodingKeys: String, CodingKey { case type }

enum ContentType: String, Codable { case text = "text", image = "image", toolUse = "tool_use" }
enum ContentType: String, Codable { case text = "text", image = "image", toolUse = "tool_use", toolResult = "tool_result" }

init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
Expand All @@ -680,6 +740,8 @@ private enum AnthropicContent: Codable, Sendable {
self = .image(try AnthropicImage(from: decoder))
case .toolUse:
self = .toolUse(try AnthropicToolUse(from: decoder))
case .toolResult:
self = .toolResult(try AnthropicToolResult(from: decoder))
}
}

Expand All @@ -688,6 +750,7 @@ private enum AnthropicContent: Codable, Sendable {
case .text(let t): try t.encode(to: encoder)
case .image(let i): try i.encode(to: encoder)
case .toolUse(let u): try u.encode(to: encoder)
case .toolResult(let r): try r.encode(to: encoder)
}
}
}
Expand Down Expand Up @@ -752,27 +815,6 @@ private func convertSegmentsToAnthropicContent(_ segments: [Transcript.Segment])
return blocks
}

private func extractPromptSegments(from session: LanguageModelSession, fallbackText: String) -> [Transcript.Segment] {
for entry in session.transcript.reversed() {
if case .prompt(let p) = entry {
// Skip prompts that are effectively empty (single empty text block)
let hasMeaningfulContent = p.segments.contains { segment in
switch segment {
case .text(let t):
return !t.content.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
case .structure:
return true
case .image:
return true
}
}
if hasMeaningfulContent { return p.segments }
// Otherwise continue searching older entries
}
}
return [.text(.init(content: fallbackText))]
}

private struct AnthropicToolUse: Codable, Sendable {
let type: String
let id: String
Expand All @@ -787,6 +829,24 @@ private struct AnthropicToolUse: Codable, Sendable {
}
}

private struct AnthropicToolResult: Codable, Sendable {
let type: String
let toolUseId: String
let content: [AnthropicContent]

enum CodingKeys: String, CodingKey {
case type
case toolUseId = "tool_use_id"
case content
}

init(toolUseId: String, content: [AnthropicContent]) {
self.type = "tool_result"
self.toolUseId = toolUseId
self.content = content
}
}

private struct AnthropicMessageResponse: Codable, Sendable {
let id: String
let type: String
Expand Down
Loading
Loading