Skip to content

Commit

Permalink
using match arm instead of string comparison
Browse files Browse the repository at this point in the history
Signed-off-by: David Anyatonwu <davidanyatonwu@gmail.com>
  • Loading branch information
onyedikachi-david committed Aug 20, 2024
1 parent fc68ada commit 2a1e1d7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 26 deletions.
45 changes: 40 additions & 5 deletions src/cli/llm/error.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,46 @@
use derive_more::From;
use strum_macros::Display;
use reqwest::StatusCode;
use thiserror::Error;

#[derive(Debug, From, Display, thiserror::Error)]
#[derive(Debug, Error)]
pub enum WebcError {
#[error("Response failed with status {status}: {body}")]
ResponseFailedStatus { status: StatusCode, body: String },
#[error("Reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
}

#[derive(Debug, Error)]
pub enum Error {
#[error("GenAI error: {0}")]
GenAI(genai::Error),
#[error("Webc error: {0}")]
Webc(WebcError),
#[error("Empty response")]
EmptyResponse,
Serde(serde_json::Error),
#[error("Serde error: {0}")]
Serde(#[from] serde_json::Error),
}

impl From<genai::Error> for Error {
fn from(err: genai::Error) -> Self {
if let genai::Error::WebModelCall { webc_error, .. } = &err {
let error_str = webc_error.to_string();
if error_str.contains("ResponseFailedStatus") {
// Extract status and body from the error message
let parts: Vec<&str> = error_str.splitn(3, ": ").collect();
if parts.len() >= 3 {
if let Ok(status) = parts[1].parse::<u16>() {
return Error::Webc(WebcError::ResponseFailedStatus {
status: StatusCode::from_u16(status)
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
body: parts[2].to_string(),
});
}
}
}
}
Error::GenAI(err)
}
}

pub type Result<A> = std::result::Result<A, Error>;
pub type Result<T> = std::result::Result<T, Error>;
34 changes: 13 additions & 21 deletions src/cli/llm/wizard.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
// use std::borrow::Borrow;

use derive_setters::Setters;
use genai::adapter::AdapterKind;
use genai::chat::{ChatOptions, ChatRequest, ChatResponse};
use genai::resolver::AuthResolver;
use genai::Client;
use reqwest::StatusCode;
use tokio_retry::strategy::{jitter, ExponentialBackoff};
use tokio_retry::Retry;

use super::error::{Error, Result};
use super::error::{Error, Result, WebcError};
use crate::cli::llm::model::Model;

#[derive(Setters, Clone)]
Expand Down Expand Up @@ -56,28 +59,17 @@ impl<Q, A> Wizard<Q, A> {
.await
{
Ok(response) => Ok(A::try_from(response)?),
Err(genai::Error::WebModelCall { webc_error, .. }) => {
if webc_error.to_string().contains("429") {
Err(Error::GenAI(genai::Error::WebModelCall {
model_info: genai::ModelInfo::new(
AdapterKind::from_model(self.model.as_str())
.unwrap_or(AdapterKind::Ollama),
self.model.as_str(),
),
webc_error,
}))
} else {
Ok(Err(Error::GenAI(genai::Error::WebModelCall {
model_info: genai::ModelInfo::new(
AdapterKind::from_model(self.model.as_str())
.unwrap_or(AdapterKind::Ollama),
self.model.as_str(),
),
webc_error,
}))?)
Err(err) => {
let error = Error::from(err);
match &error {
Error::Webc(WebcError::ResponseFailedStatus { status, .. })
if *status == StatusCode::TOO_MANY_REQUESTS =>
{
Err(error) // Propagate the error to trigger a retry
}
_ => Ok(Err(error)?), // Other errors are returned without retrying
}
}
Err(e) => Ok(Err(Error::GenAI(e))?),
}
})
.await
Expand Down

0 comments on commit 2a1e1d7

Please sign in to comment.