160 lines
4.0 KiB
Go
160 lines
4.0 KiB
Go
package llm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Client struct {
|
|
baseURL string
|
|
apiKey string
|
|
model string
|
|
http *http.Client
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type ChatInput struct {
|
|
System string `json:"system,omitempty"`
|
|
User string `json:"user,omitempty"`
|
|
Messages []Message `json:"messages,omitempty"`
|
|
Temperature *float64 `json:"temperature,omitempty"`
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
ResponseFormat json.RawMessage `json:"response_format,omitempty"`
|
|
}
|
|
|
|
type Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
type ChatResult struct {
|
|
Content string `json:"content"`
|
|
Model string `json:"model"`
|
|
Usage *Usage `json:"usage,omitempty"`
|
|
DurationMS int64 `json:"duration_ms"`
|
|
}
|
|
|
|
type chatRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
Temperature float64 `json:"temperature"`
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
ResponseFormat *json.RawMessage `json:"response_format,omitempty"`
|
|
}
|
|
|
|
type chatResponse struct {
|
|
Model string `json:"model,omitempty"`
|
|
Choices []struct {
|
|
Message Message `json:"message"`
|
|
} `json:"choices"`
|
|
Usage *Usage `json:"usage,omitempty"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
} `json:"error,omitempty"`
|
|
}
|
|
|
|
func New(baseURL, apiKey, model string, timeout time.Duration) *Client {
|
|
return &Client{
|
|
baseURL: strings.TrimRight(strings.TrimSpace(baseURL), "/"),
|
|
apiKey: apiKey,
|
|
model: model,
|
|
http: &http.Client{Timeout: timeout},
|
|
}
|
|
}
|
|
|
|
func (c *Client) Chat(ctx context.Context, in ChatInput) (*ChatResult, error) {
|
|
if c == nil || c.baseURL == "" {
|
|
return nil, fmt.Errorf("llm not configured")
|
|
}
|
|
messages := normalizeMessages(in)
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("messages are required")
|
|
}
|
|
temp := 0.1
|
|
if in.Temperature != nil {
|
|
temp = *in.Temperature
|
|
}
|
|
reqBody := chatRequest{
|
|
Model: c.model,
|
|
Messages: messages,
|
|
Temperature: temp,
|
|
MaxTokens: in.MaxTokens,
|
|
}
|
|
if len(in.ResponseFormat) > 0 {
|
|
reqBody.ResponseFormat = &in.ResponseFormat
|
|
}
|
|
body, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if c.apiKey != "" {
|
|
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
|
}
|
|
|
|
start := time.Now()
|
|
resp, err := c.http.Do(req)
|
|
duration := time.Since(start)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("llm do: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
raw, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("llm read: %w", err)
|
|
}
|
|
if resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("llm HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
|
}
|
|
var out chatResponse
|
|
if err := json.Unmarshal(raw, &out); err != nil {
|
|
return nil, fmt.Errorf("llm decode: %w", err)
|
|
}
|
|
if out.Error != nil {
|
|
return nil, fmt.Errorf("llm error: %s", out.Error.Message)
|
|
}
|
|
if len(out.Choices) == 0 {
|
|
return nil, fmt.Errorf("llm: empty choices")
|
|
}
|
|
modelName := out.Model
|
|
if modelName == "" {
|
|
modelName = c.model
|
|
}
|
|
return &ChatResult{
|
|
Content: out.Choices[0].Message.Content,
|
|
Model: modelName,
|
|
Usage: out.Usage,
|
|
DurationMS: duration.Milliseconds(),
|
|
}, nil
|
|
}
|
|
|
|
func normalizeMessages(in ChatInput) []Message {
|
|
if len(in.Messages) > 0 {
|
|
return in.Messages
|
|
}
|
|
var out []Message
|
|
if strings.TrimSpace(in.System) != "" {
|
|
out = append(out, Message{Role: "system", Content: in.System})
|
|
}
|
|
if strings.TrimSpace(in.User) != "" {
|
|
out = append(out, Message{Role: "user", Content: in.User})
|
|
}
|
|
return out
|
|
}
|