Skip to content

Commit

Permalink
fix: be consistent in downloading files, check for scanner errors (#3108
Browse files Browse the repository at this point in the history
)

* fix(downloader): be consistent in downloading files

This PR puts some order in the downloader such as functions are re-used
across several places.

This fixes an issue with having uri's inside the model YAML file, it
would resolve to MD5 rather then using the filename

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(scanner): do raise error only if unsafeFiles are found

Fixes: #3114

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
  • Loading branch information
mudler committed Aug 2, 2024
1 parent fc50a90 commit a36b721
Show file tree
Hide file tree
Showing 13 changed files with 173 additions and 171 deletions.
4 changes: 3 additions & 1 deletion core/cli/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
return err
}

if !downloader.LooksLikeOCI(modelName) {
modelURI := downloader.URI(modelName)

if !modelURI.LooksLikeOCI() {
model := gallery.FindModel(models, modelName, mi.ModelsPath)
if model == nil {
log.Error().Str("model", modelName).Msg("model not found")
Expand Down
4 changes: 2 additions & 2 deletions core/cli/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error {
var errs error = nil
for _, uri := range hfscmd.ToScan {
log.Info().Str("uri", uri).Msg("scanning specific uri")
scanResults, err := downloader.HuggingFaceScan(uri)
if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) {
scanResults, err := downloader.HuggingFaceScan(downloader.URI(uri))
if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) {
log.Error().Err(err).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("! WARNING ! A known-vulnerable model is included in this repo!")
errs = errors.Join(errs, err)
}
Expand Down
27 changes: 15 additions & 12 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/utils"
)

const (
Expand Down Expand Up @@ -72,9 +71,9 @@ type BackendConfig struct {
}

type File struct {
Filename string `yaml:"filename" json:"filename"`
SHA256 string `yaml:"sha256" json:"sha256"`
URI string `yaml:"uri" json:"uri"`
Filename string `yaml:"filename" json:"filename"`
SHA256 string `yaml:"sha256" json:"sha256"`
URI downloader.URI `yaml:"uri" json:"uri"`
}

type VallE struct {
Expand Down Expand Up @@ -213,28 +212,32 @@ func (c *BackendConfig) ShouldCallSpecificFunction() bool {
// MMProjFileName returns the filename of the MMProj file
// If the MMProj is a URL, it will return the MD5 of the URL which is the filename
func (c *BackendConfig) MMProjFileName() string {
modelURL := downloader.ConvertURL(c.MMProj)
if downloader.LooksLikeURL(modelURL) {
return utils.MD5(modelURL)
uri := downloader.URI(c.MMProj)
if uri.LooksLikeURL() {
f, _ := uri.FilenameFromUrl()
return f
}

return c.MMProj
}

func (c *BackendConfig) IsMMProjURL() bool {
return downloader.LooksLikeURL(downloader.ConvertURL(c.MMProj))
uri := downloader.URI(c.MMProj)
return uri.LooksLikeURL()
}

func (c *BackendConfig) IsModelURL() bool {
return downloader.LooksLikeURL(downloader.ConvertURL(c.Model))
uri := downloader.URI(c.Model)
return uri.LooksLikeURL()
}

// ModelFileName returns the filename of the model
// If the model is a URL, it will return the MD5 of the URL which is the filename
func (c *BackendConfig) ModelFileName() string {
modelURL := downloader.ConvertURL(c.Model)
if downloader.LooksLikeURL(modelURL) {
return utils.MD5(modelURL)
uri := downloader.URI(c.Model)
if uri.LooksLikeURL() {
f, _ := uri.FilenameFromUrl()
return f
}

return c.Model
Expand Down
10 changes: 5 additions & 5 deletions core/config/backend_config_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,18 +244,18 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {
// Create file path
filePath := filepath.Join(modelPath, file.Filename)

if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil {
if err := file.URI.DownloadFile(filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil {
return err
}
}

// If the model is an URL, expand it, and download the file
if config.IsModelURL() {
modelFileName := config.ModelFileName()
modelURL := downloader.ConvertURL(config.Model)
uri := downloader.URI(config.Model)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status)
if err != nil {
return err
}
Expand All @@ -269,10 +269,10 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {

if config.IsMMProjURL() {
modelFileName := config.MMProjFileName()
modelURL := downloader.ConvertURL(config.MMProj)
uri := downloader.URI(config.MMProj)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion core/dependencies_manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ func main() {

// download the assets
for _, asset := range assets {
if err := downloader.DownloadFile(asset.URL, filepath.Join(destPath, asset.FileName), asset.SHA, 1, 1, utils.DisplayDownloadFunction); err != nil {
uri := downloader.URI(asset.URL)
if err := uri.DownloadFile(filepath.Join(destPath, asset.FileName), asset.SHA, 1, 1, utils.DisplayDownloadFunction); err != nil {
panic(err)
}
}
Expand Down
10 changes: 6 additions & 4 deletions core/gallery/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) ([]*Gal

func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
var refFile string
err := downloader.DownloadAndUnmarshal(url, basePath, func(url string, d []byte) error {
uri := downloader.URI(url)
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
refFile = string(d)
if len(refFile) == 0 {
return fmt.Errorf("invalid reference file at url %s: %s", url, d)
Expand All @@ -153,8 +154,9 @@ func getGalleryModels(gallery config.Gallery, basePath string) ([]*GalleryModel,
return models, err
}
}
uri := downloader.URI(gallery.URL)

err := downloader.DownloadAndUnmarshal(gallery.URL, basePath, func(url string, d []byte) error {
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &models)
})
if err != nil {
Expand Down Expand Up @@ -252,8 +254,8 @@ func SafetyScanGalleryModels(galleries []config.Gallery, basePath string) error

func SafetyScanGalleryModel(galleryModel *GalleryModel) error {
for _, file := range galleryModel.AdditionalFiles {
scanResults, err := downloader.HuggingFaceScan(file.URI)
if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) {
scanResults, err := downloader.HuggingFaceScan(downloader.URI(file.URI))
if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) {
log.Error().Str("model", galleryModel.Name).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("Contains unsafe file(s)!")
return err
}
Expand Down
11 changes: 6 additions & 5 deletions core/gallery/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ type PromptTemplate struct {

func GetGalleryConfigFromURL(url string, basePath string) (Config, error) {
var config Config
err := downloader.DownloadAndUnmarshal(url, basePath, func(url string, d []byte) error {
uri := downloader.URI(url)
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
Expand Down Expand Up @@ -118,14 +119,14 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
filePath := filepath.Join(basePath, file.Filename)

if enforceScan {
scanResults, err := downloader.HuggingFaceScan(file.URI)
if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) {
scanResults, err := downloader.HuggingFaceScan(downloader.URI(file.URI))
if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) {
log.Error().Str("model", config.Name).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("Contains unsafe file(s)!")
return err
}
}

if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
uri := downloader.URI(file.URI)
if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
return err
}
}
Expand Down
3 changes: 2 additions & 1 deletion core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ func getModelStatus(url string) (response map[string]interface{}) {
}

func getModels(url string) (response []gallery.GalleryModel) {
uri := downloader.URI(url)
// TODO: No tests currently seem to exercise file:// urls. Fix?
downloader.DownloadAndUnmarshal(url, "", func(url string, i []byte) error {
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
Expand Down
4 changes: 2 additions & 2 deletions embedded/embedded.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ func init() {

func GetRemoteLibraryShorteners(url string, basePath string) (map[string]string, error) {
remoteLibrary := map[string]string{}

err := downloader.DownloadAndUnmarshal(url, basePath, func(_ string, i []byte) error {
uri := downloader.URI(url)
err := uri.DownloadAndUnmarshal(basePath, func(_ string, i []byte) error {
return yaml.Unmarshal(i, &remoteLibrary)
})
if err != nil {
Expand Down
49 changes: 49 additions & 0 deletions pkg/downloader/huggingface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package downloader

import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
)

type HuggingFaceScanResult struct {
RepositoryId string `json:"repositoryId"`
Revision string `json:"revision"`
HasUnsafeFiles bool `json:"hasUnsafeFile"`
ClamAVInfectedFiles []string `json:"clamAVInfectedFiles"`
DangerousPickles []string `json:"dangerousPickles"`
ScansDone bool `json:"scansDone"`
}

var ErrNonHuggingFaceFile = errors.New("not a huggingface repo")
var ErrUnsafeFilesFound = errors.New("unsafe files found")

func HuggingFaceScan(uri URI) (*HuggingFaceScanResult, error) {
cleanParts := strings.Split(uri.ResolveURL(), "/")
if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" {
return nil, ErrNonHuggingFaceFile
}
results, err := http.Get(fmt.Sprintf("https://huggingface.co/api/models/%s/%s/scan", cleanParts[3], cleanParts[4]))
if err != nil {
return nil, err
}
if results.StatusCode != 200 {
return nil, fmt.Errorf("unexpected status code during HuggingFaceScan: %d", results.StatusCode)
}
scanResult := &HuggingFaceScanResult{}
bodyBytes, err := io.ReadAll(results.Body)
if err != nil {
return nil, err
}
err = json.Unmarshal(bodyBytes, scanResult)
if err != nil {
return nil, err
}
if scanResult.HasUnsafeFiles {
return scanResult, ErrUnsafeFilesFound
}
return scanResult, nil
}
Loading

0 comments on commit a36b721

Please sign in to comment.