Skip to content

Commit

Permalink
Add lock for usage tracking.
Browse files Browse the repository at this point in the history
This avoids racing requests to be counted just once.

PiperOrigin-RevId: 653809899
  • Loading branch information
daiyip authored and langfun authors committed Jul 19, 2024
1 parent 4499244 commit e5dfc27
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions langfun/core/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import contextlib
import dataclasses
import enum
import threading
import time
from typing import Annotated, Any, Callable, Iterator, Sequence, Tuple, Type, Union
from langfun.core import component
Expand Down Expand Up @@ -728,17 +729,19 @@ class _UsageTracker:

def __init__(self, model_ids: set[str] | None):
self.model_ids = model_ids
self._lock = threading.Lock()
self.usages = {
m: LMSamplingUsage(0, 0, 0, 0) for m in model_ids
} if model_ids else {}

def track(self, model_id: str, usage: LMSamplingUsage):
if self.model_ids is not None and model_id not in self.model_ids:
return
if not isinstance(usage, UsageNotAvailable) and model_id in self.usages:
self.usages[model_id] += usage
else:
self.usages[model_id] = usage
with self._lock:
if not isinstance(usage, UsageNotAvailable) and model_id in self.usages:
self.usages[model_id] += usage
else:
self.usages[model_id] = usage


@contextlib.contextmanager
Expand Down

0 comments on commit e5dfc27

Please sign in to comment.