From 7798aee65a76882fe771c4d98eb2962a2395f949 Mon Sep 17 00:00:00 2001 From: httpjamesm Date: Mon, 11 Sep 2023 17:24:47 -0400 Subject: [PATCH] feat: stop string slices --- README.md | 2 +- inference.go | 18 +++++++++--------- inference_test.go | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 073eaba..0ab7595 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ resp, err := client.NewInference(InferenceConfig{ Model: "togethercomputer/RedPajama-INCITE-7B-Instruct", Prompt: "The capital of France is", MaxTokens: 128, - Stop: &stopString, + Stop: &stopStrings, }) if err != nil { panic(err) diff --git a/inference.go b/inference.go index 29c7751..882621a 100644 --- a/inference.go +++ b/inference.go @@ -6,15 +6,15 @@ import ( ) type InferenceConfig struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - MaxTokens int32 `json:"max_tokens"` - Stop *string `json:"stop"` - Temperature *float32 `json:"temperature"` - TopP *float32 `json:"top_p"` - TopK *int32 `json:"top_k"` - RepetitionPenalty *float32 `json:"repetition_penalty"` - LogProbs *int32 `json:"logprobs"` + Model string `json:"model"` + Prompt string `json:"prompt"` + MaxTokens int32 `json:"max_tokens"` + Stop *[]string `json:"stop"` + Temperature *float32 `json:"temperature"` + TopP *float32 `json:"top_p"` + TopK *int32 `json:"top_k"` + RepetitionPenalty *float32 `json:"repetition_penalty"` + LogProbs *int32 `json:"logprobs"` } type inferenceRequestBody struct { diff --git a/inference_test.go b/inference_test.go index 80364ea..c6e5887 100644 --- a/inference_test.go +++ b/inference_test.go @@ -11,13 +11,13 @@ import ( func TestNewInference(t *testing.T) { client := NewClient(os.Getenv("TOGETHERAI_API_KEY")) - stopString := "*" + stopStrings := []string{"*"} respBody, err := client.NewInference(InferenceConfig{ Model: "togethercomputer/RedPajama-INCITE-7B-Instruct", Prompt: "The capital of France is", MaxTokens: 128, - Stop: &stopString, + Stop: &stopStrings, }) assert.NoError(t, err) assert.NotNil(t, respBody)