Skip to content

Commit

Permalink
fix(defaults): set better defaults for inferencing
Browse files Browse the repository at this point in the history
This changeset aim to have better defaults and to properly detect when
no inference settings are provided with the model.

If not specified, we defaults to mirostat sampling, and offload all the
GPU layers (if a GPU is detected).

Related to #1373 and #1723
  • Loading branch information
mudler committed Mar 12, 2024
1 parent fa4c582 commit 79aa1e6
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 125 deletions.
2 changes: 1 addition & 1 deletion core/backend/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo

opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(backendConfig.Threads)),
model.WithThreads(uint32(*backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
Expand Down
6 changes: 3 additions & 3 deletions core/backend/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import (

func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
threads := backendConfig.Threads
if threads == 0 && appConfig.Threads != 0 {
threads = appConfig.Threads
if *threads == 0 && appConfig.Threads != 0 {
threads = &appConfig.Threads
}
gRPCOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(backendConfig.Backend),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithThreads(uint32(threads)),
model.WithThreads(uint32(*threads)),
model.WithContext(appConfig.Context),
model.WithModel(backendConfig.Model),
model.WithLoadGRPCLoadModelOpts(gRPCOpts),
Expand Down
6 changes: 3 additions & 3 deletions core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ type TokenUsage struct {
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
threads := c.Threads
if threads == 0 && o.Threads != 0 {
threads = o.Threads
if *threads == 0 && o.Threads != 0 {
threads = &o.Threads
}
grpcOpts := gRPCModelOpts(c)

Expand All @@ -39,7 +39,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode

opts := modelOpts(c, o, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(threads)), // some models uses this to allocate threads during startup
model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
Expand Down
41 changes: 21 additions & 20 deletions core/backend/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
F16Memory: c.F16,
F16Memory: *c.F16,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
ContextSize: int32(c.ContextSize),
ContextSize: int32(*c.ContextSize),
Seed: int32(c.Seed),
NBatch: int32(b),
NoMulMatQ: c.NoMulMatQ,
Expand All @@ -72,18 +72,18 @@ func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
YarnBetaSlow: c.YarnBetaSlow,
NGQA: c.NGQA,
RMSNormEps: c.RMSNormEps,
MLock: c.MMlock,
MLock: *c.MMlock,
RopeFreqBase: c.RopeFreqBase,
RopeScaling: c.RopeScaling,
Type: c.ModelType,
RopeFreqScale: c.RopeFreqScale,
NUMA: c.NUMA,
Embeddings: c.Embeddings,
LowVRAM: c.LowVRAM,
NGPULayers: int32(c.NGPULayers),
MMap: c.MMap,
LowVRAM: *c.LowVRAM,
NGPULayers: int32(*c.NGPULayers),
MMap: *c.MMap,
MainGPU: c.MainGPU,
Threads: int32(c.Threads),
Threads: int32(*c.Threads),
TensorSplit: c.TensorSplit,
// AutoGPTQ
ModelBaseName: c.AutoGPTQ.ModelBaseName,
Expand All @@ -102,36 +102,37 @@ func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOption
os.MkdirAll(filepath.Dir(p), 0755)
promptCachePath = p
}

return &pb.PredictOptions{
Temperature: float32(c.Temperature),
TopP: float32(c.TopP),
Temperature: float32(*c.Temperature),
TopP: float32(*c.TopP),
NDraft: c.NDraft,
TopK: int32(c.TopK),
Tokens: int32(c.Maxtokens),
Threads: int32(c.Threads),
TopK: int32(*c.TopK),
Tokens: int32(*c.Maxtokens),
Threads: int32(*c.Threads),
PromptCacheAll: c.PromptCacheAll,
PromptCacheRO: c.PromptCacheRO,
PromptCachePath: promptCachePath,
F16KV: c.F16,
DebugMode: c.Debug,
F16KV: *c.F16,
DebugMode: *c.Debug,
Grammar: c.Grammar,
NegativePromptScale: c.NegativePromptScale,
RopeFreqBase: c.RopeFreqBase,
RopeFreqScale: c.RopeFreqScale,
NegativePrompt: c.NegativePrompt,
Mirostat: int32(c.LLMConfig.Mirostat),
MirostatETA: float32(c.LLMConfig.MirostatETA),
MirostatTAU: float32(c.LLMConfig.MirostatTAU),
Debug: c.Debug,
Mirostat: int32(*c.LLMConfig.Mirostat),
MirostatETA: float32(*c.LLMConfig.MirostatETA),
MirostatTAU: float32(*c.LLMConfig.MirostatTAU),
Debug: *c.Debug,
StopPrompts: c.StopWords,
Repeat: int32(c.RepeatPenalty),
NKeep: int32(c.Keep),
Batch: int32(c.Batch),
IgnoreEOS: c.IgnoreEOS,
Seed: int32(c.Seed),
FrequencyPenalty: float32(c.FrequencyPenalty),
MLock: c.MMlock,
MMap: c.MMap,
MLock: *c.MMlock,
MMap: *c.MMap,
MainGPU: c.MainGPU,
TensorSplit: c.TensorSplit,
TailFreeSamplingZ: float32(c.TFZ),
Expand Down
4 changes: 2 additions & 2 deletions core/backend/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func ModelTranscription(audio, language string, ml *model.ModelLoader, backendCo
model.WithBackendString(model.WhisperBackend),
model.WithModel(backendConfig.Model),
model.WithContext(appConfig.Context),
model.WithThreads(uint32(backendConfig.Threads)),
model.WithThreads(uint32(*backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
})

Expand All @@ -33,6 +33,6 @@ func ModelTranscription(audio, language string, ml *model.ModelLoader, backendCo
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Threads: uint32(backendConfig.Threads),
Threads: uint32(*backendConfig.Threads),
})
}
Loading

0 comments on commit 79aa1e6

Please sign in to comment.