2025-12-01
This commit is contained in:
@@ -16,10 +16,8 @@
|
||||
#
|
||||
# ##### END GPL LICENSE BLOCK #####
|
||||
|
||||
from concurrent.futures import Future
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from errno import EACCES, ENOSPC
|
||||
import functools
|
||||
import os
|
||||
@@ -30,8 +28,6 @@ from .assets import (AssetType,
|
||||
AssetData,
|
||||
ModelType,
|
||||
SIZES)
|
||||
from .user import (PoliigonUser,
|
||||
PoliigonSubscription)
|
||||
|
||||
from .plan_manager import SubscriptionState, PoliigonPlanUpgradeManager
|
||||
|
||||
@@ -44,7 +40,7 @@ from .logger import (DEBUG, # noqa F401, allowing downstream const usage
|
||||
get_addon_logger,
|
||||
NOT_SET,
|
||||
WARNING)
|
||||
from .notifications import NotificationSystem
|
||||
from .notifications import NotificationSystem, Notification
|
||||
from . import settings
|
||||
from . import updater
|
||||
from .multilingual import Multilingual
|
||||
@@ -66,6 +62,22 @@ class PoliigonAddon():
|
||||
|
||||
library_paths: List = []
|
||||
|
||||
# Variables stored in the addon class to handle the top-level assets
|
||||
# requests. Values are set on api remote control jobs (also api_rc params).
|
||||
# NOTE: These values should only be changed in addon-core side;
|
||||
# DO NOT CHANGE IT ON ANY DCC ADDON CODE;
|
||||
all_assets_fetched: bool = False
|
||||
my_assets_fetched: bool = False
|
||||
recent_downloads_fetched: bool = False
|
||||
|
||||
# Variables to control inject token from web process
|
||||
install_emitted: bool = False
|
||||
is_user_injected: bool = False
|
||||
_injected_token: Optional[str] = None
|
||||
|
||||
# WM Onboarding notice
|
||||
wm_onboarding_notice: Optional[Notification] = None
|
||||
|
||||
def __init__(self,
|
||||
addon_name: str,
|
||||
addon_version: tuple,
|
||||
@@ -78,7 +90,10 @@ class PoliigonAddon():
|
||||
language: str = "en-US",
|
||||
# See ThreadManager.__init__ for signature below,
|
||||
# e.g. print_exc(fut: Future, key_pool: PoolKeys)
|
||||
callback_print_exc: Optional[Callable] = None):
|
||||
callback_print_exc: Optional[Callable] = None,
|
||||
# Used to get access token file - if None, token file will
|
||||
# need to be injected in the addon side;
|
||||
addon_root_path: Optional[str] = None):
|
||||
self.log_manager = get_addon_logger(env=addon_env)
|
||||
|
||||
if addon_env.env_name == "prod":
|
||||
@@ -105,6 +120,7 @@ class PoliigonAddon():
|
||||
self.software_version = software_version
|
||||
self.addon_convention = addon_convention
|
||||
|
||||
self.addon_root_path = addon_root_path
|
||||
self.user = None
|
||||
self.login_error = None
|
||||
self.api_rc = None # To be set on the DCC side
|
||||
@@ -116,6 +132,7 @@ class PoliigonAddon():
|
||||
self.set_logger_verbose(verbose=False)
|
||||
|
||||
self._settings = addon_settings
|
||||
self.settings_config = self._settings.config
|
||||
self._api = api.PoliigonConnector(
|
||||
env=self._env,
|
||||
software=software_source,
|
||||
@@ -129,6 +146,8 @@ class PoliigonAddon():
|
||||
".".join([str(x) for x in addon_version]),
|
||||
".".join([str(x) for x in software_version])
|
||||
)
|
||||
self.set_api_token()
|
||||
|
||||
self._tm = tm.ThreadManager(callback_print_exc=callback_print_exc)
|
||||
self.notify = NotificationSystem(self)
|
||||
self._api.notification_system = self.notify
|
||||
@@ -140,8 +159,6 @@ class PoliigonAddon():
|
||||
local_json=self._env.local_updater_json
|
||||
)
|
||||
|
||||
self.settings_config = self._settings.config
|
||||
|
||||
self.user_addon_dir = os.path.join(
|
||||
os.path.expanduser("~"),
|
||||
"Poliigon"
|
||||
@@ -175,12 +192,19 @@ class PoliigonAddon():
|
||||
report_exception: Callable,
|
||||
report_thread: Callable,
|
||||
status_listener: Callable,
|
||||
get_renderer_name: Callable,
|
||||
urls_dcc: Dict[str, str],
|
||||
notify_icon_info: Any,
|
||||
notify_icon_no_connection: Any,
|
||||
notify_icon_survey: Any,
|
||||
notify_icon_warn: Any,
|
||||
notify_update_body: str
|
||||
notify_icon_wm_onboarding: Any,
|
||||
notify_update_body: str,
|
||||
onboarding_wm_title: Optional[str] = None,
|
||||
onboarding_wm_label: Optional[str] = None,
|
||||
onboarding_wm_tooltip: Optional[str] = None,
|
||||
allow_banner_notice: bool = False
|
||||
|
||||
# TODO(Andreas): Once API RC gets instanced here, add:
|
||||
# page_size_online_assets: int,
|
||||
# page_size_my_assets: int,
|
||||
@@ -194,6 +218,7 @@ class PoliigonAddon():
|
||||
self._api.get_optin = get_optin
|
||||
self._api.set_on_invalidated(callback_on_invalidated_token)
|
||||
self._api._status_listener = status_listener
|
||||
self._api.get_current_renderer_name = get_renderer_name
|
||||
self._api.add_poliigon_urls(urls_dcc)
|
||||
self._api._report_message = report_message
|
||||
self._api._report_exception = report_exception
|
||||
@@ -204,8 +229,17 @@ class PoliigonAddon():
|
||||
icon_info=notify_icon_info,
|
||||
icon_no_connection=notify_icon_no_connection,
|
||||
icon_survey=notify_icon_survey,
|
||||
icon_warn=notify_icon_warn)
|
||||
icon_warn=notify_icon_warn,
|
||||
notify_icon_wm_onboarding=notify_icon_wm_onboarding)
|
||||
self.notify.addon_params.update_body = notify_update_body
|
||||
self.notify.addon_params.allow_banner_notice = allow_banner_notice
|
||||
|
||||
if onboarding_wm_title is not None:
|
||||
self.notify.addon_params.onboarding_wm_title = onboarding_wm_title
|
||||
if onboarding_wm_label is not None:
|
||||
self.notify.addon_params.onboarding_wm_label = onboarding_wm_label
|
||||
if onboarding_wm_tooltip is not None:
|
||||
self.notify.addon_params.onboarding_wm_tooltip = onboarding_wm_tooltip
|
||||
|
||||
# TODO(Andreas): Once API RC gets instanced in constructor,
|
||||
# add the following here:
|
||||
@@ -232,6 +266,196 @@ class PoliigonAddon():
|
||||
return wrapped_func_call
|
||||
return wrapped_func
|
||||
|
||||
@run_threaded(tm.PoolKeys.INTERACTIVE)
|
||||
def signal_preview_asset(self, asset_id: int) -> None:
|
||||
"""Signals an asset preview in the background if user opted in."""
|
||||
self._api.signal_preview_asset(asset_id)
|
||||
|
||||
@run_threaded(tm.PoolKeys.INTERACTIVE)
|
||||
def signal_import_asset(self, asset_id: int) -> None:
|
||||
"""Signals an asset imported in the background if user opted in."""
|
||||
self._api.signal_import_asset(asset_id)
|
||||
|
||||
@run_threaded(tm.PoolKeys.INTERACTIVE)
|
||||
def signal_view_screen(self, screen_name: str) -> None:
|
||||
"""Signals a screen tab was viewed."""
|
||||
self._api.signal_view_screen(screen_name)
|
||||
|
||||
@run_threaded(tm.PoolKeys.INTERACTIVE)
|
||||
def signal_search(self, search: str) -> None:
|
||||
"""Signals a search text was triggered."""
|
||||
if search != "":
|
||||
self._api.signal_search(search)
|
||||
|
||||
@run_threaded(tm.PoolKeys.INTERACTIVE)
|
||||
def signal_category_filter(self, categories: str) -> None:
|
||||
"""Signals a category was clicked."""
|
||||
self._api.signal_view_category(categories)
|
||||
|
||||
@run_threaded(tm.PoolKeys.INTERACTIVE)
|
||||
def signal_view_notification(self, notification_id: str) -> None:
|
||||
"""Signals a notification was viewed."""
|
||||
self._api.signal_view_notification(notification_id)
|
||||
|
||||
@run_threaded(tm.PoolKeys.INTERACTIVE)
|
||||
def signal_click_notification(
|
||||
self, notification_id: str, action: str) -> None:
|
||||
"""Signals a notification click event."""
|
||||
self._api.signal_click_notification(notification_id, action)
|
||||
|
||||
@run_threaded(tm.PoolKeys.INTERACTIVE)
|
||||
def update_user_use(self) -> None:
|
||||
"""Update user profile on api."""
|
||||
|
||||
user_use = self.user.user_profile
|
||||
if user_use is None:
|
||||
return
|
||||
|
||||
# Check if user already has primary_3d_software set
|
||||
# Only assign DCC if not already set
|
||||
assign_dcc = self.user.primary_3d_software is None
|
||||
|
||||
user_use_value = self.user.user_profile.value
|
||||
email_preference = self.user.email_preference
|
||||
render_engine = self.user.primary_rendering_engine
|
||||
|
||||
self._api.update_user_profile(user_use=user_use_value,
|
||||
email_preference=email_preference,
|
||||
primary_render_engine=render_engine,
|
||||
assign_dcc=assign_dcc)
|
||||
|
||||
def check_install_event(self):
|
||||
""" Triggers an install event. Should only be called when the automated
|
||||
login (Injected token) happened.
|
||||
If api token is not set, means that the user will have to log in and then
|
||||
the backend will emit the event using time_since_enabled (server side).
|
||||
Install event should be emmit only once;"""
|
||||
|
||||
if self._api.token is None or self.install_emitted:
|
||||
return
|
||||
|
||||
same_injected_token = self._api.token == self._injected_token
|
||||
if self.is_user_injected and same_injected_token and not self._api.invalidated:
|
||||
self.install_emitted = True
|
||||
self._api.signal_install()
|
||||
|
||||
def set_api_token(self) -> None:
|
||||
if self.addon_root_path is None:
|
||||
return
|
||||
access_token_file = os.path.join(self.addon_root_path, "access_token.txt")
|
||||
self.inject_token(access_token_file)
|
||||
|
||||
def inject_token(self,
|
||||
downloaded_token_file: str,
|
||||
callback: Optional[callable] = None,
|
||||
delete_token_file: bool = True
|
||||
) -> None:
|
||||
|
||||
valid_file = os.path.isfile(downloaded_token_file)
|
||||
if not valid_file:
|
||||
self.logger.debug("Token file not found.")
|
||||
return
|
||||
txt_file = downloaded_token_file.endswith(".txt")
|
||||
if not txt_file:
|
||||
self.logger.error("Invalid token file provided. Ignoring auto login;")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(downloaded_token_file, "r") as f:
|
||||
_token_content = f.read()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error parsing login token: {e}")
|
||||
return
|
||||
|
||||
if delete_token_file:
|
||||
try:
|
||||
os.remove(downloaded_token_file)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unable to delete token file. {e};")
|
||||
|
||||
local_token = self.settings_config.get("user", "token", fallback=None)
|
||||
if local_token:
|
||||
self._api.token = local_token
|
||||
return
|
||||
|
||||
# After splitting, two fields are expected: the token and the opt-in flag.
|
||||
# If any of them is missing, do not inject the token;
|
||||
split_token = _token_content.split(",")
|
||||
if len(_token_content) >= 2 and _token_content[-2] != "," or len(split_token) != 2:
|
||||
self.logger.error(f"Unexpected token format: {_token_content}")
|
||||
return
|
||||
|
||||
self._injected_token, _opt_in = split_token
|
||||
|
||||
try:
|
||||
opt_in_int = int(_opt_in)
|
||||
except ValueError:
|
||||
self.logger.error(f"OptIn value is not an Integer: {_opt_in};")
|
||||
return
|
||||
|
||||
opt_in_str = "true" if bool(opt_in_int) else "false"
|
||||
self.logger.debug(f"Injecting Token\n"
|
||||
f"Token: {self._injected_token} | OptIn: {opt_in_str}")
|
||||
self._api.token = self._injected_token
|
||||
self.settings_config.set("user", "token", self._injected_token)
|
||||
self.settings_config.set("logging", "reporting_opt_in", opt_in_str)
|
||||
# User Confirmed flag is currently only being used by P4Max
|
||||
self.settings_config.set("logging", "user_confirmed", "true")
|
||||
self._settings.save_settings()
|
||||
|
||||
self.is_user_injected = True
|
||||
|
||||
if callback is not None:
|
||||
callback(self._injected_token)
|
||||
|
||||
def is_addon_first_onboarding_done(self) -> bool:
|
||||
return self.settings_config.getboolean("onboarding",
|
||||
"addon_first_completed",
|
||||
fallback=False)
|
||||
|
||||
def refresh_top_level_queries_flags(self):
|
||||
self.all_assets_fetched = False
|
||||
self.my_assets_fetched = False
|
||||
self.recent_downloads_fetched = False
|
||||
|
||||
def are_user_assets_fetched(self) -> bool:
|
||||
if self.user_legacy_own_assets():
|
||||
return self.my_assets_fetched and self.recent_downloads_fetched
|
||||
return self.recent_downloads_fetched
|
||||
|
||||
def set_addon_first_onboarding_done(self) -> None:
|
||||
self.settings_config.set("onboarding", "addon_first_completed", "true")
|
||||
self._settings.save_settings()
|
||||
|
||||
def is_onboarding_wm_preview_done(self) -> bool:
|
||||
# This parameter should be set as True in the dcc side, everytime a
|
||||
# watermarked preview is imported;
|
||||
return self.settings_config.getboolean("onboarding",
|
||||
"did_watermark_preview",
|
||||
fallback=False)
|
||||
|
||||
def dismiss_onboarding_wm_notice(self):
|
||||
if self.wm_onboarding_notice:
|
||||
self.notify.dismiss_notice(self.wm_onboarding_notice, force=True)
|
||||
self.wm_onboarding_notice = None
|
||||
|
||||
def set_onboarding_wm_preview_done(self, check_banner: bool = False):
|
||||
self.settings_config.set("onboarding", "did_watermark_preview", "true")
|
||||
self._settings.save_settings()
|
||||
|
||||
self.dismiss_onboarding_wm_notice()
|
||||
if check_banner:
|
||||
self.upgrade_manager.check_show_banner()
|
||||
|
||||
def check_onboarding_notice(self):
|
||||
did_wm = self.is_onboarding_wm_preview_done()
|
||||
|
||||
if self.wm_onboarding_notice is not None:
|
||||
self.dismiss_onboarding_wm_notice()
|
||||
|
||||
if not did_wm and not self.is_unlimited_user():
|
||||
self.wm_onboarding_notice = self.notify.create_watermarked_onboarding()
|
||||
|
||||
def setup_libraries(self):
|
||||
default_lib_path = os.path.join(self.user_addon_dir, "Library")
|
||||
multi_dir = self.settings_config["directories"]
|
||||
@@ -334,35 +558,6 @@ class PoliigonAddon():
|
||||
"""Clears any invalidation flag for a user."""
|
||||
self._api.invalidated = False
|
||||
|
||||
@run_threaded(tm.PoolKeys.INTERACTIVE)
|
||||
def log_in_with_credentials(self,
|
||||
email: str,
|
||||
password: str,
|
||||
*,
|
||||
wait_for_user: bool = False) -> Future:
|
||||
self.clear_user_invalidated()
|
||||
|
||||
req = self._api.log_in(
|
||||
email,
|
||||
password
|
||||
)
|
||||
|
||||
if req.ok:
|
||||
user_data = req.body.get("user", {})
|
||||
|
||||
fut = self.create_user(user_data.get("name"), user_data.get("id"))
|
||||
if wait_for_user:
|
||||
fut.result(timeout=api.TIMEOUT)
|
||||
|
||||
self.login_error = None
|
||||
else:
|
||||
self.login_error = req.error
|
||||
|
||||
return req
|
||||
|
||||
def log_in_with_website(self):
|
||||
pass
|
||||
|
||||
def check_for_survey_notice(
|
||||
self,
|
||||
free_user_url: str,
|
||||
@@ -403,19 +598,6 @@ class PoliigonAddon():
|
||||
on_dismiss_callable=set_user_survey_flag
|
||||
)
|
||||
|
||||
@run_threaded(tm.PoolKeys.INTERACTIVE)
|
||||
def log_out(self):
|
||||
req = self._api.log_out()
|
||||
if req.ok:
|
||||
print("Logout success")
|
||||
else:
|
||||
print(req.error)
|
||||
|
||||
self._api.token = None
|
||||
|
||||
# Clear out user on logout.
|
||||
self.user = None
|
||||
|
||||
def add_library_path(self,
|
||||
path: str,
|
||||
primary: bool = True,
|
||||
@@ -494,79 +676,6 @@ class PoliigonAddon():
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_user_info(self) -> Tuple:
|
||||
req = self._api.get_user_info()
|
||||
user_name = None
|
||||
user_id = None
|
||||
|
||||
if req.ok:
|
||||
data = req.body
|
||||
user_name = data["user"]["name"]
|
||||
user_id = data["user"]["id"]
|
||||
self.login_error = None
|
||||
else:
|
||||
# TODO(SOFT-1029): Create an error log for fail in get user info
|
||||
self.login_error = req.error
|
||||
|
||||
return user_name, user_id
|
||||
|
||||
def _get_credits(self):
|
||||
if self.user is None:
|
||||
msg = "_get_credits() called without user."
|
||||
self._api.report_message(
|
||||
"addon_get_credits", msg, "error")
|
||||
return
|
||||
|
||||
req = self._api.get_user_balance()
|
||||
if req.ok:
|
||||
data = req.body
|
||||
self.user.credits = data.get("subscription_balance")
|
||||
self.user.credits_od = data.get("ondemand_balance")
|
||||
else:
|
||||
self.user.credits = None
|
||||
self.user.credits_od = None
|
||||
msg = f"ERROR: {req.error}"
|
||||
self._api.report_message(
|
||||
"addon_get_credits", msg, "error")
|
||||
|
||||
def _get_subscription_details(self):
|
||||
"""Fetches the current user's subscription status."""
|
||||
req = self._api.get_subscription_details()
|
||||
|
||||
if req.ok:
|
||||
plan = req.body
|
||||
self.user.plan.update_from_dict(plan)
|
||||
|
||||
@run_threaded(tm.PoolKeys.INTERACTIVE)
|
||||
def update_plan_data(self, done_callback: Optional[Callable] = None) -> None:
|
||||
# TODO(Joao): sub thread the two private functions
|
||||
self._get_credits()
|
||||
self._get_subscription_details()
|
||||
if done_callback is not None:
|
||||
done_callback()
|
||||
|
||||
def create_user(
|
||||
self,
|
||||
user_name: Optional[str] = None,
|
||||
user_id: Optional[int] = None,
|
||||
done_callback: Optional[Callable] = None) -> Optional[Future]:
|
||||
|
||||
if user_name is None or user_id is None:
|
||||
user_name, user_id = self._get_user_info()
|
||||
|
||||
if user_name is None or user_id is None:
|
||||
return None
|
||||
|
||||
self.user = PoliigonUser(
|
||||
user_name=user_name,
|
||||
user_id=user_id,
|
||||
plan=PoliigonSubscription(
|
||||
subscription_state=SubscriptionState.NOT_POPULATED)
|
||||
)
|
||||
|
||||
future = self.update_plan_data(done_callback)
|
||||
return future
|
||||
|
||||
def is_free_user(self) -> bool:
|
||||
"""Identifies a free user which neither
|
||||
has a plan nor on demand credits."""
|
||||
@@ -591,6 +700,15 @@ class PoliigonAddon():
|
||||
return False
|
||||
return self.user.plan.is_unlimited
|
||||
|
||||
def is_legacy_limited_user(self) -> bool:
|
||||
if self.user is None:
|
||||
return False
|
||||
elif self.user.plan in [None, SubscriptionState.NOT_POPULATED]:
|
||||
return False
|
||||
elif self.user.plan.is_limited_legacy is None:
|
||||
return False
|
||||
return self.user.plan.is_limited_legacy
|
||||
|
||||
def is_paused_subscription(self) -> Optional[bool]:
|
||||
"""Return True, if the Subscription is in paused state.
|
||||
|
||||
@@ -601,6 +719,11 @@ class PoliigonAddon():
|
||||
return None
|
||||
return self.user.plan.subscription_state == SubscriptionState.PAUSED
|
||||
|
||||
def user_legacy_own_assets(self):
|
||||
"""Return True if the user has any Owned Asset (Legacy)"""
|
||||
|
||||
return self.user is not None and self.user.count_assets_owned
|
||||
|
||||
def get_user_credits(self, incl_od: bool = True) -> int:
|
||||
"""Returns the number of _spendable_ credits."""
|
||||
|
||||
@@ -718,23 +841,6 @@ class PoliigonAddon():
|
||||
"user", "first_purchase", str(time_stamp))
|
||||
self._settings.save_settings()
|
||||
|
||||
def print_debug(self, *args, dbg=False, bg=True):
|
||||
"""Print out a debug statement with no separator line.
|
||||
|
||||
Cache based on args up to a limit, to avoid excessive repeat prints.
|
||||
All args must be flat values, such as already casted to strings, else
|
||||
an error will be thrown.
|
||||
"""
|
||||
if dbg:
|
||||
# Ensure all inputs are hashable, otherwise lru_cache fails.
|
||||
stringified = [str(arg) for arg in args]
|
||||
self._cached_print(*stringified, bg=bg)
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _cached_print(self, *args, bg: bool):
|
||||
"""A safe-to-cache function for printing."""
|
||||
print(*args)
|
||||
|
||||
def open_asset_url(self, asset_id: int) -> None:
|
||||
asset_data = self._asset_index.get_asset(asset_id)
|
||||
url = self._api.add_utm_suffix(asset_data.url)
|
||||
@@ -759,50 +865,6 @@ class PoliigonAddon():
|
||||
path_wm_previews = os.path.join(path_thumbs, asset_name)
|
||||
return path_wm_previews
|
||||
|
||||
def download_material_wm(
|
||||
self, files_to_download: List[Tuple[str, str]]) -> None:
|
||||
"""Synchronous function to download material preview."""
|
||||
|
||||
urls = []
|
||||
files_dl = []
|
||||
for _url_wm, _filename_wm_dl in files_to_download:
|
||||
urls.append(_url_wm)
|
||||
files_dl.append(_filename_wm_dl)
|
||||
|
||||
resp = self._api.pooled_preview_download(urls, files_dl)
|
||||
if not resp.ok:
|
||||
msg = f"Failed to download WM preview\n{resp}"
|
||||
self._api.report_message(
|
||||
"download_mat_preview_dl_failed", msg, "error")
|
||||
# Continue, as some may have worked.
|
||||
|
||||
for _filename_wm_dl in files_dl:
|
||||
filename_wm = _filename_wm_dl[:-3] # cut of _dl
|
||||
|
||||
try:
|
||||
file_exists = os.path.exists(filename_wm)
|
||||
dl_exists = os.path.exists(_filename_wm_dl)
|
||||
if file_exists and dl_exists:
|
||||
os.remove(filename_wm)
|
||||
elif not file_exists and not dl_exists:
|
||||
raise FileNotFoundError
|
||||
if dl_exists:
|
||||
os.rename(_filename_wm_dl, filename_wm)
|
||||
except FileNotFoundError:
|
||||
msg = f"Neither {filename_wm}, nor {_filename_wm_dl} exist"
|
||||
self._api.report_message(
|
||||
"download_mat_existing_file", msg, "error")
|
||||
except FileExistsError:
|
||||
msg = f"File {filename_wm} already exists, failed to rename"
|
||||
self._api.report_message(
|
||||
"download_mat_rename", msg, "error")
|
||||
except Exception as e:
|
||||
self.logger.exception("Unexpected exception while renaming WM preview")
|
||||
msg = f"Unexpected exception while renaming {_filename_wm_dl}\n{e}"
|
||||
self._api.report_message(
|
||||
"download_wm_exception", msg, "error")
|
||||
return resp
|
||||
|
||||
def get_config_param(self,
|
||||
name_param: str,
|
||||
name_group: str = "DEFAULT",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -47,8 +47,10 @@ from .api_remote_control_params import (
|
||||
ApiJobParamsPutUpgradePlan,
|
||||
ApiJobParamsResumePlan,
|
||||
ApiJobParamsGetAssets,
|
||||
ApiJobParamsGetAllAssets,
|
||||
ApiJobParamsLogin,
|
||||
ApiJobParamsPurchaseAsset,
|
||||
CATEGORY_ALL,
|
||||
CmdLoginMode
|
||||
)
|
||||
from .assets import AssetData
|
||||
@@ -82,11 +84,12 @@ class JobType(IntEnum):
|
||||
PUT_UPGRADE_PLAN = 6
|
||||
RESUME_PLAN = 7
|
||||
GET_ASSETS = 10
|
||||
DOWNLOAD_THUMB = 11
|
||||
PURCHASE_ASSET = 12
|
||||
DOWNLOAD_ASSET = 13
|
||||
DOWNLOAD_WM_PREVIEW = 14,
|
||||
UNIT_TEST = 15,
|
||||
GET_ALL_ASSETS = 11
|
||||
DOWNLOAD_THUMB = 12
|
||||
PURCHASE_ASSET = 13
|
||||
DOWNLOAD_ASSET = 14
|
||||
DOWNLOAD_WM_PREVIEW = 15,
|
||||
UNIT_TEST = 16,
|
||||
EXIT = 99999
|
||||
|
||||
|
||||
@@ -242,6 +245,9 @@ class ApiRemoteControl():
|
||||
def add_job_get_user_data(self,
|
||||
user_name: str,
|
||||
user_id: str,
|
||||
do_fetch_plans: bool = True,
|
||||
do_fetch_categories: bool = True,
|
||||
do_fetch_asset_data: bool = True,
|
||||
callback_cancel: Optional[Callable] = None,
|
||||
callback_progress: Optional[Callable] = None,
|
||||
callback_done: Optional[Callable] = None,
|
||||
@@ -249,7 +255,12 @@ class ApiRemoteControl():
|
||||
) -> None:
|
||||
"""Convenience function to add a get user data job."""
|
||||
|
||||
params = ApiJobParamsGetUserData(user_name, user_id)
|
||||
params = ApiJobParamsGetUserData(
|
||||
user_name=user_name,
|
||||
user_id=user_id,
|
||||
do_fetch_plans=do_fetch_plans,
|
||||
do_fetch_categories=do_fetch_categories,
|
||||
do_fetch_asset_data=do_fetch_asset_data)
|
||||
self.add_job(
|
||||
job_type=JobType.GET_USER_DATA,
|
||||
params=params,
|
||||
@@ -363,7 +374,7 @@ class ApiRemoteControl():
|
||||
def add_job_get_assets(self,
|
||||
library_paths: List[str],
|
||||
tab: str, # KEY_TAB_ONLINE, KEY_TAB_MY_ASSETS
|
||||
category_list: List[str] = ["All Assets"],
|
||||
category_list: List[str] = [CATEGORY_ALL],
|
||||
search: str = "",
|
||||
idx_page: int = 1,
|
||||
page_size: int = 10,
|
||||
@@ -373,10 +384,14 @@ class ApiRemoteControl():
|
||||
callback_progress: Optional[Callable] = None,
|
||||
callback_done: Optional[Callable] = None,
|
||||
force: bool = True,
|
||||
ignore_old_names: bool = True
|
||||
do_my_assets: bool = False
|
||||
) -> None:
|
||||
"""Convenience function to add a get assets job."""
|
||||
|
||||
if search != "":
|
||||
# With API V2 it's either search or categories, not both
|
||||
category_list = [CATEGORY_ALL]
|
||||
|
||||
params = ApiJobParamsGetAssets(library_paths,
|
||||
tab,
|
||||
category_list,
|
||||
@@ -385,7 +400,8 @@ class ApiRemoteControl():
|
||||
page_size,
|
||||
force_request,
|
||||
do_get_all,
|
||||
ignore_old_names)
|
||||
do_my_assets)
|
||||
|
||||
self.add_job(
|
||||
job_type=JobType.GET_ASSETS,
|
||||
params=params,
|
||||
@@ -395,6 +411,31 @@ class ApiRemoteControl():
|
||||
force=force,
|
||||
timeout=TIMEOUT)
|
||||
|
||||
def add_job_get_all_assets(
|
||||
self,
|
||||
library_paths: List[str],
|
||||
do_my_assets: bool = False,
|
||||
force_request: bool = False,
|
||||
callback_cancel: Optional[Callable] = None,
|
||||
callback_progress: Optional[Callable] = None,
|
||||
callback_done: Optional[Callable] = None,
|
||||
force: bool = True
|
||||
) -> None:
|
||||
"""Convenience function to add a get assets job."""
|
||||
|
||||
params = ApiJobParamsGetAllAssets(
|
||||
library_paths=library_paths,
|
||||
force_request=force_request,
|
||||
do_my_assets=do_my_assets)
|
||||
self.add_job(
|
||||
job_type=JobType.GET_ALL_ASSETS,
|
||||
params=params,
|
||||
callback_cancel=callback_cancel,
|
||||
callback_progress=callback_progress,
|
||||
callback_done=callback_done,
|
||||
force=force,
|
||||
timeout=TIMEOUT)
|
||||
|
||||
def add_job_download_thumb(self,
|
||||
asset_id: int,
|
||||
url: str,
|
||||
|
||||
+786
-272
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,103 @@
|
||||
# #### BEGIN GPL LICENSE BLOCK #####
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or
|
||||
# modify it under the terms of the GNU General Public License
|
||||
# as published by the Free Software Foundation; either version 2
|
||||
# of the License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation,
|
||||
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||
#
|
||||
# ##### END GPL LICENSE BLOCK #####
|
||||
|
||||
import enum
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from .multilingual import _m, _t
|
||||
from .user import PoliigonUser
|
||||
from .api_remote_control_params import (KEY_TAB_ONLINE,
|
||||
KEY_TAB_MY_ASSETS,
|
||||
KEY_TAB_RECENT_DOWNLOADS,
|
||||
KEY_TAB_IMPORTED,
|
||||
KEY_TAB_LOCAL)
|
||||
|
||||
|
||||
class FilterOptions(enum.Enum):
|
||||
ALL_ASSETS = _m("All Assets")
|
||||
OWNED_ASSETS = _m("Owned")
|
||||
RECENT_DOWNLOADED_ASSETS = _m("Downloads")
|
||||
LOCAL_ASSETS = _m("Local")
|
||||
IMPORTED_ASSETS = _m("Imported")
|
||||
|
||||
@classmethod
|
||||
def get_list(cls):
|
||||
return [member for member in FilterOptions]
|
||||
|
||||
@classmethod
|
||||
def get_list_for_user(cls, user: PoliigonUser):
|
||||
# If a user doesn't own any asset, the owned filter is not shown
|
||||
if user.count_assets_owned is not None and user.count_assets_owned == 0:
|
||||
return [member for member in FilterOptions
|
||||
if member != cls.OWNED_ASSETS]
|
||||
return cls.get_list()
|
||||
|
||||
@classmethod
|
||||
def get_string_list(cls, short_string: bool = False) -> List[str]:
|
||||
if short_string:
|
||||
return [member.map_to_short_value() for member in FilterOptions]
|
||||
return [member.value for member in FilterOptions]
|
||||
|
||||
@classmethod
|
||||
def get_translated_string_list(cls, short_string: bool = False) -> List[str]:
|
||||
if short_string:
|
||||
return [_t(member.map_to_short_value()) for member in FilterOptions]
|
||||
return [_t(member.value) for member in FilterOptions]
|
||||
|
||||
@classmethod
|
||||
def get_from_query(cls, key_query: str) -> Optional[Any]:
|
||||
if key_query == KEY_TAB_ONLINE:
|
||||
return cls.ALL_ASSETS
|
||||
elif key_query == KEY_TAB_MY_ASSETS:
|
||||
return cls.OWNED_ASSETS
|
||||
elif key_query == KEY_TAB_RECENT_DOWNLOADS:
|
||||
return cls.RECENT_DOWNLOADED_ASSETS
|
||||
elif key_query == KEY_TAB_LOCAL:
|
||||
return cls.LOCAL_ASSETS
|
||||
elif key_query == KEY_TAB_IMPORTED:
|
||||
return cls.IMPORTED_ASSETS
|
||||
else:
|
||||
return None
|
||||
|
||||
def map_to_query(self) -> Optional[str]:
|
||||
if self == self.ALL_ASSETS:
|
||||
return KEY_TAB_ONLINE
|
||||
elif self == self.OWNED_ASSETS:
|
||||
return KEY_TAB_MY_ASSETS
|
||||
elif self == self.RECENT_DOWNLOADED_ASSETS:
|
||||
return KEY_TAB_RECENT_DOWNLOADS
|
||||
elif self == self.LOCAL_ASSETS:
|
||||
return KEY_TAB_LOCAL
|
||||
elif self == self.IMPORTED_ASSETS:
|
||||
return KEY_TAB_IMPORTED
|
||||
else:
|
||||
return None
|
||||
|
||||
def map_to_short_value(self) -> Optional[str]:
|
||||
if self == self.ALL_ASSETS:
|
||||
return _m("All")
|
||||
elif self == self.OWNED_ASSETS:
|
||||
return self.value
|
||||
elif self == self.RECENT_DOWNLOADED_ASSETS:
|
||||
return self.value
|
||||
elif self == self.LOCAL_ASSETS:
|
||||
return self.value
|
||||
elif self == self.IMPORTED_ASSETS:
|
||||
return self.value
|
||||
else:
|
||||
return None
|
||||
@@ -34,13 +34,15 @@ from . import logger
|
||||
from .maps import MAPS_TYPE_NAMES, MapType
|
||||
|
||||
|
||||
ISOFORMAT_EPOCH_BEGIN = "1970-01-01T00:00:00.000Z"
|
||||
REGEX_DATETIME_ISOFORMAT = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}\.[0-9]{3}Z$"
|
||||
|
||||
# Compiled regex to avoid re-instancing each time.
|
||||
# Checks for preview being in the last section of a filename split into _'s
|
||||
# The [^_]* means match any character except another _, and $ asserts it's at
|
||||
# the end of the match string.
|
||||
_PREVIEW_PATTERN = re.compile(r"_[^_]*preview[^_]*$", re.IGNORECASE)
|
||||
|
||||
|
||||
# Used to report unknown asset type exactly once per session
|
||||
g_asset_unsupported_type_reported: bool = False
|
||||
|
||||
@@ -100,6 +102,7 @@ class AssetIndex():
|
||||
self.all_assets = {}
|
||||
self.cached_queries = {}
|
||||
self.reported_asset_ids = []
|
||||
self.reported_map_codes = []
|
||||
|
||||
self.use_lod_extras = use_lod_extras
|
||||
if use_lod_extras:
|
||||
@@ -115,9 +118,33 @@ class AssetIndex():
|
||||
return
|
||||
self.reporting_callable(message, code_msg, level, max_reports)
|
||||
|
||||
@staticmethod
|
||||
def _filter_image_urls(urls: List[str]) -> List[str]:
|
||||
return [url for url in urls if ".png" in url.lower() or ".jpg" in url.lower()]
|
||||
def _convert_api_time_to_timestamp(self, time_api: str) -> float:
|
||||
try:
|
||||
dt = datetime.fromisoformat(time_api)
|
||||
except ValueError:
|
||||
dt = datetime.fromisoformat(ISOFORMAT_EPOCH_BEGIN)
|
||||
warn = f"Time string from API off expected format: {time_api}"
|
||||
self.logger.warning(warn)
|
||||
|
||||
t_tuple = dt.utctimetuple()
|
||||
seconds_since_epoch = time.mktime(t_tuple)
|
||||
return seconds_since_epoch
|
||||
|
||||
def _get_api_time(self, data: Dict, key: str) -> float:
|
||||
t = data.get(key, ISOFORMAT_EPOCH_BEGIN)
|
||||
if t is None:
|
||||
# TODO(Andreas): Not sure how we would want to report this, here
|
||||
return 0.0
|
||||
if not re.match(REGEX_DATETIME_ISOFORMAT, t):
|
||||
# TODO(Andreas): Not sure how we would want to report this, here
|
||||
return 0.0
|
||||
# We know our times are UTC, yet, for unknown reasons Python's
|
||||
# datetime.fromisoformat() does not allow for the ISO conform
|
||||
# timezone qualifier ('Z' for UTC at the end).
|
||||
# So, we need to split it off.
|
||||
t = t[:-1]
|
||||
secs_since_epoch = self._convert_api_time_to_timestamp(t)
|
||||
return secs_since_epoch
|
||||
|
||||
def _get_cloudflare_thumbnails(
|
||||
self,
|
||||
@@ -129,27 +156,18 @@ class AssetIndex():
|
||||
return []
|
||||
|
||||
for thumb in cloudflare_thumbs:
|
||||
filename = thumb.get("file_name", None)
|
||||
thumb_time = thumb.get("time", None)
|
||||
|
||||
filename = thumb.get("fileName", None)
|
||||
if filename is None:
|
||||
warn = "Filename not available in Cloudflare Thumbnail."
|
||||
self.logger.warning(warn)
|
||||
continue
|
||||
filename = filename.split("?")[0]
|
||||
|
||||
if thumb_time is not None:
|
||||
try:
|
||||
thumb_time = datetime.strptime(
|
||||
thumb_time, "%Y-%m-%d %H:%M:%S")
|
||||
thumb_time = thumb_time.astimezone(timezone.utc).timestamp()
|
||||
except (ValueError, OSError):
|
||||
warn = "Thumbnail time string off expected format."
|
||||
self.logger.warning(warn)
|
||||
thumb_time = self._get_api_time(data=thumb, key="time")
|
||||
|
||||
thumb_class = assets.AssetThumbnail(
|
||||
filename=filename,
|
||||
base_url=thumb.get("base_url", None),
|
||||
base_url=thumb.get("baseUrl", None),
|
||||
index=thumb.get("position", None),
|
||||
time=thumb_time,
|
||||
type=thumb.get("type", None)
|
||||
@@ -160,7 +178,7 @@ class AssetIndex():
|
||||
@staticmethod
|
||||
def _get_texture_real_world_dimension(
|
||||
asset_dict: Dict) -> Optional[Tuple[float, float]]:
|
||||
""" Method to get Real World Dimensions of an asset in Convention 0.
|
||||
"""Method to get Real World Dimensions of an asset in Convention 0.
|
||||
For convention 1, use _decode_tex_convention_1 to get this info. """
|
||||
|
||||
dimension_height = None
|
||||
@@ -168,27 +186,23 @@ class AssetIndex():
|
||||
dimension_unit_str = ""
|
||||
dimension_dict = {}
|
||||
|
||||
asset_name = asset_dict.get("asset_name", "")
|
||||
asset_name = asset_dict.get("AssetName", "")
|
||||
# Ignore dimensions for Atlas Textures
|
||||
if (asset_name.lower()).startswith("atlas"):
|
||||
return None
|
||||
|
||||
dimension_str = None
|
||||
render_custom_schema = asset_dict.get("render_custom_schema", None)
|
||||
if render_custom_schema is not None:
|
||||
dimension_str = render_custom_schema.get("dimensions", None)
|
||||
|
||||
technical_desc = render_custom_schema.get("technical_description", {})
|
||||
if type(technical_desc) is dict:
|
||||
dimension_dict = technical_desc.get("Dimensions", {})
|
||||
if dimension_dict is None:
|
||||
dimension_dict = {}
|
||||
specs = asset_dict.get("Specifications", None)
|
||||
if specs is not None:
|
||||
dimension_str = specs.get("dimensions", None)
|
||||
|
||||
dimension_dict = specs.get("physical_size_cm", {})
|
||||
if dimension_dict is None:
|
||||
dimension_dict = {}
|
||||
|
||||
for key in dimension_dict.keys():
|
||||
dimension_key = key.split(" ")
|
||||
if len(dimension_key) < 2:
|
||||
continue
|
||||
dimension_unit_str = dimension_key[1]
|
||||
dimension_unit_str = "cm"
|
||||
if "height" in key.lower():
|
||||
dimension_height = float(dimension_dict[key])
|
||||
elif "width" in key.lower():
|
||||
@@ -203,7 +217,7 @@ class AssetIndex():
|
||||
dimension_width = float(dimension_val[2])
|
||||
dimension_unit_str = str(dimension_val[3])
|
||||
except (ValueError, IndexError):
|
||||
# if the first value is not integer, consider the format invalid
|
||||
# if the first value is not float, consider the format invalid
|
||||
return None
|
||||
|
||||
if dimension_height is None and dimension_width is None:
|
||||
@@ -233,19 +247,29 @@ class AssetIndex():
|
||||
|
||||
# Always get Metalness as Workflow for Convention 1
|
||||
workflow = "METALNESS"
|
||||
asset_maps = asset_dict.get("maps", [])
|
||||
asset_maps = asset_dict.get("Maps", [])
|
||||
|
||||
all_sizes = asset_dict.get("resolutions", [])
|
||||
all_sizes = asset_dict.get("Resolutions", [])
|
||||
sizes = [_size for _size in all_sizes if self._is_valid_size(_size)]
|
||||
|
||||
asset_id = asset_dict.get("AssetID", -1)
|
||||
maps_dictionary = {workflow: []}
|
||||
for _map_desc in asset_maps:
|
||||
try:
|
||||
map_code = _map_desc.get("type", "UNKNOWN")
|
||||
file_formats = _map_desc.get("file_formats", [])
|
||||
|
||||
if MapType.from_type_code(map_code) == MapType.UNKNOWN:
|
||||
msg = (f"AssetId: {asset_id} Unknown Map "
|
||||
f"type from server: {map_code}")
|
||||
self.logger.warning(msg)
|
||||
if map_code not in self.reported_map_codes:
|
||||
self.reported_map_codes.append(map_code)
|
||||
self.capture_message("unknown_map_type_found", msg)
|
||||
continue
|
||||
except Exception:
|
||||
map_code = "UNKNOWN"
|
||||
file_formats = []
|
||||
asset_id = asset_dict.get("id", -1)
|
||||
msg = f"Asset {asset_id}: Invalid maps data\n{asset_maps}"
|
||||
self.capture_message("assetindex_invalid_maps_data", msg)
|
||||
|
||||
@@ -257,7 +281,7 @@ class AssetIndex():
|
||||
variants=[])
|
||||
maps_dictionary[workflow].append(description)
|
||||
|
||||
specs = asset_dict.get("specifications", {})
|
||||
specs = asset_dict.get("Specifications", {})
|
||||
|
||||
creation_str = specs.get("creation_method", "")
|
||||
creation_method = assets.CreationMethodId.from_string(creation_str)
|
||||
@@ -271,8 +295,8 @@ class AssetIndex():
|
||||
|
||||
return maps_dictionary, sizes, dimensions, creation_method
|
||||
|
||||
@staticmethod
|
||||
def _decode_render_schema_tex(asset_dict: Dict
|
||||
def _decode_render_schema_tex(self,
|
||||
asset_dict: Dict
|
||||
) -> Tuple[Dict[str, assets.TextureMapDesc],
|
||||
List[str],
|
||||
List[str]]:
|
||||
@@ -286,26 +310,34 @@ class AssetIndex():
|
||||
NOTE: This is the default Texture decode function for convention 0
|
||||
"""
|
||||
|
||||
if "render_schema" not in asset_dict.keys():
|
||||
if "Maps" not in asset_dict.keys():
|
||||
return ({}, [], [])
|
||||
|
||||
all_sizes = []
|
||||
asset_id = asset_dict.get("AssetID", -1)
|
||||
all_sizes = asset_dict.get("Resolutions", [])
|
||||
all_variants = []
|
||||
tex_desc_dict = {} # {workflow: List[TextureMapDesc]
|
||||
for schema in asset_dict["render_schema"]:
|
||||
if "types" not in schema.keys():
|
||||
continue
|
||||
|
||||
workflow = schema.get("name", "REGULAR")
|
||||
dict_maps = asset_dict.get("Maps", [])
|
||||
|
||||
for _maps_dict in dict_maps:
|
||||
workflow = _maps_dict.get("workflow", "REGULAR")
|
||||
|
||||
tex_descs = []
|
||||
for tex_type in schema.get("types", []):
|
||||
tex_code = tex_type.get("type_code", "")
|
||||
for _tex_code in _maps_dict.get("maps", []):
|
||||
variant = None
|
||||
if "_" in tex_code:
|
||||
tex_code, variant = tex_code.split("_")
|
||||
if tex_code not in MAPS_TYPE_NAMES:
|
||||
tex_code = MapType.UNKNOWN.name
|
||||
map_type = MapType[tex_code]
|
||||
if "_" in _tex_code:
|
||||
_tex_code, variant = _tex_code.split("_")
|
||||
|
||||
map_type = MapType.from_type_code(_tex_code)
|
||||
if MapType.from_type_code(_tex_code) == MapType.UNKNOWN:
|
||||
msg = (f"AssetId: {asset_id} Unknown Map "
|
||||
f"type from server: {_tex_code}")
|
||||
self.logger.warning(msg)
|
||||
if _tex_code not in self.reported_map_codes:
|
||||
self.reported_map_codes.append(_tex_code)
|
||||
self.capture_message("unknown_map_type_found", msg)
|
||||
continue
|
||||
|
||||
if variant is not None:
|
||||
variants = [variant]
|
||||
@@ -313,16 +345,15 @@ class AssetIndex():
|
||||
else:
|
||||
variants = None
|
||||
|
||||
type_code = tex_type.get("type_code", "")
|
||||
type_name = tex_type.get("type_name", "")
|
||||
type_options = tex_type.get("type_options", [])
|
||||
type_preview = tex_type.get("type_preview", "")
|
||||
tex_desc = assets.TextureMapDesc(map_type_code=type_code,
|
||||
file_formats=[], # convention 0 assets don't have formats
|
||||
display_name=type_name,
|
||||
sizes=type_options,
|
||||
filename_preview=type_preview,
|
||||
variants=variants)
|
||||
type_preview = "" # TODO(Andreas): no longer available?
|
||||
tex_desc = assets.TextureMapDesc(
|
||||
map_type_code=_tex_code,
|
||||
# TODO(Andreas): It seems with API V2 we would have formats available
|
||||
file_formats=[], # convention 0 assets don't have formats
|
||||
display_name=_tex_code,
|
||||
sizes=all_sizes,
|
||||
filename_preview=type_preview,
|
||||
variants=variants)
|
||||
tex_desc_variant = None
|
||||
if variant is None:
|
||||
for tex_desc_prev in tex_descs:
|
||||
@@ -337,8 +368,6 @@ class AssetIndex():
|
||||
# variants are otherwise identical
|
||||
tex_desc_variant.variants.extend(tex_desc.variants)
|
||||
|
||||
all_sizes.extend(tex_desc.sizes)
|
||||
|
||||
tex_desc_dict[workflow] = tex_descs
|
||||
|
||||
# consolidate all sizes and variants for use in menus
|
||||
@@ -367,31 +396,27 @@ class AssetIndex():
|
||||
tuple[1] - Default size from included_resolution (if available)
|
||||
"""
|
||||
|
||||
if "render_schema" not in asset_dict.keys():
|
||||
msg = f"'render_schema' missing in asset dict\n{asset_dict}"
|
||||
self.capture_message("assetindex_no_renderschema", msg)
|
||||
return [], ""
|
||||
render_schema = asset_dict.get("render_schema", {})
|
||||
if "options" not in render_schema.keys():
|
||||
msg = f"'options' missing in 'render_schema'\n{render_schema}"
|
||||
self.capture_message("assetindex_no_renderschema_options", msg)
|
||||
return [], ""
|
||||
|
||||
all_sizes = render_schema.get("options", [])
|
||||
all_sizes = asset_dict.get("Resolutions", [])
|
||||
all_sizes = [size for size in all_sizes if self._is_valid_size(size)]
|
||||
|
||||
render_custom_schema = asset_dict.get("render_custom_schema", {})
|
||||
incl_size = None
|
||||
if "included_resolution" in render_custom_schema.keys():
|
||||
incl_size = render_custom_schema.get("included_resolution", "")
|
||||
if "DefaultResolution" in asset_dict.keys():
|
||||
incl_size = asset_dict.get("DefaultResolution", "")
|
||||
if self._is_valid_size(incl_size):
|
||||
all_sizes.extend([incl_size])
|
||||
|
||||
all_sizes = sorted(list(set(all_sizes)),
|
||||
key=lambda s: int(s[:-1]))
|
||||
|
||||
return all_sizes, incl_size
|
||||
|
||||
def _construct_watermarked_urls(self, asset_dict: Dict) -> List[str]:
|
||||
toolbox_previews = asset_dict.get("ToolboxPreviews", [])
|
||||
urls_wm = []
|
||||
for _preview in toolbox_previews:
|
||||
url_base = _preview.get("baseUrl")
|
||||
filename = _preview.get("fileName")
|
||||
urls_wm.append(f"{url_base}/{filename}")
|
||||
return urls_wm
|
||||
|
||||
def _construct_brush(self,
|
||||
asset_dict: Dict,
|
||||
convention: int
|
||||
@@ -406,8 +431,8 @@ class AssetIndex():
|
||||
"""Constructs a Model"""
|
||||
|
||||
model = assets.Model()
|
||||
if "lods" in asset_dict.keys():
|
||||
model.lods = asset_dict["lods"]
|
||||
if "LODs" in asset_dict.keys():
|
||||
model.lods = asset_dict.get("LODs", [])
|
||||
model.sizes, model.size_default = self._decode_render_schema_model(
|
||||
asset_dict)
|
||||
return model
|
||||
@@ -438,8 +463,7 @@ class AssetIndex():
|
||||
tex_bg = assets.Texture(map_descs=tex_map_descs_bg,
|
||||
sizes=sizes,
|
||||
variants=variants)
|
||||
tex_bg.watermarked_urls = self._filter_image_urls(
|
||||
asset_dict["toolbox_previews"])
|
||||
tex_bg.watermarked_urls = []
|
||||
tex_bg.maps = {}
|
||||
|
||||
tex_light = assets.Texture(map_descs=tex_map_descs_light,
|
||||
@@ -459,7 +483,10 @@ class AssetIndex():
|
||||
|
||||
creation_method = None
|
||||
if convention == 1:
|
||||
maps, sizes, dimension, creation_method = self._decode_tex_convention_1(asset_dict)
|
||||
(maps,
|
||||
sizes,
|
||||
dimension,
|
||||
creation_method) = self._decode_tex_convention_1(asset_dict)
|
||||
variants = []
|
||||
else:
|
||||
maps, sizes, variants = self._decode_render_schema_tex(asset_dict)
|
||||
@@ -470,16 +497,17 @@ class AssetIndex():
|
||||
variants=variants,
|
||||
real_world_dimension=dimension,
|
||||
creation_method=creation_method)
|
||||
tex.watermarked_urls = self._filter_image_urls(
|
||||
asset_dict["toolbox_previews"])
|
||||
tex.watermarked_urls = self._construct_watermarked_urls(asset_dict)
|
||||
tex.maps = {}
|
||||
return tex
|
||||
|
||||
def _construct_asset_base(self, asset_dict: Dict) -> assets.AssetData:
|
||||
global g_asset_unsupported_type_reported
|
||||
|
||||
asset_name = asset_dict["name"]
|
||||
asset_type_api = asset_dict["type"]
|
||||
asset_id = asset_dict.get("AssetID", 0)
|
||||
asset_name = asset_dict.get("AssetName", "NoName")
|
||||
display_name = asset_dict.get("Name", "No Name")
|
||||
asset_type_api = asset_dict.get("Type", "Unsupported")
|
||||
|
||||
try:
|
||||
asset_type = assets.API_TYPE_TO_ASSET_TYPE[asset_type_api]
|
||||
@@ -487,18 +515,17 @@ class AssetIndex():
|
||||
asset_type = assets.AssetType.UNSUPPORTED
|
||||
|
||||
if asset_type == assets.AssetType.SUBSTANCE:
|
||||
msg = f"{asset_name}: {asset_type_api} not supported, yet"
|
||||
msg = f"{display_name}: {asset_type_api} not supported, yet"
|
||||
raise NotImplementedError(msg)
|
||||
elif asset_type == assets.AssetType.UNSUPPORTED:
|
||||
msg = f"{asset_name}: {asset_type_api} not supported, yet"
|
||||
msg = f"{display_name}: {asset_type_api} not supported, yet"
|
||||
if g_asset_unsupported_type_reported is False:
|
||||
self.capture_message("assetindex_unsupported_type", msg)
|
||||
g_asset_unsupported_type_reported = True
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return assets.AssetData(asset_id=asset_dict["id"],
|
||||
asset_type=asset_type,
|
||||
asset_name=asset_dict["asset_name"])
|
||||
return assets.AssetData(
|
||||
asset_id=asset_id, asset_type=asset_type, asset_name=asset_name)
|
||||
|
||||
def _populate_default_asset_info(self,
|
||||
asset_data: assets.AssetData,
|
||||
@@ -507,39 +534,49 @@ class AssetIndex():
|
||||
) -> None:
|
||||
"""Constructs AssetData part common to all types"""
|
||||
|
||||
asset_data.display_name = asset_dict["name"]
|
||||
asset_data.display_name = asset_dict.get("Name", "No Name")
|
||||
asset_data.categories = []
|
||||
for category in asset_dict["categories"]:
|
||||
category = category.title()
|
||||
if category in assets.CATEGORY_TRANSLATION:
|
||||
category = assets.CATEGORY_TRANSLATION[category]
|
||||
asset_id = asset_data.asset_id
|
||||
for category in asset_dict.get("Categories", []):
|
||||
asset_data.categories.append(category)
|
||||
asset_data.url = asset_dict["url"]
|
||||
asset_data.slug = asset_dict["slug"]
|
||||
asset_data.credits = asset_dict["credit"]
|
||||
slug = asset_dict.get("Slug", "no-name")
|
||||
asset_data.slug = slug
|
||||
|
||||
asset_data.url = asset_dict.get("WebsiteUrl", "")
|
||||
|
||||
asset_data.credits = asset_dict.get("Credit", 1)
|
||||
|
||||
try:
|
||||
asset_data.api_convention = int(asset_dict.get("convention", 0))
|
||||
asset_data.api_convention = int(asset_dict.get("Convention", 0))
|
||||
except ValueError:
|
||||
asset_data.api_convention = 0
|
||||
asset_data.local_convention = None # to be determined during update_from_directory()
|
||||
|
||||
dict_previews = asset_dict.get("Previews", [])
|
||||
asset_data.cloudflare_thumb_urls = self._get_cloudflare_thumbnails(
|
||||
asset_dict.get("cloudflare_previews", None))
|
||||
dict_previews)
|
||||
|
||||
asset_data.thumb_urls = self._filter_image_urls(asset_dict["previews"])
|
||||
published_at = asset_dict["published_at"]
|
||||
t_published_at = time.strptime(published_at, "%Y-%m-%d %H:%M:%S")
|
||||
seconds_since_epoch = time.mktime(t_published_at)
|
||||
asset_data.published_at = seconds_since_epoch # TODO(Andreas): need to take timezone into account
|
||||
seconds_since_epoch = self._get_api_time(
|
||||
data=asset_dict, key="PublishedAt")
|
||||
|
||||
asset_data.published_at = seconds_since_epoch
|
||||
asset_data.is_local = None
|
||||
asset_data.downloaded_at = None
|
||||
asset_data.is_purchased = purchased
|
||||
asset_data.purchased_at = None
|
||||
asset_data.render_custom_schema = asset_dict.get(
|
||||
"render_custom_schema", {})
|
||||
asset_data.old_asset_names = asset_data.render_custom_schema.get(
|
||||
"previous_filenames", None)
|
||||
|
||||
rcs = asset_dict.get("RenderCustomSchema", {})
|
||||
specs = asset_dict.get("Specifications", {})
|
||||
if type(rcs) is list:
|
||||
asset_data.status.has_error = True
|
||||
asset_data.status.error = "No Render Custom Schema"
|
||||
rcs = {}
|
||||
if asset_id not in self.reported_asset_ids:
|
||||
msg = f"AssetId: {asset_id} Error: {asset_data.status.error}"
|
||||
self.capture_message("fail_asset_data_populate_rcs", msg)
|
||||
self.reported_asset_ids.append(asset_id)
|
||||
asset_data.render_custom_schema = rcs
|
||||
asset_data.specifications = specs
|
||||
|
||||
def construct_asset(self,
|
||||
asset_dict: Dict,
|
||||
@@ -570,6 +607,7 @@ class AssetIndex():
|
||||
except Exception as e:
|
||||
asset_data.state.has_error = True
|
||||
asset_data.state.error = e
|
||||
asset_data.create_dummy_type_data()
|
||||
if asset_data.asset_id not in self.reported_asset_ids:
|
||||
msg = f"AssetId: {asset_data.asset_id} Error: {e}"
|
||||
self.capture_message("fail_asset_data_populate", msg)
|
||||
@@ -623,6 +661,13 @@ class AssetIndex():
|
||||
utc_s_since_epoch = datetime.now(timezone.utc).timestamp()
|
||||
self.all_assets[asset_id].purchased_at = utc_s_since_epoch
|
||||
|
||||
def mark_recent_downloaded(self, asset_id: int) -> None:
|
||||
"""Marks an AssetData as Recent Downloaded"""
|
||||
|
||||
if asset_id not in self.all_assets:
|
||||
return
|
||||
self.all_assets[asset_id].is_recent_downloaded = True
|
||||
|
||||
def _map_type_from_filename_parts(self, filename_parts: List[str]):
|
||||
"""Gets a MapType (and its workflow) from a list of parts of a filename.
|
||||
|
||||
@@ -787,7 +832,6 @@ class AssetIndex():
|
||||
if lod is not None:
|
||||
lods.append(lod)
|
||||
else:
|
||||
lods.append("NONE")
|
||||
lod = "NONE"
|
||||
if size is not None:
|
||||
sizes.append(size)
|
||||
@@ -1108,10 +1152,8 @@ class AssetIndex():
|
||||
asset_data.is_local = False
|
||||
return files_found
|
||||
|
||||
def gather_asset_name_dict(self,
|
||||
asset_id_list: List[int],
|
||||
ignore_old_names: bool = True
|
||||
) -> Dict[str, int]:
|
||||
def gather_asset_name_dict(
|
||||
self, asset_id_list: List[int]) -> Dict[str, int]:
|
||||
"""Returns a dictionary with asset name keys (including optional old
|
||||
names) and asset ID values for all asset IDs in list."""
|
||||
|
||||
@@ -1124,16 +1166,6 @@ class AssetIndex():
|
||||
asset_name = asset_data.asset_name
|
||||
asset_name_dict[asset_name] = asset_id
|
||||
|
||||
if ignore_old_names:
|
||||
continue
|
||||
|
||||
old_asset_names = asset_data.old_asset_names
|
||||
if old_asset_names is None or len(old_asset_names) == 0:
|
||||
continue
|
||||
|
||||
for _old_name in old_asset_names:
|
||||
asset_name_dict[_old_name] = asset_id
|
||||
|
||||
return asset_name_dict
|
||||
|
||||
def update_all_local_assets(self,
|
||||
@@ -1179,7 +1211,7 @@ class AssetIndex():
|
||||
asset_type=None, purchased=purchased)
|
||||
|
||||
asset_name_dict = self.gather_asset_name_dict(
|
||||
asset_id_list, ignore_old_names=ignore_old_names)
|
||||
asset_id_list)
|
||||
|
||||
# update_from_directory() overwrites file reference entries.
|
||||
# Thus the primary library directory has to be scanned last.
|
||||
@@ -1333,7 +1365,7 @@ class AssetIndex():
|
||||
except NotImplementedError:
|
||||
# Silence Substance exceptions
|
||||
self.logger.info(("Unsupported asset type encountered.\n"
|
||||
f" {asset_dict['name']}: {asset_dict['type']}"))
|
||||
f" {asset_dict['Name']}: {asset_dict['Type']}"))
|
||||
if query_tuple in self.cached_queries:
|
||||
self.cached_queries[query_tuple].extend(tmp_cached_query)
|
||||
return tmp_cached_query
|
||||
@@ -1412,12 +1444,16 @@ class AssetIndex():
|
||||
asset_ids: List[int],
|
||||
key_query: str,
|
||||
chunk: Optional[int] = None,
|
||||
chunk_size: Optional[int] = None
|
||||
chunk_size: Optional[int] = None,
|
||||
do_append: bool = False
|
||||
) -> None:
|
||||
"""Stores a list of asset IDs in query cache."""
|
||||
|
||||
query_tuple = self._query_key_to_tuple(key_query, chunk, chunk_size)
|
||||
self.cached_queries[query_tuple] = asset_ids
|
||||
if do_append:
|
||||
self.cached_queries[query_tuple].extend(asset_ids)
|
||||
else:
|
||||
self.cached_queries[query_tuple] = asset_ids
|
||||
|
||||
def query_exists(self,
|
||||
key_query: str,
|
||||
@@ -1615,7 +1651,8 @@ class AssetIndex():
|
||||
lod: Optional[str] = None,
|
||||
model_type: Optional[assets.ModelType] = None,
|
||||
native_only: bool = False,
|
||||
renderer: Optional[str] = None
|
||||
renderer: Optional[str] = None,
|
||||
check_map_preferences: bool = False
|
||||
) -> bool:
|
||||
"""Checks if an asset (or a flavor thereof) has been downloaded.
|
||||
|
||||
@@ -1646,7 +1683,7 @@ class AssetIndex():
|
||||
incl_watermarked = size == "WM"
|
||||
|
||||
local_sizes = self.check_asset_local_sizes(
|
||||
asset_id, workflow, incl_watermarked)
|
||||
asset_id, workflow, incl_watermarked, check_map_preferences)
|
||||
if size is None:
|
||||
tex_is_local = any(local_sizes.values())
|
||||
else:
|
||||
@@ -1682,7 +1719,8 @@ class AssetIndex():
|
||||
def check_asset_local_sizes(self,
|
||||
asset_id: int,
|
||||
workflow: Optional[str] = "REGULAR",
|
||||
incl_watermarked: bool = False
|
||||
incl_watermarked: bool = False,
|
||||
check_map_preferences: bool = False
|
||||
) -> Dict[str, bool]:
|
||||
"""Returns texture 'locality' by size.
|
||||
|
||||
@@ -1700,13 +1738,24 @@ class AssetIndex():
|
||||
asset_data = self.all_assets[asset_id]
|
||||
type_data = asset_data.get_type_data()
|
||||
|
||||
local_sizes = {}
|
||||
all_sizes = type_data.get_size_list(incl_watermarked) # local only False, no conv needed
|
||||
local_convention = asset_data.get_convention(local=True)
|
||||
if local_convention is not None:
|
||||
convention_min = min(local_convention, self.addon_convention)
|
||||
else:
|
||||
convention_min = self.addon_convention
|
||||
|
||||
map_prefs = None
|
||||
if check_map_preferences:
|
||||
map_prefs = self.addon.user.map_preferences
|
||||
|
||||
local_sizes = {}
|
||||
all_sizes = type_data.get_size_list(incl_watermarked=incl_watermarked,
|
||||
# Local only set as false to get
|
||||
# False values on return dict;
|
||||
local_only=False,
|
||||
local_convention=local_convention,
|
||||
addon_convention=self.addon_convention,
|
||||
map_preferences=map_prefs)
|
||||
for size in all_sizes:
|
||||
if workflow is None:
|
||||
workflow_list = self.get_asset_workflow_list(asset_id)
|
||||
@@ -1773,37 +1822,6 @@ class AssetIndex():
|
||||
|
||||
return local_lods
|
||||
|
||||
# TODO(Joao): Deprecated - Delete function and use cloudflare previews
|
||||
def get_thumbnail_url_list(self, asset_id: int) -> List[str]:
|
||||
"""Gets _all_ URLs for an asset's thumbnails"""
|
||||
|
||||
if asset_id not in self.all_assets:
|
||||
return None
|
||||
ad = self.all_assets[asset_id]
|
||||
return ad.thumb_urls
|
||||
|
||||
# TODO(Joao): Deprecated - Delete function and use cloudflare previews
|
||||
def get_thumbnail_url_by_index(self,
|
||||
asset_id: int,
|
||||
index: int = 0) -> Optional[str]:
|
||||
"""Returns preview url via index, if index exists,
|
||||
otherwise the first preview url will be returned.
|
||||
|
||||
Return value may be None, e.g. in case of dummy entries.
|
||||
"""
|
||||
|
||||
if index < 0:
|
||||
raise ValueError
|
||||
if asset_id not in self.all_assets:
|
||||
return None
|
||||
ad = self.all_assets[asset_id]
|
||||
if ad.thumb_urls is None or len(ad.thumb_urls) == 0:
|
||||
return None
|
||||
elif index < len(ad.thumb_urls):
|
||||
return ad.thumb_urls[index]
|
||||
else:
|
||||
return ad.thumb_urls[0]
|
||||
|
||||
def get_asset_cf_thumbnails(self, asset_id):
|
||||
if asset_id not in self.all_assets:
|
||||
return None
|
||||
@@ -1880,44 +1898,6 @@ class AssetIndex():
|
||||
return path_thumb, url
|
||||
return None, None
|
||||
|
||||
# TODO(Joao): Deprecated - Delete function and use cloudflare previews
|
||||
def get_thumbnail_url_by_name(self,
|
||||
asset_id: int,
|
||||
name: str = "sphere") -> Optional[str]:
|
||||
"""Returns preview url via name extension, if it exists.
|
||||
|
||||
Return value may be None, e.g. in case name not found.
|
||||
"""
|
||||
if asset_id not in self.all_assets:
|
||||
return None
|
||||
asset_data = self.all_assets[asset_id]
|
||||
if asset_data.thumb_urls is None or len(asset_data.thumb_urls) == 0:
|
||||
return None
|
||||
|
||||
name = name.lower()
|
||||
result_url = None
|
||||
for url in asset_data.thumb_urls:
|
||||
if name in url.lower():
|
||||
result_url = url
|
||||
break
|
||||
return result_url
|
||||
|
||||
# TODO(Andreas): maybe not URLs...
|
||||
def get_large_preview_url_list(self, asset_id: int) -> List[str]:
|
||||
"""Gets _all_ URLs for an asset's large previews"""
|
||||
|
||||
# TODO(Andreas)
|
||||
return []
|
||||
|
||||
def get_large_preview_url(self,
|
||||
asset_id: int,
|
||||
index: int = 0
|
||||
) -> Optional[str]:
|
||||
"""Gets URL for an asset's larrge preview"""
|
||||
|
||||
# TODO(Andreas)
|
||||
return ""
|
||||
|
||||
def get_watermark_preview_url_list(self,
|
||||
asset_id: int
|
||||
) -> Optional[List[str]]:
|
||||
@@ -2046,7 +2026,8 @@ class AssetIndex():
|
||||
def get_asset_id_list(self,
|
||||
asset_type: Optional[assets.AssetType] = None,
|
||||
purchased: bool = None,
|
||||
local: bool = None
|
||||
local: bool = None,
|
||||
check_map_prefs: bool = False
|
||||
) -> List[int]:
|
||||
"""Return a list of asset IDs in AssetIndex.
|
||||
Optionally restricted by per type and/or is_purchased flag.
|
||||
@@ -2055,30 +2036,31 @@ class AssetIndex():
|
||||
asset_type: Restrict list to a specific type. Use None for any type.
|
||||
purchased: Restrict list to (non-)purchased assets. Use None for both.
|
||||
local: Restrict list to (non-)local assets. Use None for both.
|
||||
check_map_prefs: Restrict list to ready to import assets (considering map prefs).
|
||||
"""
|
||||
|
||||
# TODO(Andreas): Just realized, we could likely speed this up
|
||||
# by considering only asset_ids in "my_assets"
|
||||
# (from query cache) in case purchased == True
|
||||
asset_id_list = [
|
||||
all_asset_id_list = [
|
||||
asset_data.asset_id for asset_data in self.all_assets.values()
|
||||
if asset_type is None or asset_data.asset_type == asset_type
|
||||
]
|
||||
if purchased is None and local is None:
|
||||
return asset_id_list
|
||||
return all_asset_id_list
|
||||
|
||||
asset_id_list = [
|
||||
asset_id for asset_id in asset_id_list
|
||||
purchased_asset_id_list = [
|
||||
asset_id for asset_id in all_asset_id_list
|
||||
if self.all_assets[asset_id].is_purchased == purchased
|
||||
]
|
||||
if local is None:
|
||||
return asset_id_list
|
||||
return purchased_asset_id_list
|
||||
|
||||
asset_id_list = [
|
||||
asset_id for asset_id in asset_id_list
|
||||
if self.all_assets[asset_id].is_local == local
|
||||
filter_local_list = all_asset_id_list if purchased is None else purchased_asset_id_list
|
||||
local_asset_id_list = [
|
||||
asset_id for asset_id in filter_local_list
|
||||
if self.check_asset_is_local(asset_id,
|
||||
check_map_preferences=check_map_prefs
|
||||
) == local
|
||||
]
|
||||
return asset_id_list
|
||||
return local_asset_id_list
|
||||
|
||||
def num_assets(self, asset_type: Optional[assets.AssetType] = None) -> int:
|
||||
"""Returns the number of assets, optionally per type"""
|
||||
@@ -2104,19 +2086,19 @@ class AssetIndex():
|
||||
def _init_categories(self, categories):
|
||||
for category in categories:
|
||||
category["asset_count"] = 0
|
||||
self._init_categories(category["children"])
|
||||
self._init_categories(category.get("children", []))
|
||||
|
||||
def _count_asset(self, categories, asset_categories):
|
||||
num_asset_categories = len(asset_categories)
|
||||
for category in categories:
|
||||
category_name = category["name"]
|
||||
if category_name not in asset_categories:
|
||||
id_category = category.get("id", -1)
|
||||
if id_category not in asset_categories:
|
||||
continue
|
||||
asset_categories.remove(category_name)
|
||||
category["asset_count"] += 1
|
||||
self._count_asset(category["children"], asset_categories)
|
||||
asset_categories.remove(id_category)
|
||||
category["asset_count"] = category.get("asset_count", 0) + 1
|
||||
self._count_asset(category.get("children", []), asset_categories)
|
||||
break
|
||||
if len(asset_categories) > 0 and len(asset_categories) < num_asset_categories:
|
||||
if 0 < len(asset_categories) < num_asset_categories:
|
||||
self._count_asset(categories, asset_categories)
|
||||
|
||||
def get_asset_count_per_category(self,
|
||||
@@ -2133,13 +2115,14 @@ class AssetIndex():
|
||||
# Top level is different,
|
||||
# as it actually contains AssetTypes, not categories
|
||||
for category in categories:
|
||||
asset_type_name = category["name"]
|
||||
id_cat_asset_type = category.get("id", -1)
|
||||
asset_type_name = category.get("name", "Unknown Category")
|
||||
if asset_type_name == CATEGORY_FREE:
|
||||
continue
|
||||
|
||||
asset_type = assets.AssetType.type_from_api(asset_type_name)
|
||||
|
||||
# filter depending on purchased and downloaded
|
||||
# Filter depending on purchased and downloaded
|
||||
if purchased:
|
||||
asset_ids_per_type[asset_type] = [
|
||||
asset_id for asset_id in asset_ids_per_type[asset_type]
|
||||
@@ -2158,10 +2141,11 @@ class AssetIndex():
|
||||
# important copy(), as we remove categories from the list
|
||||
asset_categories = asset_data.categories.copy()
|
||||
|
||||
if asset_type_name in asset_categories:
|
||||
asset_categories.remove(asset_type_name)
|
||||
if id_cat_asset_type in asset_categories:
|
||||
asset_categories.remove(id_cat_asset_type)
|
||||
|
||||
self._count_asset(category["children"], asset_categories)
|
||||
self._count_asset(
|
||||
category.get("children", []), asset_categories)
|
||||
if len(asset_categories):
|
||||
msg_warn = (f"Did not count all categories ({asset_id})!\n"
|
||||
f"Left over: {asset_categories}\n"
|
||||
@@ -2180,6 +2164,12 @@ class AssetIndex():
|
||||
type_data.get_files(files_dict)
|
||||
return files_dict
|
||||
|
||||
def flush_state(self) -> None:
|
||||
"""Flushes all Downloads and Purchases params"""
|
||||
|
||||
for asset_data in self.all_assets.values():
|
||||
asset_data.flush_state()
|
||||
|
||||
def flush_is_local(self) -> None:
|
||||
"""Flushes all is_local flags"""
|
||||
|
||||
@@ -2224,9 +2214,9 @@ class AssetIndex():
|
||||
raise ValueError("Asset ID already in use")
|
||||
if len(asset_name) == 0:
|
||||
raise ValueError("Please specify an asset name")
|
||||
if asset_type not in ["HDRIs", "Models", "Textures"]:
|
||||
if asset_type not in ["HDRIS", "Models", "Textures"]:
|
||||
msg = (f"Unknown asset type: {asset_type}\n"
|
||||
"Known types: HDRIs, Models, Textures")
|
||||
"Known types: HDRIS, Models, Textures")
|
||||
raise ValueError(msg)
|
||||
return True
|
||||
|
||||
@@ -2351,13 +2341,13 @@ class AssetIndex():
|
||||
asset_data.url = None
|
||||
asset_data.slug = None
|
||||
asset_data.credits = 0
|
||||
asset_data.thumb_urls = None
|
||||
asset_data.published_at = date_now
|
||||
asset_data.is_local = True
|
||||
asset_data.downloaded_at = None
|
||||
asset_data.is_purchased = True
|
||||
asset_data.purchased_at = None
|
||||
asset_data.render_custom_schema = None
|
||||
asset_data.specifications = None
|
||||
asset_data.api_convention = convention
|
||||
asset_data.local_convention = convention
|
||||
|
||||
|
||||
@@ -42,11 +42,10 @@ PREVIEWS = ["_atlas",
|
||||
"_grid",
|
||||
]
|
||||
PREVIEW_EXTS_LOWER = [".jpg", ".jpeg", ".png"]
|
||||
MAP_EXT_LOWER = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".exr", ".hdr", ".psd"]
|
||||
MAP_EXT_LOWER = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".exr", ".hdr", ".webp", ".psd"]
|
||||
SIZES = ["256", "512"] + [f"{idx+1}K" for idx in range(18)] + ["HIRES", "WM"]
|
||||
VARIANTS = [f"VAR{idx}" for idx in range(1, 10)]
|
||||
WORKFLOWS = ["REGULAR", "METALNESS", "SPECULAR"]
|
||||
CATEGORY_TRANSLATION = {"Hdrs": "HDRIs"}
|
||||
|
||||
|
||||
class ModelType(IntEnum):
|
||||
@@ -138,7 +137,7 @@ CREATION_METHODS = {
|
||||
CATEGORY_NAME_TO_ASSET_TYPE = {
|
||||
"All Assets": None,
|
||||
"Brushes": AssetType.BRUSH,
|
||||
"HDRIs": AssetType.HDRI,
|
||||
"HDRIS": AssetType.HDRI,
|
||||
"Models": AssetType.MODEL,
|
||||
"Substances": AssetType.SUBSTANCE,
|
||||
"Textures": AssetType.TEXTURE,
|
||||
@@ -162,9 +161,24 @@ class AssetThumbnail():
|
||||
filename: str
|
||||
base_url: str
|
||||
index: int
|
||||
time: datetime
|
||||
time: datetime # TODO(Andreas): This seems rather float!
|
||||
type: str
|
||||
|
||||
@classmethod
|
||||
def _from_dict(cls, d: Dict):
|
||||
"""Alternate constructor,
|
||||
used after loading AssetIndex from JSON to reconstruct class.
|
||||
"""
|
||||
|
||||
# TODO(Andreas): If time really was a datetime (see above),
|
||||
# something like this would be needed:
|
||||
# if "time" not in d:
|
||||
# raise KeyError("time")
|
||||
# d["time"] = create datetime from value
|
||||
|
||||
new = cls(**d)
|
||||
return new
|
||||
|
||||
|
||||
# NOT a data class
|
||||
# Reason: this way it does not get saved with the asset index.
|
||||
@@ -327,6 +341,33 @@ class AssetState():
|
||||
self.purchase = AssetStatePurchase()
|
||||
self.dl = AssetStateDownload()
|
||||
|
||||
def has_any_error(self) -> bool:
|
||||
"""Returns True, if any of the state classes signals an error."""
|
||||
|
||||
has_dl_error = self.dl.has_error()
|
||||
has_purchase_error = self.purchase.has_error()
|
||||
return self.has_error or has_dl_error or has_purchase_error
|
||||
|
||||
def get_any_error(
|
||||
self,
|
||||
force_string: bool = False
|
||||
) -> Optional[Union[str, Exception]]:
|
||||
"""Returns any error inside of state with following priority:
|
||||
Asset error -> Purchase error -> Download error
|
||||
"""
|
||||
|
||||
error = None
|
||||
if self.has_error:
|
||||
error = self.error
|
||||
elif self.purchase.has_error():
|
||||
error = self.purchase.error
|
||||
elif self.dl.has_error():
|
||||
error = self.dl.error
|
||||
|
||||
if force_string and error is not None:
|
||||
error = str(error)
|
||||
return error
|
||||
|
||||
|
||||
class AssetDataRuntime():
|
||||
"""Stores additional asset data, which will only be valid during runtime.
|
||||
@@ -595,6 +636,18 @@ class TextureMap:
|
||||
def get_path(self):
|
||||
return os.path.join(self.directory, self.filename)
|
||||
|
||||
def copy(self):
|
||||
"""Creates a copy of the TextureMap instance"""
|
||||
return TextureMap(
|
||||
directory=self.directory,
|
||||
filename=self.filename,
|
||||
file_format=self.file_format,
|
||||
lod=self.lod,
|
||||
map_type=self.map_type,
|
||||
size=self.size,
|
||||
variant=self.variant,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextureMapDesc:
|
||||
@@ -730,9 +783,9 @@ class Texture(BaseTex):
|
||||
return asset_map_list
|
||||
|
||||
@staticmethod
|
||||
def _convention_0_filter_16bit(tex_map_dict: Dict[MapType, List[TextureMap]],
|
||||
prefer_16_bit: bool = False
|
||||
) -> Dict[MapType, List[TextureMap]]:
|
||||
def _convention_0_filter_maps(tex_map_dict: Dict[MapType, List[TextureMap]],
|
||||
prefer_16_bit: bool = False
|
||||
) -> Dict[MapType, List[TextureMap]]:
|
||||
# Decide between 8-Bit and 16-Bit, if both are available
|
||||
if MapType.BUMP in tex_map_dict and MapType.BUMP16 in tex_map_dict:
|
||||
if prefer_16_bit:
|
||||
@@ -749,6 +802,20 @@ class Texture(BaseTex):
|
||||
del tex_map_dict[MapType.NRM]
|
||||
else:
|
||||
del tex_map_dict[MapType.NRM16]
|
||||
|
||||
# TODO(Joao): Implement this in the dccs side once all the map
|
||||
# management is being done there;
|
||||
# If ALBEDO map is found, we replace the COL list with the ALBEDO maps;
|
||||
# To be compatible with the addons, we have to keep using the MapType.COL
|
||||
# as key;
|
||||
albedo_list = tex_map_dict.get(MapType.ALBEDO, [])
|
||||
if len(albedo_list) > 0:
|
||||
tex_map_dict[MapType.COL] = []
|
||||
for albedo_map in albedo_list:
|
||||
albedo_map = albedo_map.copy()
|
||||
albedo_map.map_type = MapType.COL
|
||||
tex_map_dict[MapType.COL].append(albedo_map)
|
||||
|
||||
return tex_map_dict
|
||||
|
||||
def update(self, type_data_new, purge_maps: bool = False) -> None:
|
||||
@@ -959,7 +1026,10 @@ class Texture(BaseTex):
|
||||
raise KeyError(f"Workflow not found: {workflow}")
|
||||
|
||||
map_descs = self.map_descs[workflow]
|
||||
return [map_desc.map_type_code for map_desc in map_descs]
|
||||
|
||||
# No UNKNOWN map should get here - they are filtered on the map_descs constructor ()
|
||||
return [map_desc.map_type_code for map_desc in map_descs
|
||||
if MapType.from_type_code(map_desc.map_type_code) != MapType.UNKNOWN]
|
||||
|
||||
def get_maps_per_preferences(self,
|
||||
map_preferences: Any,
|
||||
@@ -1049,8 +1119,8 @@ class Texture(BaseTex):
|
||||
tex_map_dict[_map_type] = tex_map_dict.get(_map_type, []) + [tex_map]
|
||||
|
||||
if map_type_list is None:
|
||||
tex_map_dict = self._convention_0_filter_16bit(tex_map_dict,
|
||||
prefer_16_bit)
|
||||
tex_map_dict = self._convention_0_filter_maps(tex_map_dict,
|
||||
prefer_16_bit)
|
||||
|
||||
tex_maps = []
|
||||
# Get rid of multiple files for the same texture (e.g. .png and .psd)
|
||||
@@ -1880,11 +1950,18 @@ class AssetData:
|
||||
downloaded_at: Optional[int] = None
|
||||
# is_purchased: None until proven true or false.
|
||||
is_purchased: Optional[bool] = None
|
||||
# is_recent_downloaded: None until proven true or false
|
||||
# (new flag after implementing recent downloads filter in addon-first)
|
||||
is_recent_downloaded: Optional[bool] = None
|
||||
# UTC, seconds since epoch
|
||||
purchased_at: Optional[int] = None
|
||||
# render_custom_schema: Filled with what ever meta data
|
||||
# ApiResponse contains for this key.
|
||||
# TODO(Joao): Remove Render Custom Schema from this class.
|
||||
# Render custom schema was removed from asset dict api, specifications
|
||||
# key was added here as substitute - to discuss if we should keep both;
|
||||
render_custom_schema: Optional[Dict] = None
|
||||
specifications: Optional[Dict] = None
|
||||
|
||||
api_convention: Optional[int] = None
|
||||
local_convention: Optional[int] = None
|
||||
@@ -1922,6 +1999,11 @@ class AssetData:
|
||||
elif new.texture is not None:
|
||||
new.texture = Texture._from_dict(new.texture)
|
||||
|
||||
cf_thumb_urls = []
|
||||
for _dict in new.cloudflare_thumb_urls:
|
||||
cf_thumb_urls.append(AssetThumbnail._from_dict(_dict))
|
||||
new.cloudflare_thumb_urls = cf_thumb_urls
|
||||
|
||||
new.state = AssetState()
|
||||
new.runtime = AssetDataRuntime()
|
||||
|
||||
@@ -1978,12 +2060,13 @@ class AssetData:
|
||||
data_dict[_t("Creation Method")] = method_str
|
||||
|
||||
elif self.asset_type == AssetType.MODEL:
|
||||
tech_data = self.render_custom_schema.get("technical_description", {})
|
||||
lod_tri_count = tech_data.get("LODs", {})
|
||||
specs = self.specifications
|
||||
tech_data = specs if specs is not None else {}
|
||||
lod_tri_count = tech_data.get("lods", {})
|
||||
|
||||
meshes = tech_data.get("Meshes", None)
|
||||
meshes = tech_data.get("meshes", None)
|
||||
source_tri_count = lod_tri_count.get("SOURCE", None)
|
||||
dimensions = tech_data.get("Dimensions", {})
|
||||
dimensions = tech_data.get("physical_size_cm", {})
|
||||
incl_res = self.render_custom_schema.get("included_resolution")
|
||||
|
||||
# TODO(Joao): The following lines are for mark dimensions strings
|
||||
@@ -1993,7 +2076,8 @@ class AssetData:
|
||||
depth_str = _m("Depth (cm)") # noqa
|
||||
|
||||
for key, value in dimensions.items():
|
||||
data_dict[_t(key)] = value
|
||||
key_title = _t(key.title())
|
||||
data_dict[f"{key_title} (cm)"] = value
|
||||
|
||||
if meshes is not None:
|
||||
data_dict[_t("Mesh Count")] = meshes
|
||||
@@ -2060,9 +2144,10 @@ class AssetData:
|
||||
self.categories = _cond_set(asset_data_new.categories, self.categories)
|
||||
self.url = _cond_set(asset_data_new.url, self.url)
|
||||
self.slug = _cond_set(asset_data_new.slug, self.slug)
|
||||
self.thumb_urls = _cond_set(asset_data_new.thumb_urls, self.thumb_urls)
|
||||
self.published_at = _cond_set(asset_data_new.thumb_urls,
|
||||
self.thumb_urls)
|
||||
self.published_at = _cond_set(asset_data_new.published_at,
|
||||
self.published_at)
|
||||
self.thumb_urls = _cond_set(asset_data_new.thumb_urls,
|
||||
self.thumb_urls)
|
||||
self.is_local = _cond_set(asset_data_new.is_local, self.is_local)
|
||||
self.downloaded_at = _cond_set(asset_data_new.downloaded_at,
|
||||
self.downloaded_at)
|
||||
@@ -2072,9 +2157,15 @@ class AssetData:
|
||||
self.purchased_at)
|
||||
self.render_custom_schema = _cond_set(
|
||||
asset_data_new.render_custom_schema, self.render_custom_schema)
|
||||
self.specifications = _cond_set(asset_data_new.specifications,
|
||||
self.specifications)
|
||||
|
||||
self.get_type_data().update(asset_data_new.get_type_data(), purge_maps)
|
||||
|
||||
def flush_state(self):
|
||||
"""Reset all Download and Purchase states."""
|
||||
self.state = AssetState()
|
||||
|
||||
def flush_local(self):
|
||||
"""Resets all information about local files."""
|
||||
|
||||
@@ -2187,6 +2278,43 @@ class AssetData:
|
||||
name_mat += f"_{renderer}"
|
||||
return name_mat
|
||||
|
||||
def create_dummy_type_data(
|
||||
self, asset_type: Optional[AssetType] = None) -> None:
|
||||
"""Creates a dummy type data. For example to be used in case we run
|
||||
into an exception during evaluation of server side asset data and thus
|
||||
can't build complete data structures.
|
||||
"""
|
||||
|
||||
if self.get_type_data() is not None:
|
||||
return
|
||||
if asset_type is None:
|
||||
asset_type = self.asset_type
|
||||
|
||||
tex_dummy = Texture(maps={}, map_descs={}, sizes=["1K"])
|
||||
if asset_type == AssetType.BRUSH:
|
||||
self.brush = Brush(alpha=tex_dummy)
|
||||
elif asset_type == AssetType.HDRI:
|
||||
tex_dummy_2 = Texture(maps={}, map_descs={}, sizes=["1K"])
|
||||
self.hdri = Hdri(light=tex_dummy, bg=tex_dummy_2)
|
||||
elif asset_type == AssetType.MODEL:
|
||||
self.model = Model(texture=tex_dummy)
|
||||
elif asset_type == AssetType.TEXTURE:
|
||||
self.texture = tex_dummy
|
||||
|
||||
def is_new(self) -> bool:
|
||||
"""Returns True if the asset was published within the last 30 days.
|
||||
|
||||
Returns:
|
||||
bool: True if the asset is new (< 30 days old), False otherwise
|
||||
"""
|
||||
if not self.published_at:
|
||||
return False
|
||||
days_threshold = 30
|
||||
# Convert timestamp to datetime object
|
||||
published_datetime = datetime.fromtimestamp(self.published_at, tz=timezone.utc)
|
||||
days_old = (datetime.now(timezone.utc) - published_datetime).days
|
||||
return days_old <= days_threshold
|
||||
|
||||
|
||||
# Currently constants are defined here at the end,
|
||||
# as some require above classes to be defined.
|
||||
@@ -2205,3 +2333,11 @@ ASSET_TYPE_TO_CATEGORY_NAME = {AssetType.BRUSH: "Brushes",
|
||||
AssetType.TEXTURE: "Textures",
|
||||
AssetType.UNSUPPORTED: "Unsupported",
|
||||
}
|
||||
|
||||
ASSET_TYPE_TO_URL_PATH = {AssetType.BRUSH: "brush",
|
||||
AssetType.HDRI: "hdr",
|
||||
AssetType.MODEL: "model",
|
||||
AssetType.SUBSTANCE: "substance",
|
||||
AssetType.TEXTURE: "texture",
|
||||
AssetType.UNSUPPORTED: "unsupported",
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ MAX_RETRIES_PER_ASSET = 2
|
||||
|
||||
# This list defines a priority to fallback available formats
|
||||
# NOTE: This is only for Convention 1 downloads
|
||||
SUPPORTED_TEX_FORMATS = ["jpg", "png", "tiff", "exr"]
|
||||
SUPPORTED_TEX_FORMATS = ["jpg", "png", "tiff", "exr", "hdr"]
|
||||
MODEL_FILE_EXT = ["fbx", "blend", "max", "c4d", "skp", "ma"]
|
||||
|
||||
|
||||
@@ -166,20 +166,30 @@ class AssetDownload:
|
||||
asset_data: AssetData,
|
||||
size: str,
|
||||
dir_target: str,
|
||||
lod: str = "NONE",
|
||||
download_lods: bool = False,
|
||||
native_mesh: bool = False,
|
||||
renderer: Optional[str] = None,
|
||||
update_callback: Optional[Callable] = None
|
||||
update_callback: Optional[Callable] = None,
|
||||
get_download_data_callback: Optional[Callable] = None
|
||||
) -> None:
|
||||
self.addon = addon
|
||||
self.asset_data = asset_data
|
||||
self.size = size
|
||||
self.lod = lod
|
||||
|
||||
# TODO(Joao): Add lods list parameter to the class and assign to
|
||||
# self.lods when select download lods feature is available.
|
||||
# Currently if the download is set do download lods, all the lods will
|
||||
# be downloaded. There is no way to download only some given lods;
|
||||
self.download_lods = download_lods
|
||||
self.lods = None
|
||||
|
||||
self.native_mesh = native_mesh
|
||||
self.renderer = renderer
|
||||
|
||||
self.update_callback = update_callback
|
||||
self.get_download_data_callback = get_download_data_callback
|
||||
self.get_download_data_callback_called = False
|
||||
|
||||
self.dir_target = os.path.join(dir_target, asset_data.asset_name)
|
||||
self.download_list = []
|
||||
|
||||
@@ -234,15 +244,18 @@ class AssetDownload:
|
||||
]
|
||||
}
|
||||
|
||||
if self.convention == 0:
|
||||
self.data_payload["assets"][0]["sizes"] = [self.size]
|
||||
elif self.convention == 1:
|
||||
self.data_payload["assets"][0]["resolution"] = self.size
|
||||
self.data_payload["assets"][0]["sizes"] = [self.size]
|
||||
|
||||
if self.asset_data.asset_type in [AssetType.HDRI, AssetType.TEXTURE]:
|
||||
if self.asset_data.asset_type == AssetType.TEXTURE:
|
||||
self.set_texture_payload()
|
||||
elif self.asset_data.asset_type == AssetType.MODEL:
|
||||
self.set_model_payload()
|
||||
elif self.asset_data.asset_type == AssetType.HDRI:
|
||||
# HDRIs don't require any additional payload. The logic for defining
|
||||
# which maps should be downloaded is implemented on the server backend.
|
||||
# This change was made to support .hdr files and specify which DCCs
|
||||
# they should be downloaded/imported for;
|
||||
pass
|
||||
|
||||
def set_texture_payload(self) -> None:
|
||||
if self.convention == 0:
|
||||
@@ -282,13 +295,15 @@ class AssetDownload:
|
||||
self.data_payload["assets"][0]["maps"] = map_list
|
||||
|
||||
def set_model_payload(self) -> None:
|
||||
self.data_payload["assets"][0]["lods"] = int(self.download_lods)
|
||||
|
||||
if self.download_lods:
|
||||
self.lods = self.asset_data.get_type_data().get_lod_list()
|
||||
if self.lods not in [[], ["NONE"]]:
|
||||
self.data_payload["assets"][0]["lods"] = self.lods
|
||||
if self.native_mesh and self.renderer is not None:
|
||||
self.data_payload["assets"][0]["softwares"] = [self.addon._api.software_dl_dcc]
|
||||
self.data_payload["assets"][0]["renders"] = [self.renderer]
|
||||
else:
|
||||
self.data_payload["assets"][0]["softwares"] = ["ALL_OTHERS"]
|
||||
self.data_payload["assets"][0]["softwares"] = ["Generic"]
|
||||
|
||||
def create_download_folder(self) -> bool:
|
||||
try:
|
||||
@@ -354,8 +369,23 @@ class AssetDownload:
|
||||
|
||||
return dl_folder
|
||||
|
||||
@staticmethod
|
||||
def get_files_list(res) -> Tuple[List, List]:
|
||||
files_dict_list = [_files_dict
|
||||
for _files_dict in res.body.get("payload", [])
|
||||
if "files" in _files_dict.keys()]
|
||||
files_list = None
|
||||
dynamic_files_list = None
|
||||
if len(files_dict_list) > 0:
|
||||
files_list = files_dict_list[0].get("files")
|
||||
# Dynamic files expected only for textures
|
||||
if files_dict_list[0].get("dynamic_files"):
|
||||
dynamic_files_list = files_dict_list[0].get("dynamic_files")
|
||||
return files_list, dynamic_files_list
|
||||
|
||||
def build_download_list(self, res: ApiResponse) -> None:
|
||||
files_list = res.body.get("files", [])
|
||||
files_list, dynamic_files = self.get_files_list(res)
|
||||
|
||||
self.uuid = res.body.get("uuid", None)
|
||||
if self.uuid in [None, ""]:
|
||||
self.addon.logger_dl.error("No UUID for download")
|
||||
@@ -408,7 +438,7 @@ class AssetDownload:
|
||||
self.addon._api.report_message(
|
||||
"model_with_only_source_lod", msg, level="info")
|
||||
|
||||
self.set_dynamic_files(res.body.get("dynamic_files", None))
|
||||
self.set_dynamic_files(dynamic_files)
|
||||
|
||||
def set_dynamic_files(self,
|
||||
dynamic_files_api: Optional[List[Dict]]
|
||||
@@ -597,6 +627,14 @@ class AssetDownload:
|
||||
progress = self.size_asset_bytes_downloaded / max(self.size_asset_bytes_expected, 1)
|
||||
self.asset_data.state.dl.set_progress(max(progress, 0.001))
|
||||
self.asset_data.state.dl.set_downloaded_bytes(self.size_asset_bytes_expected)
|
||||
|
||||
try:
|
||||
if not self.get_download_data_callback_called:
|
||||
self.get_download_data_callback()
|
||||
self.get_download_data_callback_called = True
|
||||
except TypeError:
|
||||
pass # No Download data callback
|
||||
|
||||
try: # Init progress bar
|
||||
self.update_callback()
|
||||
except TypeError:
|
||||
|
||||
@@ -77,6 +77,11 @@ class MapType(IntEnum):
|
||||
NA_ORM = 50
|
||||
NA_VERTEXBLEND = 51
|
||||
|
||||
# Legacy maps - used to identify and download - not yet applied to the materials;
|
||||
ALBEDO = 52
|
||||
OVERLAY16 = 53
|
||||
DIRECTION = 54
|
||||
|
||||
# Convention 1, values 100 to 149 (150 and up for convention 1 only maps)
|
||||
# NOTE: Value - 100 should match convention 0
|
||||
AmbientOcclusion = 104
|
||||
|
||||
@@ -40,19 +40,27 @@ NOTICE_PRIO_NO_INET = NOTICE_PRIO_LOW # Show, but other errors have precedent
|
||||
NOTICE_PRIO_PROXY = NOTICE_PRIO_MEDIUM
|
||||
NOTICE_PRIO_SETTINGS_WRITE = NOTICE_PRIO_HIGH
|
||||
NOTICE_PRIO_SURVEY = NOTICE_PRIO_LOWEST
|
||||
|
||||
NOTICE_PRIO_UPDATE = NOTICE_PRIO_HIGH + 5 # urgent, but room for "more urgent"
|
||||
NOTICE_PRIO_RESTART = NOTICE_PRIO_LOW
|
||||
|
||||
# unsupported renderer is low (only onboarding is lower), once is not dismissible
|
||||
NOTICE_PRIO_UNSUPPORTED_RENDERER = NOTICE_PRIO_LOWEST + 5
|
||||
# Onboarding WM is now the lowest, once is not dismissible and less prio than warnings
|
||||
NOTICE_PRIO_WM_ONBOARDING = NOTICE_PRIO_LOWEST + 6
|
||||
|
||||
# Predefined notice IDs
|
||||
NOTICE_ID_MAT_TEMPLATE = "MATERIAL_TEMPLATE_ERROR"
|
||||
NOTICE_ID_NO_INET = "NO_INTERNET_CONNECTION"
|
||||
NOTICE_ID_PROXY = "PROXY_CONNECTION_ERROR"
|
||||
NOTICE_ID_SETTINGS_WRITE = "SETTINGS_WRITE_ERROR"
|
||||
NOTICE_ID_UNSUPPORTED_RENDERER = "UNSUPPORTED_RENDERER"
|
||||
NOTICE_ID_SURVEY_FREE = "NPS_INAPP_FREE"
|
||||
NOTICE_ID_SURVEY_ACTIVE = "NPS_INAPP_ACTIVE"
|
||||
NOTICE_ID_UPDATE = "UPDATE_READY_MANUAL_INSTALL"
|
||||
NOTICE_ID_VERSION_ALERT = "ADDON_VERSION_ALERT"
|
||||
NOTICE_ID_RESTART_ALERT = "NOTICE_ID_RESTART_ALERT"
|
||||
NOTICE_ID_WM_ONBOARDING = "NOTICE_WM_ONBOARDING"
|
||||
|
||||
# Predefined notice titles
|
||||
# Used a default param in create functions, but should usually be overridden
|
||||
@@ -61,10 +69,12 @@ NOTICE_TITLE_MAT_TEMPLATE = _m("Material template error")
|
||||
NOTICE_TITLE_NO_INET = _m("No internet access")
|
||||
NOTICE_TITLE_PROXY = _m("Encountered proxy error")
|
||||
NOTICE_TITLE_SETTINGS_WRITE = _m("Failed to write settings")
|
||||
NOTICE_UNSUPPORTED_RENDERER = _m("Renderer not supported")
|
||||
NOTICE_TITLE_SURVEY = _m("How's the addon?")
|
||||
NOTICE_TITLE_UPDATE = _m("Update ready")
|
||||
NOTICE_TITLE_DEPRECATED = _m("Deprecated version")
|
||||
NOTICE_TITLE_RESTART = _m("Restart needed")
|
||||
NOTICE_TITLE_WM_ONBOARDING = _m("Watermarked Previews")
|
||||
|
||||
# Predefined notice Labels (text to be displayed on the notification Banner)
|
||||
# Used a default param in create functions, but should usually be overridden
|
||||
@@ -87,6 +97,7 @@ NOTICE_ICON_WARN = "ICON_WARN"
|
||||
NOTICE_ICON_INFO = "ICON_INFO"
|
||||
NOTICE_ICON_SURVEY = "ICON_SURVEY"
|
||||
NOTICE_ICON_NO_CONNECTION = "ICON_NO_CONNECTION"
|
||||
NOTICE_ICON_WM_ONBOARDING = "ICON_WM_ONBOARDING"
|
||||
|
||||
|
||||
class ActionType(IntEnum):
|
||||
@@ -97,6 +108,9 @@ class ActionType(IntEnum):
|
||||
RUN_OPERATOR = 4
|
||||
UPDATE_READY = 2
|
||||
|
||||
# Adding None for notifications not related with any further actions (e.g. onboarding)
|
||||
NONE = 99
|
||||
|
||||
|
||||
class SignalType(IntEnum):
|
||||
"""Types of each interaction with the notifications."""
|
||||
@@ -147,6 +161,10 @@ class Notification():
|
||||
# function to be called when the notification is dismissed (viewed or not)
|
||||
on_dismiss_callable: Optional[Callable] = None
|
||||
|
||||
# Parameter for notifications that can be shown as banners on the asset
|
||||
# browser (only available for P4Max and P4C)
|
||||
display_as_banner: bool = False
|
||||
|
||||
viewed: bool = False # False until actually drawn
|
||||
clicked: bool = False # False until user interact with the notice
|
||||
|
||||
@@ -167,6 +185,12 @@ class AddonNotificationsParameters:
|
||||
update_action_text: str = NOTICE_TITLE_UPDATE
|
||||
update_body: str = ""
|
||||
|
||||
onboarding_wm_title: str = NOTICE_TITLE_WM_ONBOARDING
|
||||
onboarding_wm_label: str = ""
|
||||
onboarding_wm_tooltip: str = ""
|
||||
|
||||
allow_banner_notice: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotificationOpenUrl(Notification):
|
||||
@@ -217,7 +241,16 @@ class NotificationUpdateReady(Notification):
|
||||
self.action = ActionType.UPDATE_READY
|
||||
|
||||
def get_key(self) -> str:
|
||||
return "".join([self.action.name, self.download_url, self.download_label])
|
||||
return "".join([self.action.name, self.id_notice])
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotificationNoAction(Notification):
|
||||
def __post_init__(self):
|
||||
self.action = ActionType.NONE
|
||||
|
||||
def get_key(self) -> str:
|
||||
return "".join([self.action.name, self.id_notice])
|
||||
|
||||
|
||||
class NotificationSystem():
|
||||
@@ -239,7 +272,8 @@ class NotificationSystem():
|
||||
NOTICE_ICON_WARN: None,
|
||||
NOTICE_ICON_INFO: None,
|
||||
NOTICE_ICON_SURVEY: None,
|
||||
NOTICE_ICON_NO_CONNECTION: None
|
||||
NOTICE_ICON_NO_CONNECTION: None,
|
||||
NOTICE_ICON_WM_ONBOARDING: None
|
||||
}
|
||||
|
||||
addon_params = AddonNotificationsParameters()
|
||||
@@ -255,12 +289,14 @@ class NotificationSystem():
|
||||
icon_warn: Optional[Any] = None,
|
||||
icon_info: Optional[Any] = None,
|
||||
icon_survey: Optional[Any] = None,
|
||||
icon_no_connection: Optional[Any] = None
|
||||
icon_no_connection: Optional[Any] = None,
|
||||
notify_icon_wm_onboarding: Optional[Any] = None
|
||||
) -> None:
|
||||
self.icon_dcc_map[NOTICE_ICON_WARN] = icon_warn
|
||||
self.icon_dcc_map[NOTICE_ICON_INFO] = icon_info
|
||||
self.icon_dcc_map[NOTICE_ICON_SURVEY] = icon_survey
|
||||
self.icon_dcc_map[NOTICE_ICON_NO_CONNECTION] = icon_no_connection
|
||||
self.icon_dcc_map[NOTICE_ICON_WM_ONBOARDING] = notify_icon_wm_onboarding
|
||||
|
||||
def _run_threaded(key_pool: PoolKeys,
|
||||
max_threads: Optional[int] = None,
|
||||
@@ -351,7 +387,9 @@ class NotificationSystem():
|
||||
if not notice.allow_dismiss and not force:
|
||||
return
|
||||
|
||||
if not notice.clicked:
|
||||
if not notice.clicked and force is False:
|
||||
# Don't signal dismiss if it's a forced removal, since there's no
|
||||
# human interaction.
|
||||
self._signal_dismiss(notice)
|
||||
|
||||
if notice.on_dismiss_callable is not None:
|
||||
@@ -630,6 +668,32 @@ class NotificationSystem():
|
||||
self.enqueue_notice(notice)
|
||||
return notice
|
||||
|
||||
def create_unsupported_renderer(self,
|
||||
title: str = NOTICE_UNSUPPORTED_RENDERER,
|
||||
*,
|
||||
label: str = NOTICE_UNSUPPORTED_RENDERER,
|
||||
tooltip: str,
|
||||
body: str,
|
||||
auto_enqueue: bool = True
|
||||
) -> Notification:
|
||||
"""Returns an Unsupported Render Engine notice."""
|
||||
|
||||
notice = NotificationPopup(
|
||||
id_notice=NOTICE_ID_UNSUPPORTED_RENDERER,
|
||||
title=title,
|
||||
priority=NOTICE_PRIO_UNSUPPORTED_RENDERER,
|
||||
allow_dismiss=False,
|
||||
tooltip=tooltip,
|
||||
icon=self.icon_dcc_map[NOTICE_ICON_WARN],
|
||||
body=body,
|
||||
label=label,
|
||||
open_popup=True,
|
||||
alert=True
|
||||
)
|
||||
if auto_enqueue:
|
||||
self.enqueue_notice(notice)
|
||||
return notice
|
||||
|
||||
def create_update(self,
|
||||
title: str = NOTICE_TITLE_UPDATE,
|
||||
*,
|
||||
@@ -676,3 +740,30 @@ class NotificationSystem():
|
||||
if auto_enqueue:
|
||||
self.enqueue_notice(notice)
|
||||
return notice
|
||||
|
||||
def create_watermarked_onboarding(self,
|
||||
auto_enqueue: bool = True,
|
||||
) -> NotificationNoAction:
|
||||
"""Returns a pre-built 'WM Onboarding' notice."""
|
||||
|
||||
title = NOTICE_TITLE_WM_ONBOARDING
|
||||
if self.addon_params.onboarding_wm_title is not None:
|
||||
title = self.addon_params.onboarding_wm_title
|
||||
|
||||
notice = NotificationNoAction(
|
||||
id_notice=NOTICE_ID_WM_ONBOARDING,
|
||||
title=title,
|
||||
priority=NOTICE_PRIO_WM_ONBOARDING,
|
||||
display_as_banner=True,
|
||||
label=self.addon_params.onboarding_wm_label,
|
||||
allow_dismiss=False,
|
||||
tooltip=self.addon_params.onboarding_wm_tooltip,
|
||||
icon=self.icon_dcc_map[NOTICE_ICON_WM_ONBOARDING]
|
||||
)
|
||||
|
||||
allow_banner = self.addon_params.allow_banner_notice
|
||||
handle_as_banner_notice = allow_banner and notice.display_as_banner
|
||||
if auto_enqueue and not handle_as_banner_notice:
|
||||
self.enqueue_notice(notice)
|
||||
|
||||
return notice
|
||||
|
||||
@@ -154,6 +154,7 @@ class PoliigonSubscription:
|
||||
base_price: Optional[float] = None # e.g. 123.45
|
||||
currency_symbol: Optional[str] = None # e.g. "$" (special character)
|
||||
is_unlimited: Optional[bool] = None
|
||||
is_limited_legacy: Optional[bool] = None
|
||||
has_team: Optional[bool] = None
|
||||
|
||||
@staticmethod
|
||||
@@ -275,6 +276,10 @@ class PoliigonSubscription:
|
||||
if unlimited is not None:
|
||||
self.is_unlimited = bool(unlimited)
|
||||
|
||||
legacy = plan_dictionary.get("limited_legacy", None)
|
||||
if legacy is not None:
|
||||
self.is_limited_legacy = bool(legacy)
|
||||
|
||||
has_team = bool(plan_dictionary.get("team_id", None))
|
||||
if has_team is not None:
|
||||
self.has_team = bool(has_team)
|
||||
@@ -403,6 +408,10 @@ class PoliigonPlanUpgradeManager:
|
||||
self.addon.notify._signal_clicked(signal_notice)
|
||||
|
||||
def check_show_banner(self) -> bool:
|
||||
# Checks if the wm onboarding banner is drawn or not
|
||||
if not self.addon.is_onboarding_wm_preview_done():
|
||||
return False
|
||||
|
||||
if self.user is None:
|
||||
return False
|
||||
do_show_banner = self.show_banner
|
||||
@@ -485,11 +494,13 @@ class PoliigonPlanUpgradeManager:
|
||||
if self.user.plan.is_unlimited:
|
||||
# The only benefit to upgrade is to get more downloads, if you're
|
||||
# already unlimited, there's nothing to upgrade to
|
||||
self.upgrade_plan = None
|
||||
return
|
||||
|
||||
if self.user.plan.has_team:
|
||||
# Let's not offer updates to team members, since these contracts
|
||||
# are handled separately
|
||||
self.upgrade_plan = None
|
||||
return
|
||||
|
||||
upgrade_pro_plan = None
|
||||
|
||||
@@ -128,7 +128,7 @@ class UpgradeContent:
|
||||
self.populate()
|
||||
|
||||
def student_discount(self, is_teacher: bool = False) -> Any:
|
||||
primary = _t("Access the entire library by joining Pro")
|
||||
primary = _t("Access the entire library with a Poliigon Plan")
|
||||
secondary = _t("{0} can claim a 50% discount".format(
|
||||
_t("Students") if not is_teacher else _t("Teachers")))
|
||||
if self.as_single_paragraph:
|
||||
@@ -144,7 +144,7 @@ class UpgradeContent:
|
||||
self.icon_path = self.icons.check
|
||||
|
||||
def become_pro(self) -> Any:
|
||||
primary = _t("Access the entire library by joining Pro")
|
||||
primary = _t("Access the entire library with a Poliigon Plan")
|
||||
secondary = _t("Download and import from the entire Poliigon library")
|
||||
if self.as_single_paragraph:
|
||||
# To keep it slimmer, do just the primary text.
|
||||
|
||||
@@ -17,11 +17,12 @@
|
||||
# ##### END GPL LICENSE BLOCK #####
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Any
|
||||
from enum import Enum
|
||||
|
||||
from .assets import (MapType)
|
||||
from .plan_manager import PoliigonSubscription
|
||||
|
||||
from .multilingual import _m, _t
|
||||
from .logger import (DEBUG, # noqa F401, allowing downstream const usage
|
||||
ERROR,
|
||||
INFO,
|
||||
@@ -30,6 +31,42 @@ from .logger import (DEBUG, # noqa F401, allowing downstream const usage
|
||||
WARNING)
|
||||
|
||||
|
||||
class PoliigonUserProfiles(Enum):
|
||||
USER_HOBBYIST = _m("As a hobbyist")
|
||||
USER_STUDENT = _m("As a student")
|
||||
USER_PROFESSIONAL = _m("As a professional (individual)")
|
||||
USER_STUDIO = _m("As a professional (studio/team)")
|
||||
USER_TEACHER = _m("As a teacher")
|
||||
|
||||
def to_ui_display(self) -> Optional[str]:
|
||||
if self == self.USER_STUDENT:
|
||||
return _t("Student")
|
||||
elif self == self.USER_TEACHER:
|
||||
return _t("Teacher")
|
||||
elif self == self.USER_HOBBYIST:
|
||||
return _t("Hobbyist")
|
||||
elif self == self.USER_PROFESSIONAL:
|
||||
return _t("Freelancer")
|
||||
elif self == self.USER_STUDIO:
|
||||
return _t("Team")
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, use_string: str) -> Optional[Any]:
|
||||
try:
|
||||
return cls(use_string)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_list(cls, addon_display: bool = False) -> List[str]:
|
||||
if not addon_display:
|
||||
return [member.value for member in PoliigonUserProfiles]
|
||||
else:
|
||||
return [member.to_ui_display() for member in PoliigonUserProfiles]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MapFormats:
|
||||
map_type: MapType
|
||||
@@ -123,7 +160,15 @@ class PoliigonUser:
|
||||
is_teacher: Optional[bool] = False
|
||||
credits: Optional[int] = None
|
||||
credits_od: Optional[int] = None
|
||||
count_assets_owned: Optional[int] = None
|
||||
count_assets_downloads: Optional[int] = None
|
||||
plan: Optional[PoliigonSubscription] = None
|
||||
map_preferences: Optional[UserDownloadPreferences] = None
|
||||
user_profile: PoliigonUserProfiles = None
|
||||
email_preference: Optional[bool] = None
|
||||
primary_3d_software: Optional[str] = None
|
||||
# Primary Render engine might be changed/decided in the Dcc side if
|
||||
# not yet available (None) in user data;
|
||||
primary_rendering_engine: Optional[str] = None
|
||||
# Todo(Joao): remove this flag when all addons are using map prefs
|
||||
use_preferences_on_download: bool = False
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from sentry_sdk import profiler
|
||||
from sentry_sdk import metrics
|
||||
from sentry_sdk.scope import Scope
|
||||
from sentry_sdk.transport import Transport, HttpTransport
|
||||
from sentry_sdk.client import Client
|
||||
|
||||
from sentry_sdk.api import * # noqa
|
||||
|
||||
from sentry_sdk.consts import VERSION # noqa
|
||||
from sentry_sdk.consts import VERSION
|
||||
|
||||
__all__ = [ # noqa
|
||||
"Hub",
|
||||
@@ -12,9 +13,11 @@ __all__ = [ # noqa
|
||||
"Client",
|
||||
"Transport",
|
||||
"HttpTransport",
|
||||
"VERSION",
|
||||
"integrations",
|
||||
# From sentry_sdk.api
|
||||
"init",
|
||||
"add_attachment",
|
||||
"add_breadcrumb",
|
||||
"capture_event",
|
||||
"capture_exception",
|
||||
@@ -45,6 +48,13 @@ __all__ = [ # noqa
|
||||
"start_transaction",
|
||||
"trace",
|
||||
"monitor",
|
||||
"logger",
|
||||
"metrics",
|
||||
"profiler",
|
||||
"start_session",
|
||||
"end_session",
|
||||
"set_transaction_name",
|
||||
"update_current_span",
|
||||
]
|
||||
|
||||
# Initialize the debug support after everything is loaded
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List, Callable, TYPE_CHECKING, Any
|
||||
|
||||
from sentry_sdk.utils import format_timestamp, safe_repr
|
||||
from sentry_sdk.envelope import Envelope, Item, PayloadRef
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentry_sdk._types import Log
|
||||
|
||||
|
||||
class LogBatcher:
|
||||
MAX_LOGS_BEFORE_FLUSH = 100
|
||||
FLUSH_WAIT_TIME = 5.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capture_func, # type: Callable[[Envelope], None]
|
||||
):
|
||||
# type: (...) -> None
|
||||
self._log_buffer = [] # type: List[Log]
|
||||
self._capture_func = capture_func
|
||||
self._running = True
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._flush_event = threading.Event() # type: threading.Event
|
||||
|
||||
self._flusher = None # type: Optional[threading.Thread]
|
||||
self._flusher_pid = None # type: Optional[int]
|
||||
|
||||
def _ensure_thread(self):
|
||||
# type: (...) -> bool
|
||||
"""For forking processes we might need to restart this thread.
|
||||
This ensures that our process actually has that thread running.
|
||||
"""
|
||||
if not self._running:
|
||||
return False
|
||||
|
||||
pid = os.getpid()
|
||||
if self._flusher_pid == pid:
|
||||
return True
|
||||
|
||||
with self._lock:
|
||||
# Recheck to make sure another thread didn't get here and start the
|
||||
# the flusher in the meantime
|
||||
if self._flusher_pid == pid:
|
||||
return True
|
||||
|
||||
self._flusher_pid = pid
|
||||
|
||||
self._flusher = threading.Thread(target=self._flush_loop)
|
||||
self._flusher.daemon = True
|
||||
|
||||
try:
|
||||
self._flusher.start()
|
||||
except RuntimeError:
|
||||
# Unfortunately at this point the interpreter is in a state that no
|
||||
# longer allows us to spawn a thread and we have to bail.
|
||||
self._running = False
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _flush_loop(self):
|
||||
# type: (...) -> None
|
||||
while self._running:
|
||||
self._flush_event.wait(self.FLUSH_WAIT_TIME + random.random())
|
||||
self._flush_event.clear()
|
||||
self._flush()
|
||||
|
||||
def add(
|
||||
self,
|
||||
log, # type: Log
|
||||
):
|
||||
# type: (...) -> None
|
||||
if not self._ensure_thread() or self._flusher is None:
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
self._log_buffer.append(log)
|
||||
if len(self._log_buffer) >= self.MAX_LOGS_BEFORE_FLUSH:
|
||||
self._flush_event.set()
|
||||
|
||||
def kill(self):
|
||||
# type: (...) -> None
|
||||
if self._flusher is None:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._flush_event.set()
|
||||
self._flusher = None
|
||||
|
||||
def flush(self):
|
||||
# type: (...) -> None
|
||||
self._flush()
|
||||
|
||||
@staticmethod
|
||||
def _log_to_transport_format(log):
|
||||
# type: (Log) -> Any
|
||||
def format_attribute(val):
|
||||
# type: (int | float | str | bool) -> Any
|
||||
if isinstance(val, bool):
|
||||
return {"value": val, "type": "boolean"}
|
||||
if isinstance(val, int):
|
||||
return {"value": val, "type": "integer"}
|
||||
if isinstance(val, float):
|
||||
return {"value": val, "type": "double"}
|
||||
if isinstance(val, str):
|
||||
return {"value": val, "type": "string"}
|
||||
return {"value": safe_repr(val), "type": "string"}
|
||||
|
||||
if "sentry.severity_number" not in log["attributes"]:
|
||||
log["attributes"]["sentry.severity_number"] = log["severity_number"]
|
||||
if "sentry.severity_text" not in log["attributes"]:
|
||||
log["attributes"]["sentry.severity_text"] = log["severity_text"]
|
||||
|
||||
res = {
|
||||
"timestamp": int(log["time_unix_nano"]) / 1.0e9,
|
||||
"trace_id": log.get("trace_id", "00000000-0000-0000-0000-000000000000"),
|
||||
"level": str(log["severity_text"]),
|
||||
"body": str(log["body"]),
|
||||
"attributes": {
|
||||
k: format_attribute(v) for (k, v) in log["attributes"].items()
|
||||
},
|
||||
}
|
||||
|
||||
return res
|
||||
|
||||
def _flush(self):
|
||||
# type: (...) -> Optional[Envelope]
|
||||
|
||||
envelope = Envelope(
|
||||
headers={"sent_at": format_timestamp(datetime.now(timezone.utc))}
|
||||
)
|
||||
with self._lock:
|
||||
if len(self._log_buffer) == 0:
|
||||
return None
|
||||
|
||||
envelope.add_item(
|
||||
Item(
|
||||
type="log",
|
||||
content_type="application/vnd.sentry.items.log+json",
|
||||
headers={
|
||||
"item_count": len(self._log_buffer),
|
||||
},
|
||||
payload=PayloadRef(
|
||||
json={
|
||||
"items": [
|
||||
self._log_to_transport_format(log)
|
||||
for log in self._log_buffer
|
||||
]
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
self._log_buffer.clear()
|
||||
|
||||
self._capture_func(envelope)
|
||||
return envelope
|
||||
@@ -0,0 +1,156 @@
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List, Callable, TYPE_CHECKING, Any, Union
|
||||
|
||||
from sentry_sdk.utils import format_timestamp, safe_repr
|
||||
from sentry_sdk.envelope import Envelope, Item, PayloadRef
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentry_sdk._types import Metric
|
||||
|
||||
|
||||
class MetricsBatcher:
|
||||
MAX_METRICS_BEFORE_FLUSH = 1000
|
||||
FLUSH_WAIT_TIME = 5.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capture_func, # type: Callable[[Envelope], None]
|
||||
):
|
||||
# type: (...) -> None
|
||||
self._metric_buffer = [] # type: List[Metric]
|
||||
self._capture_func = capture_func
|
||||
self._running = True
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._flush_event = threading.Event() # type: threading.Event
|
||||
|
||||
self._flusher = None # type: Optional[threading.Thread]
|
||||
self._flusher_pid = None # type: Optional[int]
|
||||
|
||||
def _ensure_thread(self):
|
||||
# type: (...) -> bool
|
||||
if not self._running:
|
||||
return False
|
||||
|
||||
pid = os.getpid()
|
||||
if self._flusher_pid == pid:
|
||||
return True
|
||||
|
||||
with self._lock:
|
||||
if self._flusher_pid == pid:
|
||||
return True
|
||||
|
||||
self._flusher_pid = pid
|
||||
|
||||
self._flusher = threading.Thread(target=self._flush_loop)
|
||||
self._flusher.daemon = True
|
||||
|
||||
try:
|
||||
self._flusher.start()
|
||||
except RuntimeError:
|
||||
self._running = False
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _flush_loop(self):
|
||||
# type: (...) -> None
|
||||
while self._running:
|
||||
self._flush_event.wait(self.FLUSH_WAIT_TIME + random.random())
|
||||
self._flush_event.clear()
|
||||
self._flush()
|
||||
|
||||
def add(
|
||||
self,
|
||||
metric, # type: Metric
|
||||
):
|
||||
# type: (...) -> None
|
||||
if not self._ensure_thread() or self._flusher is None:
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
self._metric_buffer.append(metric)
|
||||
if len(self._metric_buffer) >= self.MAX_METRICS_BEFORE_FLUSH:
|
||||
self._flush_event.set()
|
||||
|
||||
def kill(self):
|
||||
# type: (...) -> None
|
||||
if self._flusher is None:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._flush_event.set()
|
||||
self._flusher = None
|
||||
|
||||
def flush(self):
|
||||
# type: (...) -> None
|
||||
self._flush()
|
||||
|
||||
@staticmethod
|
||||
def _metric_to_transport_format(metric):
|
||||
# type: (Metric) -> Any
|
||||
def format_attribute(val):
|
||||
# type: (Union[int, float, str, bool]) -> Any
|
||||
if isinstance(val, bool):
|
||||
return {"value": val, "type": "boolean"}
|
||||
if isinstance(val, int):
|
||||
return {"value": val, "type": "integer"}
|
||||
if isinstance(val, float):
|
||||
return {"value": val, "type": "double"}
|
||||
if isinstance(val, str):
|
||||
return {"value": val, "type": "string"}
|
||||
return {"value": safe_repr(val), "type": "string"}
|
||||
|
||||
res = {
|
||||
"timestamp": metric["timestamp"],
|
||||
"trace_id": metric["trace_id"],
|
||||
"name": metric["name"],
|
||||
"type": metric["type"],
|
||||
"value": metric["value"],
|
||||
"attributes": {
|
||||
k: format_attribute(v) for (k, v) in metric["attributes"].items()
|
||||
},
|
||||
}
|
||||
|
||||
if metric.get("span_id") is not None:
|
||||
res["span_id"] = metric["span_id"]
|
||||
|
||||
if metric.get("unit") is not None:
|
||||
res["unit"] = metric["unit"]
|
||||
|
||||
return res
|
||||
|
||||
def _flush(self):
|
||||
# type: (...) -> Optional[Envelope]
|
||||
|
||||
envelope = Envelope(
|
||||
headers={"sent_at": format_timestamp(datetime.now(timezone.utc))}
|
||||
)
|
||||
with self._lock:
|
||||
if len(self._metric_buffer) == 0:
|
||||
return None
|
||||
|
||||
envelope.add_item(
|
||||
Item(
|
||||
type="trace_metric",
|
||||
content_type="application/vnd.sentry.items.trace-metric+json",
|
||||
headers={
|
||||
"item_count": len(self._metric_buffer),
|
||||
},
|
||||
payload=PayloadRef(
|
||||
json={
|
||||
"items": [
|
||||
self._metric_to_transport_format(metric)
|
||||
for metric in self._metric_buffer
|
||||
]
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
self._metric_buffer.clear()
|
||||
|
||||
self._capture_func(envelope)
|
||||
return envelope
|
||||
@@ -30,6 +30,17 @@ class AnnotatedValue:
|
||||
|
||||
return self.value == other.value and self.metadata == other.metadata
|
||||
|
||||
def __str__(self):
|
||||
# type: (AnnotatedValue) -> str
|
||||
return str({"value": str(self.value), "metadata": str(self.metadata)})
|
||||
|
||||
def __len__(self):
|
||||
# type: (AnnotatedValue) -> int
|
||||
if self.value is not None:
|
||||
return len(self.value)
|
||||
else:
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def removed_because_raw_data(cls):
|
||||
# type: () -> AnnotatedValue
|
||||
@@ -47,11 +58,14 @@ class AnnotatedValue:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def removed_because_over_size_limit(cls):
|
||||
# type: () -> AnnotatedValue
|
||||
"""The actual value was removed because the size of the field exceeded the configured maximum size (specified with the max_request_body_size sdk option)"""
|
||||
def removed_because_over_size_limit(cls, value=""):
|
||||
# type: (Any) -> AnnotatedValue
|
||||
"""
|
||||
The actual value was removed because the size of the field exceeded the configured maximum size,
|
||||
for example specified with the max_request_body_size sdk option.
|
||||
"""
|
||||
return AnnotatedValue(
|
||||
value="",
|
||||
value=value,
|
||||
metadata={
|
||||
"rem": [ # Remark
|
||||
[
|
||||
@@ -149,14 +163,14 @@ if TYPE_CHECKING:
|
||||
Event = TypedDict(
|
||||
"Event",
|
||||
{
|
||||
"breadcrumbs": dict[
|
||||
Literal["values"], list[dict[str, Any]]
|
||||
"breadcrumbs": Annotated[
|
||||
dict[Literal["values"], list[dict[str, Any]]]
|
||||
], # TODO: We can expand on this type
|
||||
"check_in_id": str,
|
||||
"contexts": dict[str, dict[str, object]],
|
||||
"dist": str,
|
||||
"duration": Optional[float],
|
||||
"environment": str,
|
||||
"environment": Optional[str],
|
||||
"errors": list[dict[str, Any]], # TODO: We can expand on this type
|
||||
"event_id": str,
|
||||
"exception": dict[
|
||||
@@ -174,7 +188,7 @@ if TYPE_CHECKING:
|
||||
"monitor_slug": Optional[str],
|
||||
"platform": Literal["python"],
|
||||
"profile": object, # Should be sentry_sdk.profiler.Profile, but we can't import that here due to circular imports
|
||||
"release": str,
|
||||
"release": Optional[str],
|
||||
"request": dict[str, object],
|
||||
"sdk": Mapping[str, object],
|
||||
"server_name": str,
|
||||
@@ -196,7 +210,6 @@ if TYPE_CHECKING:
|
||||
"type": Literal["check_in", "transaction"],
|
||||
"user": dict[str, object],
|
||||
"_dropped_spans": int,
|
||||
"_metrics_summary": dict[str, object],
|
||||
},
|
||||
total=False,
|
||||
)
|
||||
@@ -206,17 +219,61 @@ if TYPE_CHECKING:
|
||||
tuple[None, None, None],
|
||||
]
|
||||
|
||||
# TODO: Make a proper type definition for this (PRs welcome!)
|
||||
Hint = Dict[str, Any]
|
||||
|
||||
Log = TypedDict(
|
||||
"Log",
|
||||
{
|
||||
"severity_text": str,
|
||||
"severity_number": int,
|
||||
"body": str,
|
||||
"attributes": dict[str, str | bool | float | int],
|
||||
"time_unix_nano": int,
|
||||
"trace_id": Optional[str],
|
||||
},
|
||||
)
|
||||
|
||||
MetricType = Literal["counter", "gauge", "distribution"]
|
||||
|
||||
MetricAttributeValue = TypedDict(
|
||||
"MetricAttributeValue",
|
||||
{
|
||||
"value": Union[str, bool, float, int],
|
||||
"type": Literal["string", "boolean", "double", "integer"],
|
||||
},
|
||||
)
|
||||
|
||||
Metric = TypedDict(
|
||||
"Metric",
|
||||
{
|
||||
"timestamp": float,
|
||||
"trace_id": Optional[str],
|
||||
"span_id": Optional[str],
|
||||
"name": str,
|
||||
"type": MetricType,
|
||||
"value": float,
|
||||
"unit": Optional[str],
|
||||
"attributes": dict[str, str | bool | float | int],
|
||||
},
|
||||
)
|
||||
|
||||
MetricProcessor = Callable[[Metric, Hint], Optional[Metric]]
|
||||
|
||||
# TODO: Make a proper type definition for this (PRs welcome!)
|
||||
Breadcrumb = Dict[str, Any]
|
||||
|
||||
# TODO: Make a proper type definition for this (PRs welcome!)
|
||||
BreadcrumbHint = Dict[str, Any]
|
||||
|
||||
# TODO: Make a proper type definition for this (PRs welcome!)
|
||||
SamplingContext = Dict[str, Any]
|
||||
|
||||
EventProcessor = Callable[[Event, Hint], Optional[Event]]
|
||||
ErrorProcessor = Callable[[Event, ExcInfo], Optional[Event]]
|
||||
BreadcrumbProcessor = Callable[[Breadcrumb, BreadcrumbHint], Optional[Breadcrumb]]
|
||||
TransactionProcessor = Callable[[Event, Hint], Optional[Event]]
|
||||
LogProcessor = Callable[[Log, Hint], Optional[Log]]
|
||||
|
||||
TracesSampler = Callable[[SamplingContext], Union[float, int, bool]]
|
||||
|
||||
@@ -234,35 +291,16 @@ if TYPE_CHECKING:
|
||||
"internal",
|
||||
"profile",
|
||||
"profile_chunk",
|
||||
"metric_bucket",
|
||||
"monitor",
|
||||
"span",
|
||||
"log_item",
|
||||
"trace_metric",
|
||||
]
|
||||
SessionStatus = Literal["ok", "exited", "crashed", "abnormal"]
|
||||
|
||||
ContinuousProfilerMode = Literal["thread", "gevent", "unknown"]
|
||||
ProfilerMode = Union[ContinuousProfilerMode, Literal["sleep"]]
|
||||
|
||||
# Type of the metric.
|
||||
MetricType = Literal["d", "s", "g", "c"]
|
||||
|
||||
# Value of the metric.
|
||||
MetricValue = Union[int, float, str]
|
||||
|
||||
# Internal representation of tags as a tuple of tuples (this is done in order to allow for the same key to exist
|
||||
# multiple times).
|
||||
MetricTagsInternal = Tuple[Tuple[str, str], ...]
|
||||
|
||||
# External representation of tags as a dictionary.
|
||||
MetricTagValue = Union[str, int, float, None]
|
||||
MetricTags = Mapping[str, MetricTagValue]
|
||||
|
||||
# Value inside the generator for the metric value.
|
||||
FlushedMetricValue = Union[int, float]
|
||||
|
||||
BucketKey = Tuple[MetricType, str, MeasurementUnit, MetricTagsInternal]
|
||||
MetricMetaKey = Tuple[MetricType, str, MeasurementUnit]
|
||||
|
||||
MonitorConfigScheduleType = Literal["crontab", "interval"]
|
||||
MonitorConfigScheduleUnit = Literal[
|
||||
"year",
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
from .utils import (
|
||||
set_data_normalized,
|
||||
GEN_AI_MESSAGE_ROLE_MAPPING,
|
||||
GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING,
|
||||
normalize_message_role,
|
||||
normalize_message_roles,
|
||||
) # noqa: F401
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
from sentry_sdk.consts import SPANDATA
|
||||
import sentry_sdk.utils
|
||||
from sentry_sdk import start_span
|
||||
from sentry_sdk.tracing import Span
|
||||
@@ -9,7 +10,9 @@ from sentry_sdk.utils import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional, Callable, Any
|
||||
from typing import Optional, Callable, Awaitable, Any, Union, TypeVar
|
||||
|
||||
F = TypeVar("F", bound=Union[Callable[..., Any], Callable[..., Awaitable[Any]]])
|
||||
|
||||
_ai_pipeline_name = ContextVar("ai_pipeline_name", default=None)
|
||||
|
||||
@@ -25,13 +28,13 @@ def get_ai_pipeline_name():
|
||||
|
||||
|
||||
def ai_track(description, **span_kwargs):
|
||||
# type: (str, Any) -> Callable[..., Any]
|
||||
# type: (str, Any) -> Callable[[F], F]
|
||||
def decorator(f):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
# type: (F) -> F
|
||||
def sync_wrapped(*args, **kwargs):
|
||||
# type: (Any, Any) -> Any
|
||||
curr_pipeline = _ai_pipeline_name.get()
|
||||
op = span_kwargs.get("op", "ai.run" if curr_pipeline else "ai.pipeline")
|
||||
op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
|
||||
|
||||
with start_span(name=description, op=op, **span_kwargs) as span:
|
||||
for k, v in kwargs.pop("sentry_tags", {}).items():
|
||||
@@ -39,7 +42,7 @@ def ai_track(description, **span_kwargs):
|
||||
for k, v in kwargs.pop("sentry_data", {}).items():
|
||||
span.set_data(k, v)
|
||||
if curr_pipeline:
|
||||
span.set_data("ai.pipeline.name", curr_pipeline)
|
||||
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
|
||||
return f(*args, **kwargs)
|
||||
else:
|
||||
_ai_pipeline_name.set(description)
|
||||
@@ -60,7 +63,7 @@ def ai_track(description, **span_kwargs):
|
||||
async def async_wrapped(*args, **kwargs):
|
||||
# type: (Any, Any) -> Any
|
||||
curr_pipeline = _ai_pipeline_name.get()
|
||||
op = span_kwargs.get("op", "ai.run" if curr_pipeline else "ai.pipeline")
|
||||
op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
|
||||
|
||||
with start_span(name=description, op=op, **span_kwargs) as span:
|
||||
for k, v in kwargs.pop("sentry_tags", {}).items():
|
||||
@@ -68,7 +71,7 @@ def ai_track(description, **span_kwargs):
|
||||
for k, v in kwargs.pop("sentry_data", {}).items():
|
||||
span.set_data(k, v)
|
||||
if curr_pipeline:
|
||||
span.set_data("ai.pipeline.name", curr_pipeline)
|
||||
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
|
||||
return await f(*args, **kwargs)
|
||||
else:
|
||||
_ai_pipeline_name.set(description)
|
||||
@@ -87,29 +90,48 @@ def ai_track(description, **span_kwargs):
|
||||
return res
|
||||
|
||||
if inspect.iscoroutinefunction(f):
|
||||
return wraps(f)(async_wrapped)
|
||||
return wraps(f)(async_wrapped) # type: ignore
|
||||
else:
|
||||
return wraps(f)(sync_wrapped)
|
||||
return wraps(f)(sync_wrapped) # type: ignore
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def record_token_usage(
|
||||
span, prompt_tokens=None, completion_tokens=None, total_tokens=None
|
||||
span,
|
||||
input_tokens=None,
|
||||
input_tokens_cached=None,
|
||||
output_tokens=None,
|
||||
output_tokens_reasoning=None,
|
||||
total_tokens=None,
|
||||
):
|
||||
# type: (Span, Optional[int], Optional[int], Optional[int]) -> None
|
||||
# type: (Span, Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]) -> None
|
||||
|
||||
# TODO: move pipeline name elsewhere
|
||||
ai_pipeline_name = get_ai_pipeline_name()
|
||||
if ai_pipeline_name:
|
||||
span.set_data("ai.pipeline.name", ai_pipeline_name)
|
||||
if prompt_tokens is not None:
|
||||
span.set_measurement("ai_prompt_tokens_used", value=prompt_tokens)
|
||||
if completion_tokens is not None:
|
||||
span.set_measurement("ai_completion_tokens_used", value=completion_tokens)
|
||||
if (
|
||||
total_tokens is None
|
||||
and prompt_tokens is not None
|
||||
and completion_tokens is not None
|
||||
):
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, ai_pipeline_name)
|
||||
|
||||
if input_tokens is not None:
|
||||
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
|
||||
|
||||
if input_tokens_cached is not None:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
|
||||
input_tokens_cached,
|
||||
)
|
||||
|
||||
if output_tokens is not None:
|
||||
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
|
||||
|
||||
if output_tokens_reasoning is not None:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
|
||||
output_tokens_reasoning,
|
||||
)
|
||||
|
||||
if total_tokens is None and input_tokens is not None and output_tokens is not None:
|
||||
total_tokens = input_tokens + output_tokens
|
||||
|
||||
if total_tokens is not None:
|
||||
span.set_measurement("ai_total_tokens_used", total_tokens)
|
||||
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)
|
||||
|
||||
@@ -1,32 +1,144 @@
|
||||
import json
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
from sys import getsizeof
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from sentry_sdk.tracing import Span
|
||||
from sentry_sdk.tracing import Span
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.utils import logger
|
||||
|
||||
MAX_GEN_AI_MESSAGE_BYTES = 20_000 # 20KB
|
||||
|
||||
def _normalize_data(data):
|
||||
# type: (Any) -> Any
|
||||
|
||||
class GEN_AI_ALLOWED_MESSAGE_ROLES:
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING = {
|
||||
GEN_AI_ALLOWED_MESSAGE_ROLES.SYSTEM: ["system"],
|
||||
GEN_AI_ALLOWED_MESSAGE_ROLES.USER: ["user", "human"],
|
||||
GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT: ["assistant", "ai"],
|
||||
GEN_AI_ALLOWED_MESSAGE_ROLES.TOOL: ["tool", "tool_call"],
|
||||
}
|
||||
|
||||
GEN_AI_MESSAGE_ROLE_MAPPING = {}
|
||||
for target_role, source_roles in GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING.items():
|
||||
for source_role in source_roles:
|
||||
GEN_AI_MESSAGE_ROLE_MAPPING[source_role] = target_role
|
||||
|
||||
|
||||
def _normalize_data(data, unpack=True):
|
||||
# type: (Any, bool) -> Any
|
||||
# convert pydantic data (e.g. OpenAI v1+) to json compatible format
|
||||
if hasattr(data, "model_dump"):
|
||||
try:
|
||||
return data.model_dump()
|
||||
return _normalize_data(data.model_dump(), unpack=unpack)
|
||||
except Exception as e:
|
||||
logger.warning("Could not convert pydantic data to JSON: %s", e)
|
||||
return data
|
||||
return data if isinstance(data, (int, float, bool, str)) else str(data)
|
||||
|
||||
if isinstance(data, list):
|
||||
if len(data) == 1:
|
||||
return _normalize_data(data[0]) # remove empty dimensions
|
||||
return list(_normalize_data(x) for x in data)
|
||||
if unpack and len(data) == 1:
|
||||
return _normalize_data(data[0], unpack=unpack) # remove empty dimensions
|
||||
return list(_normalize_data(x, unpack=unpack) for x in data)
|
||||
|
||||
if isinstance(data, dict):
|
||||
return {k: _normalize_data(v) for (k, v) in data.items()}
|
||||
return data
|
||||
return {k: _normalize_data(v, unpack=unpack) for (k, v) in data.items()}
|
||||
|
||||
return data if isinstance(data, (int, float, bool, str)) else str(data)
|
||||
|
||||
|
||||
def set_data_normalized(span, key, value):
|
||||
# type: (Span, str, Any) -> None
|
||||
normalized = _normalize_data(value)
|
||||
span.set_data(key, normalized)
|
||||
def set_data_normalized(span, key, value, unpack=True):
|
||||
# type: (Span, str, Any, bool) -> None
|
||||
normalized = _normalize_data(value, unpack=unpack)
|
||||
if isinstance(normalized, (int, float, bool, str)):
|
||||
span.set_data(key, normalized)
|
||||
else:
|
||||
span.set_data(key, json.dumps(normalized))
|
||||
|
||||
|
||||
def normalize_message_role(role):
|
||||
# type: (str) -> str
|
||||
"""
|
||||
Normalize a message role to one of the 4 allowed gen_ai role values.
|
||||
Maps "ai" -> "assistant" and keeps other standard roles unchanged.
|
||||
"""
|
||||
return GEN_AI_MESSAGE_ROLE_MAPPING.get(role, role)
|
||||
|
||||
|
||||
def normalize_message_roles(messages):
|
||||
# type: (list[dict[str, Any]]) -> list[dict[str, Any]]
|
||||
"""
|
||||
Normalize roles in a list of messages to use standard gen_ai role values.
|
||||
Creates a deep copy to avoid modifying the original messages.
|
||||
"""
|
||||
normalized_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
normalized_messages.append(message)
|
||||
continue
|
||||
normalized_message = message.copy()
|
||||
if "role" in message:
|
||||
normalized_message["role"] = normalize_message_role(message["role"])
|
||||
normalized_messages.append(normalized_message)
|
||||
|
||||
return normalized_messages
|
||||
|
||||
|
||||
def get_start_span_function():
|
||||
# type: () -> Callable[..., Any]
|
||||
current_span = sentry_sdk.get_current_span()
|
||||
transaction_exists = (
|
||||
current_span is not None and current_span.containing_transaction is not None
|
||||
)
|
||||
return sentry_sdk.start_span if transaction_exists else sentry_sdk.start_transaction
|
||||
|
||||
|
||||
def _find_truncation_index(messages, max_bytes):
|
||||
# type: (List[Dict[str, Any]], int) -> int
|
||||
"""
|
||||
Find the index of the first message that would exceed the max bytes limit.
|
||||
Compute the individual message sizes, and return the index of the first message from the back
|
||||
of the list that would exceed the max bytes limit.
|
||||
"""
|
||||
running_sum = 0
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
size = len(json.dumps(messages[idx], separators=(",", ":")).encode("utf-8"))
|
||||
running_sum += size
|
||||
if running_sum > max_bytes:
|
||||
return idx + 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def truncate_messages_by_size(messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES):
|
||||
# type: (List[Dict[str, Any]], int) -> Tuple[List[Dict[str, Any]], int]
|
||||
serialized_json = json.dumps(messages, separators=(",", ":"))
|
||||
current_size = len(serialized_json.encode("utf-8"))
|
||||
|
||||
if current_size <= max_bytes:
|
||||
return messages, 0
|
||||
|
||||
truncation_index = _find_truncation_index(messages, max_bytes)
|
||||
return messages[truncation_index:], truncation_index
|
||||
|
||||
|
||||
def truncate_and_annotate_messages(
|
||||
messages, span, scope, max_bytes=MAX_GEN_AI_MESSAGE_BYTES
|
||||
):
|
||||
# type: (Optional[List[Dict[str, Any]]], Any, Any, int) -> Optional[List[Dict[str, Any]]]
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
truncated_messages, removed_count = truncate_messages_by_size(messages, max_bytes)
|
||||
if removed_count > 0:
|
||||
scope._gen_ai_original_message_count[span.span_id] = len(messages)
|
||||
|
||||
return truncated_messages
|
||||
|
||||
@@ -51,6 +51,7 @@ else:
|
||||
# When changing this, update __all__ in __init__.py too
|
||||
__all__ = [
|
||||
"init",
|
||||
"add_attachment",
|
||||
"add_breadcrumb",
|
||||
"capture_event",
|
||||
"capture_exception",
|
||||
@@ -81,6 +82,10 @@ __all__ = [
|
||||
"start_transaction",
|
||||
"trace",
|
||||
"monitor",
|
||||
"start_session",
|
||||
"end_session",
|
||||
"set_transaction_name",
|
||||
"update_current_span",
|
||||
]
|
||||
|
||||
|
||||
@@ -184,6 +189,20 @@ def capture_exception(
|
||||
return get_current_scope().capture_exception(error, scope=scope, **scope_kwargs)
|
||||
|
||||
|
||||
@scopemethod
|
||||
def add_attachment(
|
||||
bytes=None, # type: Union[None, bytes, Callable[[], bytes]]
|
||||
filename=None, # type: Optional[str]
|
||||
path=None, # type: Optional[str]
|
||||
content_type=None, # type: Optional[str]
|
||||
add_to_transactions=False, # type: bool
|
||||
):
|
||||
# type: (...) -> None
|
||||
return get_isolation_scope().add_attachment(
|
||||
bytes, filename, path, content_type, add_to_transactions
|
||||
)
|
||||
|
||||
|
||||
@scopemethod
|
||||
def add_breadcrumb(
|
||||
crumb=None, # type: Optional[Breadcrumb]
|
||||
@@ -388,6 +407,10 @@ def start_transaction(
|
||||
|
||||
def set_measurement(name, value, unit=""):
|
||||
# type: (str, float, MeasurementUnit) -> None
|
||||
"""
|
||||
.. deprecated:: 2.28.0
|
||||
This function is deprecated and will be removed in the next major release.
|
||||
"""
|
||||
transaction = get_current_scope().transaction
|
||||
if transaction is not None:
|
||||
transaction.set_measurement(name, value, unit)
|
||||
@@ -431,3 +454,102 @@ def continue_trace(
|
||||
return get_isolation_scope().continue_trace(
|
||||
environ_or_headers, op, name, source, origin
|
||||
)
|
||||
|
||||
|
||||
@scopemethod
|
||||
def start_session(
|
||||
session_mode="application", # type: str
|
||||
):
|
||||
# type: (...) -> None
|
||||
return get_isolation_scope().start_session(session_mode=session_mode)
|
||||
|
||||
|
||||
@scopemethod
|
||||
def end_session():
|
||||
# type: () -> None
|
||||
return get_isolation_scope().end_session()
|
||||
|
||||
|
||||
@scopemethod
|
||||
def set_transaction_name(name, source=None):
|
||||
# type: (str, Optional[str]) -> None
|
||||
return get_current_scope().set_transaction_name(name, source)
|
||||
|
||||
|
||||
def update_current_span(op=None, name=None, attributes=None, data=None):
|
||||
# type: (Optional[str], Optional[str], Optional[dict[str, Union[str, int, float, bool]]], Optional[dict[str, Any]]) -> None
|
||||
"""
|
||||
Update the current active span with the provided parameters.
|
||||
|
||||
This function allows you to modify properties of the currently active span.
|
||||
If no span is currently active, this function will do nothing.
|
||||
|
||||
:param op: The operation name for the span. This is a high-level description
|
||||
of what the span represents (e.g., "http.client", "db.query").
|
||||
You can use predefined constants from :py:class:`sentry_sdk.consts.OP`
|
||||
or provide your own string. If not provided, the span's operation will
|
||||
remain unchanged.
|
||||
:type op: str or None
|
||||
|
||||
:param name: The human-readable name/description for the span. This provides
|
||||
more specific details about what the span represents (e.g., "GET /api/users",
|
||||
"SELECT * FROM users"). If not provided, the span's name will remain unchanged.
|
||||
:type name: str or None
|
||||
|
||||
:param data: A dictionary of key-value pairs to add as data to the span. This
|
||||
data will be merged with any existing span data. If not provided,
|
||||
no data will be added.
|
||||
|
||||
.. deprecated:: 2.35.0
|
||||
Use ``attributes`` instead. The ``data`` parameter will be removed
|
||||
in a future version.
|
||||
:type data: dict[str, Union[str, int, float, bool]] or None
|
||||
|
||||
:param attributes: A dictionary of key-value pairs to add as attributes to the span.
|
||||
Attribute values must be strings, integers, floats, or booleans. These
|
||||
attributes will be merged with any existing span data. If not provided,
|
||||
no attributes will be added.
|
||||
:type attributes: dict[str, Union[str, int, float, bool]] or None
|
||||
|
||||
:returns: None
|
||||
|
||||
.. versionadded:: 2.35.0
|
||||
|
||||
Example::
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.consts import OP
|
||||
|
||||
sentry_sdk.update_current_span(
|
||||
op=OP.FUNCTION,
|
||||
name="process_user_data",
|
||||
attributes={"user_id": 123, "batch_size": 50}
|
||||
)
|
||||
"""
|
||||
current_span = get_current_span()
|
||||
|
||||
if current_span is None:
|
||||
return
|
||||
|
||||
if op is not None:
|
||||
current_span.op = op
|
||||
|
||||
if name is not None:
|
||||
# internally it is still description
|
||||
current_span.description = name
|
||||
|
||||
if data is not None and attributes is not None:
|
||||
raise ValueError(
|
||||
"Cannot provide both `data` and `attributes`. Please use only `attributes`."
|
||||
)
|
||||
|
||||
if data is not None:
|
||||
warnings.warn(
|
||||
"The `data` parameter is deprecated. Please use `attributes` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
attributes = data
|
||||
|
||||
if attributes is not None:
|
||||
current_span.update_data(attributes)
|
||||
|
||||
@@ -8,6 +8,7 @@ from importlib import import_module
|
||||
from typing import TYPE_CHECKING, List, Dict, cast, overload
|
||||
import warnings
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk._compat import PY37, check_uwsgi_thread_support
|
||||
from sentry_sdk.utils import (
|
||||
AnnotatedValue,
|
||||
@@ -22,11 +23,16 @@ from sentry_sdk.utils import (
|
||||
handle_in_app,
|
||||
is_gevent,
|
||||
logger,
|
||||
get_before_send_log,
|
||||
get_before_send_metric,
|
||||
has_logs_enabled,
|
||||
has_metrics_enabled,
|
||||
)
|
||||
from sentry_sdk.serializer import serialize
|
||||
from sentry_sdk.tracing import trace
|
||||
from sentry_sdk.transport import BaseHttpTransport, make_transport
|
||||
from sentry_sdk.consts import (
|
||||
SPANDATA,
|
||||
DEFAULT_MAX_VALUE_LENGTH,
|
||||
DEFAULT_OPTIONS,
|
||||
INSTRUMENTER,
|
||||
@@ -34,6 +40,7 @@ from sentry_sdk.consts import (
|
||||
ClientConstructor,
|
||||
)
|
||||
from sentry_sdk.integrations import _DEFAULT_INTEGRATIONS, setup_integrations
|
||||
from sentry_sdk.integrations.dedupe import DedupeIntegration
|
||||
from sentry_sdk.sessions import SessionFlusher
|
||||
from sentry_sdk.envelope import Envelope
|
||||
from sentry_sdk.profiler.continuous_profiler import setup_continuous_profiler
|
||||
@@ -44,7 +51,6 @@ from sentry_sdk.profiler.transaction_profiler import (
|
||||
)
|
||||
from sentry_sdk.scrubber import EventScrubber
|
||||
from sentry_sdk.monitor import Monitor
|
||||
from sentry_sdk.spotlight import setup_spotlight
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
@@ -55,13 +61,14 @@ if TYPE_CHECKING:
|
||||
from typing import Union
|
||||
from typing import TypeVar
|
||||
|
||||
from sentry_sdk._types import Event, Hint, SDKInfo
|
||||
from sentry_sdk._types import Event, Hint, SDKInfo, Log, Metric
|
||||
from sentry_sdk.integrations import Integration
|
||||
from sentry_sdk.metrics import MetricsAggregator
|
||||
from sentry_sdk.scope import Scope
|
||||
from sentry_sdk.session import Session
|
||||
from sentry_sdk.spotlight import SpotlightClient
|
||||
from sentry_sdk.transport import Transport
|
||||
from sentry_sdk._log_batcher import LogBatcher
|
||||
from sentry_sdk._metrics_batcher import MetricsBatcher
|
||||
|
||||
I = TypeVar("I", bound=Integration) # noqa: E741
|
||||
|
||||
@@ -107,7 +114,7 @@ def _get_options(*args, **kwargs):
|
||||
rv["environment"] = os.environ.get("SENTRY_ENVIRONMENT") or "production"
|
||||
|
||||
if rv["debug"] is None:
|
||||
rv["debug"] = env_to_bool(os.environ.get("SENTRY_DEBUG", "False"), strict=True)
|
||||
rv["debug"] = env_to_bool(os.environ.get("SENTRY_DEBUG"), strict=True) or False
|
||||
|
||||
if rv["server_name"] is None and hasattr(socket, "gethostname"):
|
||||
rv["server_name"] = socket.gethostname()
|
||||
@@ -139,6 +146,11 @@ def _get_options(*args, **kwargs):
|
||||
)
|
||||
rv["socket_options"] = None
|
||||
|
||||
if rv["keep_alive"] is None:
|
||||
rv["keep_alive"] = (
|
||||
env_to_bool(os.environ.get("SENTRY_KEEP_ALIVE"), strict=True) or False
|
||||
)
|
||||
|
||||
if rv["enable_tracing"] is not None:
|
||||
warnings.warn(
|
||||
"The `enable_tracing` parameter is deprecated. Please use `traces_sample_rate` instead.",
|
||||
@@ -168,13 +180,12 @@ class BaseClient:
|
||||
|
||||
def __init__(self, options=None):
|
||||
# type: (Optional[Dict[str, Any]]) -> None
|
||||
self.options = (
|
||||
options if options is not None else DEFAULT_OPTIONS
|
||||
) # type: Dict[str, Any]
|
||||
self.options = options if options is not None else DEFAULT_OPTIONS # type: Dict[str, Any]
|
||||
|
||||
self.transport = None # type: Optional[Transport]
|
||||
self.monitor = None # type: Optional[Monitor]
|
||||
self.metrics_aggregator = None # type: Optional[MetricsAggregator]
|
||||
self.log_batcher = None # type: Optional[LogBatcher]
|
||||
self.metrics_batcher = None # type: Optional[MetricsBatcher]
|
||||
|
||||
def __getstate__(self, *args, **kwargs):
|
||||
# type: (*Any, **Any) -> Any
|
||||
@@ -206,6 +217,14 @@ class BaseClient:
|
||||
# type: (*Any, **Any) -> Optional[str]
|
||||
return None
|
||||
|
||||
def _capture_log(self, log):
|
||||
# type: (Log) -> None
|
||||
pass
|
||||
|
||||
def _capture_metric(self, metric):
|
||||
# type: (Metric) -> None
|
||||
pass
|
||||
|
||||
def capture_session(self, *args, **kwargs):
|
||||
# type: (*Any, **Any) -> None
|
||||
return None
|
||||
@@ -348,25 +367,19 @@ class _Client(BaseClient):
|
||||
|
||||
self.session_flusher = SessionFlusher(capture_func=_capture_envelope)
|
||||
|
||||
self.metrics_aggregator = None # type: Optional[MetricsAggregator]
|
||||
experiments = self.options.get("_experiments", {})
|
||||
if experiments.get("enable_metrics", True):
|
||||
# Context vars are not working correctly on Python <=3.6
|
||||
# with gevent.
|
||||
metrics_supported = not is_gevent() or PY37
|
||||
if metrics_supported:
|
||||
from sentry_sdk.metrics import MetricsAggregator
|
||||
self.log_batcher = None
|
||||
|
||||
self.metrics_aggregator = MetricsAggregator(
|
||||
capture_func=_capture_envelope,
|
||||
enable_code_locations=bool(
|
||||
experiments.get("metric_code_locations", True)
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Metrics not supported on Python 3.6 and lower with gevent."
|
||||
)
|
||||
if has_logs_enabled(self.options):
|
||||
from sentry_sdk._log_batcher import LogBatcher
|
||||
|
||||
self.log_batcher = LogBatcher(capture_func=_capture_envelope)
|
||||
|
||||
self.metrics_batcher = None
|
||||
|
||||
if has_metrics_enabled(self.options):
|
||||
from sentry_sdk._metrics_batcher import MetricsBatcher
|
||||
|
||||
self.metrics_batcher = MetricsBatcher(capture_func=_capture_envelope)
|
||||
|
||||
max_request_body_size = ("always", "never", "small", "medium")
|
||||
if self.options["max_request_body_size"] not in max_request_body_size:
|
||||
@@ -409,7 +422,17 @@ class _Client(BaseClient):
|
||||
)
|
||||
|
||||
if self.options.get("spotlight"):
|
||||
# This is intentionally here to prevent setting up spotlight
|
||||
# stuff we don't need unless spotlight is explicitly enabled
|
||||
from sentry_sdk.spotlight import setup_spotlight
|
||||
|
||||
self.spotlight = setup_spotlight(self.options)
|
||||
if not self.options["dsn"]:
|
||||
sample_all = lambda *_args, **_kwargs: 1.0
|
||||
self.options["send_default_pii"] = True
|
||||
self.options["error_sampler"] = sample_all
|
||||
self.options["traces_sampler"] = sample_all
|
||||
self.options["profiles_sampler"] = sample_all
|
||||
|
||||
sdk_name = get_sdk_name(list(self.integrations.keys()))
|
||||
SDK_INFO["name"] = sdk_name
|
||||
@@ -437,7 +460,7 @@ class _Client(BaseClient):
|
||||
|
||||
if (
|
||||
self.monitor
|
||||
or self.metrics_aggregator
|
||||
or self.log_batcher
|
||||
or has_profiling_enabled(self.options)
|
||||
or isinstance(self.transport, BaseHttpTransport)
|
||||
):
|
||||
@@ -461,11 +484,7 @@ class _Client(BaseClient):
|
||||
|
||||
Returns whether the client should send default PII (Personally Identifiable Information) data to Sentry.
|
||||
"""
|
||||
result = self.options.get("send_default_pii")
|
||||
if result is None:
|
||||
result = not self.options["dsn"] and self.spotlight is not None
|
||||
|
||||
return result
|
||||
return self.options.get("send_default_pii") or False
|
||||
|
||||
@property
|
||||
def dsn(self):
|
||||
@@ -482,12 +501,14 @@ class _Client(BaseClient):
|
||||
# type: (...) -> Optional[Event]
|
||||
|
||||
previous_total_spans = None # type: Optional[int]
|
||||
previous_total_breadcrumbs = None # type: Optional[int]
|
||||
|
||||
if event.get("timestamp") is None:
|
||||
event["timestamp"] = datetime.now(timezone.utc)
|
||||
|
||||
is_transaction = event.get("type") == "transaction"
|
||||
|
||||
if scope is not None:
|
||||
is_transaction = event.get("type") == "transaction"
|
||||
spans_before = len(cast(List[Dict[str, object]], event.get("spans", [])))
|
||||
event_ = scope.apply_to_event(event, hint, self.options)
|
||||
|
||||
@@ -518,9 +539,20 @@ class _Client(BaseClient):
|
||||
dropped_spans = event.pop("_dropped_spans", 0) + spans_delta # type: int
|
||||
if dropped_spans > 0:
|
||||
previous_total_spans = spans_before + dropped_spans
|
||||
if scope._n_breadcrumbs_truncated > 0:
|
||||
breadcrumbs = event.get("breadcrumbs", {})
|
||||
values = (
|
||||
breadcrumbs.get("values", [])
|
||||
if not isinstance(breadcrumbs, AnnotatedValue)
|
||||
else []
|
||||
)
|
||||
previous_total_breadcrumbs = (
|
||||
len(values) + scope._n_breadcrumbs_truncated
|
||||
)
|
||||
|
||||
if (
|
||||
self.options["attach_stacktrace"]
|
||||
not is_transaction
|
||||
and self.options["attach_stacktrace"]
|
||||
and "exception" not in event
|
||||
and "stacktrace" not in event
|
||||
and "threads" not in event
|
||||
@@ -566,10 +598,30 @@ class _Client(BaseClient):
|
||||
if event_scrubber:
|
||||
event_scrubber.scrub_event(event)
|
||||
|
||||
if scope is not None and scope._gen_ai_original_message_count:
|
||||
spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue
|
||||
if isinstance(spans, list):
|
||||
for span in spans:
|
||||
span_id = span.get("span_id", None)
|
||||
span_data = span.get("data", {})
|
||||
if (
|
||||
span_id
|
||||
and span_id in scope._gen_ai_original_message_count
|
||||
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
|
||||
):
|
||||
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
|
||||
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
|
||||
{"len": scope._gen_ai_original_message_count[span_id]},
|
||||
)
|
||||
if previous_total_spans is not None:
|
||||
event["spans"] = AnnotatedValue(
|
||||
event.get("spans", []), {"len": previous_total_spans}
|
||||
)
|
||||
if previous_total_breadcrumbs is not None:
|
||||
event["breadcrumbs"] = AnnotatedValue(
|
||||
event.get("breadcrumbs", {"values": []}),
|
||||
{"len": previous_total_breadcrumbs},
|
||||
)
|
||||
|
||||
# Postprocess the event here so that annotated types do
|
||||
# generally not surface in before_send
|
||||
@@ -599,6 +651,14 @@ class _Client(BaseClient):
|
||||
self.transport.record_lost_event(
|
||||
"before_send", data_category="error"
|
||||
)
|
||||
|
||||
# If this is an exception, reset the DedupeIntegration. It still
|
||||
# remembers the dropped exception as the last exception, meaning
|
||||
# that if the same exception happens again and is not dropped
|
||||
# in before_send, it'd get dropped by DedupeIntegration.
|
||||
if event.get("exception"):
|
||||
DedupeIntegration.reset_last_seen()
|
||||
|
||||
event = new_event
|
||||
|
||||
before_send_transaction = self.options["before_send_transaction"]
|
||||
@@ -740,6 +800,8 @@ class _Client(BaseClient):
|
||||
if exceptions:
|
||||
errored = True
|
||||
for error in exceptions:
|
||||
if isinstance(error, AnnotatedValue):
|
||||
error = error.value or {}
|
||||
mechanism = error.get("mechanism")
|
||||
if isinstance(mechanism, Mapping) and mechanism.get("handled") is False:
|
||||
crashed = True
|
||||
@@ -847,8 +909,135 @@ class _Client(BaseClient):
|
||||
|
||||
return return_value
|
||||
|
||||
def _capture_log(self, log):
|
||||
# type: (Optional[Log]) -> None
|
||||
if not has_logs_enabled(self.options) or log is None:
|
||||
return
|
||||
|
||||
current_scope = sentry_sdk.get_current_scope()
|
||||
isolation_scope = sentry_sdk.get_isolation_scope()
|
||||
|
||||
log["attributes"]["sentry.sdk.name"] = SDK_INFO["name"]
|
||||
log["attributes"]["sentry.sdk.version"] = SDK_INFO["version"]
|
||||
|
||||
server_name = self.options.get("server_name")
|
||||
if server_name is not None and SPANDATA.SERVER_ADDRESS not in log["attributes"]:
|
||||
log["attributes"][SPANDATA.SERVER_ADDRESS] = server_name
|
||||
|
||||
environment = self.options.get("environment")
|
||||
if environment is not None and "sentry.environment" not in log["attributes"]:
|
||||
log["attributes"]["sentry.environment"] = environment
|
||||
|
||||
release = self.options.get("release")
|
||||
if release is not None and "sentry.release" not in log["attributes"]:
|
||||
log["attributes"]["sentry.release"] = release
|
||||
|
||||
span = current_scope.span
|
||||
if span is not None and "sentry.trace.parent_span_id" not in log["attributes"]:
|
||||
log["attributes"]["sentry.trace.parent_span_id"] = span.span_id
|
||||
|
||||
if log.get("trace_id") is None:
|
||||
transaction = current_scope.transaction
|
||||
propagation_context = isolation_scope.get_active_propagation_context()
|
||||
if transaction is not None:
|
||||
log["trace_id"] = transaction.trace_id
|
||||
elif propagation_context is not None:
|
||||
log["trace_id"] = propagation_context.trace_id
|
||||
|
||||
# The user, if present, is always set on the isolation scope.
|
||||
if isolation_scope._user is not None:
|
||||
for log_attribute, user_attribute in (
|
||||
("user.id", "id"),
|
||||
("user.name", "username"),
|
||||
("user.email", "email"),
|
||||
):
|
||||
if (
|
||||
user_attribute in isolation_scope._user
|
||||
and log_attribute not in log["attributes"]
|
||||
):
|
||||
log["attributes"][log_attribute] = isolation_scope._user[
|
||||
user_attribute
|
||||
]
|
||||
|
||||
# If debug is enabled, log the log to the console
|
||||
debug = self.options.get("debug", False)
|
||||
if debug:
|
||||
logger.debug(
|
||||
f"[Sentry Logs] [{log.get('severity_text')}] {log.get('body')}"
|
||||
)
|
||||
|
||||
before_send_log = get_before_send_log(self.options)
|
||||
if before_send_log is not None:
|
||||
log = before_send_log(log, {})
|
||||
|
||||
if log is None:
|
||||
return
|
||||
|
||||
if self.log_batcher:
|
||||
self.log_batcher.add(log)
|
||||
|
||||
def _capture_metric(self, metric):
|
||||
# type: (Optional[Metric]) -> None
|
||||
if not has_metrics_enabled(self.options) or metric is None:
|
||||
return
|
||||
|
||||
isolation_scope = sentry_sdk.get_isolation_scope()
|
||||
|
||||
metric["attributes"]["sentry.sdk.name"] = SDK_INFO["name"]
|
||||
metric["attributes"]["sentry.sdk.version"] = SDK_INFO["version"]
|
||||
|
||||
environment = self.options.get("environment")
|
||||
if environment is not None and "sentry.environment" not in metric["attributes"]:
|
||||
metric["attributes"]["sentry.environment"] = environment
|
||||
|
||||
release = self.options.get("release")
|
||||
if release is not None and "sentry.release" not in metric["attributes"]:
|
||||
metric["attributes"]["sentry.release"] = release
|
||||
|
||||
span = sentry_sdk.get_current_span()
|
||||
metric["trace_id"] = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
if span:
|
||||
metric["trace_id"] = span.trace_id
|
||||
metric["span_id"] = span.span_id
|
||||
else:
|
||||
propagation_context = isolation_scope.get_active_propagation_context()
|
||||
if propagation_context and propagation_context.trace_id:
|
||||
metric["trace_id"] = propagation_context.trace_id
|
||||
|
||||
if isolation_scope._user is not None:
|
||||
for metric_attribute, user_attribute in (
|
||||
("user.id", "id"),
|
||||
("user.name", "username"),
|
||||
("user.email", "email"),
|
||||
):
|
||||
if (
|
||||
user_attribute in isolation_scope._user
|
||||
and metric_attribute not in metric["attributes"]
|
||||
):
|
||||
metric["attributes"][metric_attribute] = isolation_scope._user[
|
||||
user_attribute
|
||||
]
|
||||
|
||||
debug = self.options.get("debug", False)
|
||||
if debug:
|
||||
logger.debug(
|
||||
f"[Sentry Metrics] [{metric.get('type')}] {metric.get('name')}: {metric.get('value')}"
|
||||
)
|
||||
|
||||
before_send_metric = get_before_send_metric(self.options)
|
||||
if before_send_metric is not None:
|
||||
metric = before_send_metric(metric, {})
|
||||
|
||||
if metric is None:
|
||||
return
|
||||
|
||||
if self.metrics_batcher:
|
||||
self.metrics_batcher.add(metric)
|
||||
|
||||
def capture_session(
|
||||
self, session # type: Session
|
||||
self,
|
||||
session, # type: Session
|
||||
):
|
||||
# type: (...) -> None
|
||||
if not session.release:
|
||||
@@ -869,7 +1058,8 @@ class _Client(BaseClient):
|
||||
...
|
||||
|
||||
def get_integration(
|
||||
self, name_or_class # type: Union[str, Type[Integration]]
|
||||
self,
|
||||
name_or_class, # type: Union[str, Type[Integration]]
|
||||
):
|
||||
# type: (...) -> Optional[Integration]
|
||||
"""Returns the integration for this client by name or class.
|
||||
@@ -897,8 +1087,10 @@ class _Client(BaseClient):
|
||||
if self.transport is not None:
|
||||
self.flush(timeout=timeout, callback=callback)
|
||||
self.session_flusher.kill()
|
||||
if self.metrics_aggregator is not None:
|
||||
self.metrics_aggregator.kill()
|
||||
if self.log_batcher is not None:
|
||||
self.log_batcher.kill()
|
||||
if self.metrics_batcher is not None:
|
||||
self.metrics_batcher.kill()
|
||||
if self.monitor:
|
||||
self.monitor.kill()
|
||||
self.transport.kill()
|
||||
@@ -921,8 +1113,10 @@ class _Client(BaseClient):
|
||||
if timeout is None:
|
||||
timeout = self.options["shutdown_timeout"]
|
||||
self.session_flusher.flush()
|
||||
if self.metrics_aggregator is not None:
|
||||
self.metrics_aggregator.flush()
|
||||
if self.log_batcher is not None:
|
||||
self.log_batcher.flush()
|
||||
if self.metrics_batcher is not None:
|
||||
self.metrics_batcher.flush()
|
||||
self.transport.flush(timeout=timeout, callback=callback)
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
import uuid
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.utils import logger
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -54,4 +55,8 @@ def capture_checkin(
|
||||
|
||||
sentry_sdk.capture_event(check_in_event)
|
||||
|
||||
logger.debug(
|
||||
f"[Crons] Captured check-in ({check_in_event.get('check_in_id')}): {check_in_event.get('monitor_slug')} -> {check_in_event.get('status')}"
|
||||
)
|
||||
|
||||
return check_in_event["check_in_id"]
|
||||
|
||||
@@ -57,39 +57,49 @@ class Envelope:
|
||||
)
|
||||
|
||||
def add_event(
|
||||
self, event # type: Event
|
||||
self,
|
||||
event, # type: Event
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.add_item(Item(payload=PayloadRef(json=event), type="event"))
|
||||
|
||||
def add_transaction(
|
||||
self, transaction # type: Event
|
||||
self,
|
||||
transaction, # type: Event
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.add_item(Item(payload=PayloadRef(json=transaction), type="transaction"))
|
||||
|
||||
def add_profile(
|
||||
self, profile # type: Any
|
||||
self,
|
||||
profile, # type: Any
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.add_item(Item(payload=PayloadRef(json=profile), type="profile"))
|
||||
|
||||
def add_profile_chunk(
|
||||
self, profile_chunk # type: Any
|
||||
self,
|
||||
profile_chunk, # type: Any
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.add_item(
|
||||
Item(payload=PayloadRef(json=profile_chunk), type="profile_chunk")
|
||||
Item(
|
||||
payload=PayloadRef(json=profile_chunk),
|
||||
type="profile_chunk",
|
||||
headers={"platform": profile_chunk.get("platform", "python")},
|
||||
)
|
||||
)
|
||||
|
||||
def add_checkin(
|
||||
self, checkin # type: Any
|
||||
self,
|
||||
checkin, # type: Any
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.add_item(Item(payload=PayloadRef(json=checkin), type="check_in"))
|
||||
|
||||
def add_session(
|
||||
self, session # type: Union[Session, Any]
|
||||
self,
|
||||
session, # type: Union[Session, Any]
|
||||
):
|
||||
# type: (...) -> None
|
||||
if isinstance(session, Session):
|
||||
@@ -97,13 +107,15 @@ class Envelope:
|
||||
self.add_item(Item(payload=PayloadRef(json=session), type="session"))
|
||||
|
||||
def add_sessions(
|
||||
self, sessions # type: Any
|
||||
self,
|
||||
sessions, # type: Any
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.add_item(Item(payload=PayloadRef(json=sessions), type="sessions"))
|
||||
|
||||
def add_item(
|
||||
self, item # type: Item
|
||||
self,
|
||||
item, # type: Item
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.items.append(item)
|
||||
@@ -129,7 +141,8 @@ class Envelope:
|
||||
return iter(self.items)
|
||||
|
||||
def serialize_into(
|
||||
self, f # type: Any
|
||||
self,
|
||||
f, # type: Any
|
||||
):
|
||||
# type: (...) -> None
|
||||
f.write(json_dumps(self.headers))
|
||||
@@ -145,7 +158,8 @@ class Envelope:
|
||||
|
||||
@classmethod
|
||||
def deserialize_from(
|
||||
cls, f # type: Any
|
||||
cls,
|
||||
f, # type: Any
|
||||
):
|
||||
# type: (...) -> Envelope
|
||||
headers = parse_json(f.readline())
|
||||
@@ -159,7 +173,8 @@ class Envelope:
|
||||
|
||||
@classmethod
|
||||
def deserialize(
|
||||
cls, bytes # type: bytes
|
||||
cls,
|
||||
bytes, # type: bytes
|
||||
):
|
||||
# type: (...) -> Envelope
|
||||
return cls.deserialize_from(io.BytesIO(bytes))
|
||||
@@ -268,14 +283,16 @@ class Item:
|
||||
return "transaction"
|
||||
elif ty == "event":
|
||||
return "error"
|
||||
elif ty == "log":
|
||||
return "log_item"
|
||||
elif ty == "trace_metric":
|
||||
return "trace_metric"
|
||||
elif ty == "client_report":
|
||||
return "internal"
|
||||
elif ty == "profile":
|
||||
return "profile"
|
||||
elif ty == "profile_chunk":
|
||||
return "profile_chunk"
|
||||
elif ty == "statsd":
|
||||
return "metric_bucket"
|
||||
elif ty == "check_in":
|
||||
return "monitor"
|
||||
else:
|
||||
@@ -301,7 +318,8 @@ class Item:
|
||||
return None
|
||||
|
||||
def serialize_into(
|
||||
self, f # type: Any
|
||||
self,
|
||||
f, # type: Any
|
||||
):
|
||||
# type: (...) -> None
|
||||
headers = dict(self.headers)
|
||||
@@ -320,7 +338,8 @@ class Item:
|
||||
|
||||
@classmethod
|
||||
def deserialize_from(
|
||||
cls, f # type: Any
|
||||
cls,
|
||||
f, # type: Any
|
||||
):
|
||||
# type: (...) -> Optional[Item]
|
||||
line = f.readline().rstrip()
|
||||
@@ -335,7 +354,7 @@ class Item:
|
||||
# if no length was specified we need to read up to the end of line
|
||||
# and remove it (if it is present, i.e. not the very last char in an eof terminated envelope)
|
||||
payload = f.readline().rstrip(b"\n")
|
||||
if headers.get("type") in ("event", "transaction", "metric_buckets"):
|
||||
if headers.get("type") in ("event", "transaction"):
|
||||
rv = cls(headers=headers, payload=PayloadRef(json=parse_json(payload)))
|
||||
else:
|
||||
rv = cls(headers=headers, payload=payload)
|
||||
@@ -343,7 +362,8 @@ class Item:
|
||||
|
||||
@classmethod
|
||||
def deserialize(
|
||||
cls, bytes # type: bytes
|
||||
cls,
|
||||
bytes, # type: bytes
|
||||
):
|
||||
# type: (...) -> Optional[Item]
|
||||
return cls.deserialize_from(io.BytesIO(bytes))
|
||||
|
||||
@@ -15,7 +15,6 @@ DEFAULT_FLAG_CAPACITY = 100
|
||||
|
||||
|
||||
class FlagBuffer:
|
||||
|
||||
def __init__(self, capacity):
|
||||
# type: (int) -> None
|
||||
self.capacity = capacity
|
||||
@@ -64,5 +63,9 @@ def add_feature_flag(flag, result):
|
||||
Records a flag and its value to be sent on subsequent error events.
|
||||
We recommend you do this on flag evaluations. Flags are buffered per Sentry scope.
|
||||
"""
|
||||
flags = sentry_sdk.get_current_scope().flags
|
||||
flags = sentry_sdk.get_isolation_scope().flags
|
||||
flags.set(flag, result)
|
||||
|
||||
span = sentry_sdk.get_current_span()
|
||||
if span:
|
||||
span.set_flag(f"flag.evaluation.{flag}", result)
|
||||
|
||||
@@ -205,7 +205,8 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
|
||||
scope._isolation_scope.set(old_isolation_scope)
|
||||
|
||||
def run(
|
||||
self, callback # type: Callable[[], T]
|
||||
self,
|
||||
callback, # type: Callable[[], T]
|
||||
):
|
||||
# type: (...) -> T
|
||||
"""
|
||||
@@ -219,7 +220,8 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
|
||||
return callback()
|
||||
|
||||
def get_integration(
|
||||
self, name_or_class # type: Union[str, Type[Integration]]
|
||||
self,
|
||||
name_or_class, # type: Union[str, Type[Integration]]
|
||||
):
|
||||
# type: (...) -> Any
|
||||
"""
|
||||
@@ -277,7 +279,8 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
|
||||
return self._last_event_id
|
||||
|
||||
def bind_client(
|
||||
self, new # type: Optional[BaseClient]
|
||||
self,
|
||||
new, # type: Optional[BaseClient]
|
||||
):
|
||||
# type: (...) -> None
|
||||
"""
|
||||
@@ -430,7 +433,7 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
|
||||
transaction=None,
|
||||
instrumenter=INSTRUMENTER.SENTRY,
|
||||
custom_sampling_context=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# type: (Optional[Transaction], str, Optional[SamplingContext], Unpack[TransactionKwargs]) -> Union[Transaction, NoOpSpan]
|
||||
"""
|
||||
@@ -487,14 +490,16 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
|
||||
|
||||
@overload
|
||||
def push_scope(
|
||||
self, callback=None # type: Optional[None]
|
||||
self,
|
||||
callback=None, # type: Optional[None]
|
||||
):
|
||||
# type: (...) -> ContextManager[Scope]
|
||||
pass
|
||||
|
||||
@overload
|
||||
def push_scope( # noqa: F811
|
||||
self, callback # type: Callable[[Scope], None]
|
||||
self,
|
||||
callback, # type: Callable[[Scope], None]
|
||||
):
|
||||
# type: (...) -> None
|
||||
pass
|
||||
@@ -540,14 +545,16 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
|
||||
|
||||
@overload
|
||||
def configure_scope(
|
||||
self, callback=None # type: Optional[None]
|
||||
self,
|
||||
callback=None, # type: Optional[None]
|
||||
):
|
||||
# type: (...) -> ContextManager[Scope]
|
||||
pass
|
||||
|
||||
@overload
|
||||
def configure_scope( # noqa: F811
|
||||
self, callback # type: Callable[[Scope], None]
|
||||
self,
|
||||
callback, # type: Callable[[Scope], None]
|
||||
):
|
||||
# type: (...) -> None
|
||||
pass
|
||||
@@ -587,7 +594,8 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
|
||||
return inner()
|
||||
|
||||
def start_session(
|
||||
self, session_mode="application" # type: str
|
||||
self,
|
||||
session_mode="application", # type: str
|
||||
):
|
||||
# type: (...) -> None
|
||||
"""
|
||||
|
||||
@@ -95,6 +95,7 @@ _AUTO_ENABLING_INTEGRATIONS = [
|
||||
"sentry_sdk.integrations.huey.HueyIntegration",
|
||||
"sentry_sdk.integrations.huggingface_hub.HuggingfaceHubIntegration",
|
||||
"sentry_sdk.integrations.langchain.LangchainIntegration",
|
||||
"sentry_sdk.integrations.langgraph.LanggraphIntegration",
|
||||
"sentry_sdk.integrations.litestar.LitestarIntegration",
|
||||
"sentry_sdk.integrations.loguru.LoguruIntegration",
|
||||
"sentry_sdk.integrations.openai.OpenAIIntegration",
|
||||
@@ -131,6 +132,7 @@ _MIN_VERSIONS = {
|
||||
"celery": (4, 4, 7),
|
||||
"chalice": (1, 16, 0),
|
||||
"clickhouse_driver": (0, 2, 0),
|
||||
"cohere": (5, 4, 0),
|
||||
"django": (1, 8),
|
||||
"dramatiq": (1, 9),
|
||||
"falcon": (1, 4),
|
||||
@@ -138,13 +140,20 @@ _MIN_VERSIONS = {
|
||||
"flask": (1, 1, 4),
|
||||
"gql": (3, 4, 1),
|
||||
"graphene": (3, 3),
|
||||
"google_genai": (1, 29, 0), # google-genai
|
||||
"grpc": (1, 32, 0), # grpcio
|
||||
"huggingface_hub": (0, 22),
|
||||
"langchain": (0, 0, 210),
|
||||
"httpx": (0, 16, 0),
|
||||
"huggingface_hub": (0, 24, 7),
|
||||
"langchain": (0, 1, 0),
|
||||
"langgraph": (0, 6, 6),
|
||||
"launchdarkly": (9, 8, 0),
|
||||
"litellm": (1, 77, 5),
|
||||
"loguru": (0, 7, 0),
|
||||
"mcp": (1, 15, 0),
|
||||
"openai": (1, 0, 0),
|
||||
"openai_agents": (0, 0, 19),
|
||||
"openfeature": (0, 7, 1),
|
||||
"pydantic_ai": (1, 0, 0),
|
||||
"quart": (0, 16, 0),
|
||||
"ray": (2, 7, 0),
|
||||
"requests": (2, 0, 0),
|
||||
|
||||
@@ -20,9 +20,9 @@ from sentry_sdk.integrations._wsgi_common import (
|
||||
from sentry_sdk.tracing import (
|
||||
BAGGAGE_HEADER_NAME,
|
||||
SOURCE_FOR_STYLE,
|
||||
TRANSACTION_SOURCE_ROUTE,
|
||||
TransactionSource,
|
||||
)
|
||||
from sentry_sdk.tracing_utils import should_propagate_trace
|
||||
from sentry_sdk.tracing_utils import should_propagate_trace, add_http_request_source
|
||||
from sentry_sdk.utils import (
|
||||
capture_internal_exceptions,
|
||||
ensure_integration_enabled,
|
||||
@@ -129,7 +129,7 @@ class AioHttpIntegration(Integration):
|
||||
# If this transaction name makes it to the UI, AIOHTTP's
|
||||
# URL resolver did not find a route or died trying.
|
||||
name="generic AIOHTTP request",
|
||||
source=TRANSACTION_SOURCE_ROUTE,
|
||||
source=TransactionSource.ROUTE,
|
||||
origin=AioHttpIntegration.origin,
|
||||
)
|
||||
with sentry_sdk.start_transaction(
|
||||
@@ -279,6 +279,9 @@ def create_trace_config():
|
||||
span.set_data("reason", params.response.reason)
|
||||
span.finish()
|
||||
|
||||
with capture_internal_exceptions():
|
||||
add_http_request_source(span)
|
||||
|
||||
trace_config = TraceConfig()
|
||||
|
||||
trace_config.on_request_start.append(on_request_start)
|
||||
|
||||
+208
-73
@@ -3,16 +3,29 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.ai.monitoring import record_token_usage
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
from sentry_sdk.ai.utils import (
|
||||
set_data_normalized,
|
||||
normalize_message_roles,
|
||||
truncate_and_annotate_messages,
|
||||
get_start_span_function,
|
||||
)
|
||||
from sentry_sdk.consts import OP, SPANDATA, SPANSTATUS
|
||||
from sentry_sdk.integrations import _check_minimum_version, DidNotEnable, Integration
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.tracing_utils import set_span_errored
|
||||
from sentry_sdk.utils import (
|
||||
capture_internal_exceptions,
|
||||
event_from_exception,
|
||||
package_version,
|
||||
safe_serialize,
|
||||
)
|
||||
|
||||
try:
|
||||
try:
|
||||
from anthropic import NOT_GIVEN
|
||||
except ImportError:
|
||||
NOT_GIVEN = None
|
||||
|
||||
from anthropic.resources import AsyncMessages, Messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -45,6 +58,8 @@ class AnthropicIntegration(Integration):
|
||||
|
||||
def _capture_exception(exc):
|
||||
# type: (Any) -> None
|
||||
set_span_errored()
|
||||
|
||||
event, hint = event_from_exception(
|
||||
exc,
|
||||
client_options=sentry_sdk.get_client().options,
|
||||
@@ -53,8 +68,11 @@ def _capture_exception(exc):
|
||||
sentry_sdk.capture_event(event, hint=hint)
|
||||
|
||||
|
||||
def _calculate_token_usage(result, span):
|
||||
# type: (Messages, Span) -> None
|
||||
def _get_token_usage(result):
|
||||
# type: (Messages) -> tuple[int, int]
|
||||
"""
|
||||
Get token usage from the Anthropic response.
|
||||
"""
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
if hasattr(result, "usage"):
|
||||
@@ -64,31 +82,13 @@ def _calculate_token_usage(result, span):
|
||||
if hasattr(usage, "output_tokens") and isinstance(usage.output_tokens, int):
|
||||
output_tokens = usage.output_tokens
|
||||
|
||||
total_tokens = input_tokens + output_tokens
|
||||
record_token_usage(span, input_tokens, output_tokens, total_tokens)
|
||||
return input_tokens, output_tokens
|
||||
|
||||
|
||||
def _get_responses(content):
|
||||
# type: (list[Any]) -> list[dict[str, Any]]
|
||||
def _collect_ai_data(event, model, input_tokens, output_tokens, content_blocks):
|
||||
# type: (MessageStreamEvent, str | None, int, int, list[str]) -> tuple[str | None, int, int, list[str]]
|
||||
"""
|
||||
Get JSON of a Anthropic responses.
|
||||
"""
|
||||
responses = []
|
||||
for item in content:
|
||||
if hasattr(item, "text"):
|
||||
responses.append(
|
||||
{
|
||||
"type": item.type,
|
||||
"text": item.text,
|
||||
}
|
||||
)
|
||||
return responses
|
||||
|
||||
|
||||
def _collect_ai_data(event, input_tokens, output_tokens, content_blocks):
|
||||
# type: (MessageStreamEvent, int, int, list[str]) -> tuple[int, int, list[str]]
|
||||
"""
|
||||
Count token usage and collect content blocks from the AI streaming response.
|
||||
Collect model information, token usage, and collect content blocks from the AI streaming response.
|
||||
"""
|
||||
with capture_internal_exceptions():
|
||||
if hasattr(event, "type"):
|
||||
@@ -96,36 +96,135 @@ def _collect_ai_data(event, input_tokens, output_tokens, content_blocks):
|
||||
usage = event.message.usage
|
||||
input_tokens += usage.input_tokens
|
||||
output_tokens += usage.output_tokens
|
||||
model = event.message.model or model
|
||||
elif event.type == "content_block_start":
|
||||
pass
|
||||
elif event.type == "content_block_delta":
|
||||
if hasattr(event.delta, "text"):
|
||||
content_blocks.append(event.delta.text)
|
||||
elif hasattr(event.delta, "partial_json"):
|
||||
content_blocks.append(event.delta.partial_json)
|
||||
elif event.type == "content_block_stop":
|
||||
pass
|
||||
elif event.type == "message_delta":
|
||||
output_tokens += event.usage.output_tokens
|
||||
|
||||
return input_tokens, output_tokens, content_blocks
|
||||
return model, input_tokens, output_tokens, content_blocks
|
||||
|
||||
|
||||
def _add_ai_data_to_span(
|
||||
span, integration, input_tokens, output_tokens, content_blocks
|
||||
):
|
||||
# type: (Span, AnthropicIntegration, int, int, list[str]) -> None
|
||||
def _set_input_data(span, kwargs, integration):
|
||||
# type: (Span, dict[str, Any], AnthropicIntegration) -> None
|
||||
"""
|
||||
Add token usage and content blocks from the AI streaming response to the span.
|
||||
Set input data for the span based on the provided keyword arguments for the anthropic message creation.
|
||||
"""
|
||||
with capture_internal_exceptions():
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
complete_message = "".join(content_blocks)
|
||||
span.set_data(
|
||||
SPANDATA.AI_RESPONSES,
|
||||
[{"type": "text", "text": complete_message}],
|
||||
messages = kwargs.get("messages")
|
||||
if (
|
||||
messages is not None
|
||||
and len(messages) > 0
|
||||
and should_send_default_pii()
|
||||
and integration.include_prompts
|
||||
):
|
||||
normalized_messages = []
|
||||
for message in messages:
|
||||
if (
|
||||
message.get("role") == "user"
|
||||
and "content" in message
|
||||
and isinstance(message["content"], (list, tuple))
|
||||
):
|
||||
for item in message["content"]:
|
||||
if item.get("type") == "tool_result":
|
||||
normalized_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": {
|
||||
"tool_use_id": item.get("tool_use_id"),
|
||||
"output": item.get("content"),
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
normalized_messages.append(message)
|
||||
|
||||
role_normalized_messages = normalize_message_roles(normalized_messages)
|
||||
scope = sentry_sdk.get_current_scope()
|
||||
messages_data = truncate_and_annotate_messages(
|
||||
role_normalized_messages, span, scope
|
||||
)
|
||||
if messages_data is not None:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
|
||||
)
|
||||
total_tokens = input_tokens + output_tokens
|
||||
record_token_usage(span, input_tokens, output_tokens, total_tokens)
|
||||
span.set_data(SPANDATA.AI_STREAMING, True)
|
||||
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_RESPONSE_STREAMING, kwargs.get("stream", False)
|
||||
)
|
||||
|
||||
kwargs_keys_to_attributes = {
|
||||
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
|
||||
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
|
||||
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
|
||||
"top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
|
||||
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
|
||||
}
|
||||
for key, attribute in kwargs_keys_to_attributes.items():
|
||||
value = kwargs.get(key)
|
||||
if value is not NOT_GIVEN and value is not None:
|
||||
set_data_normalized(span, attribute, value)
|
||||
|
||||
# Input attributes: Tools
|
||||
tools = kwargs.get("tools")
|
||||
if tools is not NOT_GIVEN and tools is not None and len(tools) > 0:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
|
||||
)
|
||||
|
||||
|
||||
def _set_output_data(
|
||||
span,
|
||||
integration,
|
||||
model,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
content_blocks,
|
||||
finish_span=False,
|
||||
):
|
||||
# type: (Span, AnthropicIntegration, str | None, int | None, int | None, list[Any], bool) -> None
|
||||
"""
|
||||
Set output data for the span based on the AI response."""
|
||||
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, model)
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
output_messages = {
|
||||
"response": [],
|
||||
"tool": [],
|
||||
} # type: (dict[str, list[Any]])
|
||||
|
||||
for output in content_blocks:
|
||||
if output["type"] == "text":
|
||||
output_messages["response"].append(output["text"])
|
||||
elif output["type"] == "tool_use":
|
||||
output_messages["tool"].append(output)
|
||||
|
||||
if len(output_messages["tool"]) > 0:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
|
||||
output_messages["tool"],
|
||||
unpack=False,
|
||||
)
|
||||
|
||||
if len(output_messages["response"]) > 0:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_RESPONSE_TEXT, output_messages["response"]
|
||||
)
|
||||
|
||||
record_token_usage(
|
||||
span,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
||||
if finish_span:
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
|
||||
def _sentry_patched_create_common(f, *args, **kwargs):
|
||||
@@ -142,31 +241,41 @@ def _sentry_patched_create_common(f, *args, **kwargs):
|
||||
except TypeError:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
span = sentry_sdk.start_span(
|
||||
op=OP.ANTHROPIC_MESSAGES_CREATE,
|
||||
description="Anthropic messages create",
|
||||
model = kwargs.get("model", "")
|
||||
|
||||
span = get_start_span_function()(
|
||||
op=OP.GEN_AI_CHAT,
|
||||
name=f"chat {model}".strip(),
|
||||
origin=AnthropicIntegration.origin,
|
||||
)
|
||||
span.__enter__()
|
||||
|
||||
_set_input_data(span, kwargs, integration)
|
||||
|
||||
result = yield f, args, kwargs
|
||||
|
||||
# add data to span and finish it
|
||||
messages = list(kwargs["messages"])
|
||||
model = kwargs.get("model")
|
||||
|
||||
with capture_internal_exceptions():
|
||||
span.set_data(SPANDATA.AI_MODEL_ID, model)
|
||||
span.set_data(SPANDATA.AI_STREAMING, False)
|
||||
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
span.set_data(SPANDATA.AI_INPUT_MESSAGES, messages)
|
||||
|
||||
if hasattr(result, "content"):
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
span.set_data(SPANDATA.AI_RESPONSES, _get_responses(result.content))
|
||||
_calculate_token_usage(result, span)
|
||||
span.__exit__(None, None, None)
|
||||
input_tokens, output_tokens = _get_token_usage(result)
|
||||
|
||||
content_blocks = []
|
||||
for content_block in result.content:
|
||||
if hasattr(content_block, "to_dict"):
|
||||
content_blocks.append(content_block.to_dict())
|
||||
elif hasattr(content_block, "model_dump"):
|
||||
content_blocks.append(content_block.model_dump())
|
||||
elif hasattr(content_block, "text"):
|
||||
content_blocks.append({"type": "text", "text": content_block.text})
|
||||
|
||||
_set_output_data(
|
||||
span=span,
|
||||
integration=integration,
|
||||
model=getattr(result, "model", None),
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
content_blocks=content_blocks,
|
||||
finish_span=True,
|
||||
)
|
||||
|
||||
# Streaming response
|
||||
elif hasattr(result, "_iterator"):
|
||||
@@ -174,39 +283,53 @@ def _sentry_patched_create_common(f, *args, **kwargs):
|
||||
|
||||
def new_iterator():
|
||||
# type: () -> Iterator[MessageStreamEvent]
|
||||
model = None
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
content_blocks = [] # type: list[str]
|
||||
|
||||
for event in old_iterator:
|
||||
input_tokens, output_tokens, content_blocks = _collect_ai_data(
|
||||
event, input_tokens, output_tokens, content_blocks
|
||||
model, input_tokens, output_tokens, content_blocks = (
|
||||
_collect_ai_data(
|
||||
event, model, input_tokens, output_tokens, content_blocks
|
||||
)
|
||||
)
|
||||
if event.type != "message_stop":
|
||||
yield event
|
||||
yield event
|
||||
|
||||
_add_ai_data_to_span(
|
||||
span, integration, input_tokens, output_tokens, content_blocks
|
||||
_set_output_data(
|
||||
span=span,
|
||||
integration=integration,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
|
||||
finish_span=True,
|
||||
)
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
async def new_iterator_async():
|
||||
# type: () -> AsyncIterator[MessageStreamEvent]
|
||||
model = None
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
content_blocks = [] # type: list[str]
|
||||
|
||||
async for event in old_iterator:
|
||||
input_tokens, output_tokens, content_blocks = _collect_ai_data(
|
||||
event, input_tokens, output_tokens, content_blocks
|
||||
model, input_tokens, output_tokens, content_blocks = (
|
||||
_collect_ai_data(
|
||||
event, model, input_tokens, output_tokens, content_blocks
|
||||
)
|
||||
)
|
||||
if event.type != "message_stop":
|
||||
yield event
|
||||
yield event
|
||||
|
||||
_add_ai_data_to_span(
|
||||
span, integration, input_tokens, output_tokens, content_blocks
|
||||
_set_output_data(
|
||||
span=span,
|
||||
integration=integration,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
|
||||
finish_span=True,
|
||||
)
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
if str(type(result._iterator)) == "<class 'async_generator'>":
|
||||
result._iterator = new_iterator_async()
|
||||
@@ -248,7 +371,13 @@ def _wrap_message_create(f):
|
||||
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
|
||||
kwargs["integration"] = integration
|
||||
|
||||
return _execute_sync(f, *args, **kwargs)
|
||||
try:
|
||||
return _execute_sync(f, *args, **kwargs)
|
||||
finally:
|
||||
span = sentry_sdk.get_current_span()
|
||||
if span is not None and span.status == SPANSTATUS.ERROR:
|
||||
with capture_internal_exceptions():
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
return _sentry_patched_create_sync
|
||||
|
||||
@@ -281,6 +410,12 @@ def _wrap_message_create_async(f):
|
||||
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
|
||||
kwargs["integration"] = integration
|
||||
|
||||
return await _execute_async(f, *args, **kwargs)
|
||||
try:
|
||||
return await _execute_async(f, *args, **kwargs)
|
||||
finally:
|
||||
span = sentry_sdk.get_current_span()
|
||||
if span is not None and span.status == SPANSTATUS.ERROR:
|
||||
with capture_internal_exceptions():
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
return _sentry_patched_create_async
|
||||
|
||||
@@ -5,7 +5,7 @@ from sentry_sdk.consts import OP, SPANSTATUS
|
||||
from sentry_sdk.integrations import _check_minimum_version, DidNotEnable, Integration
|
||||
from sentry_sdk.integrations.logging import ignore_logger
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.tracing import Transaction, TRANSACTION_SOURCE_TASK
|
||||
from sentry_sdk.tracing import Transaction, TransactionSource
|
||||
from sentry_sdk.utils import (
|
||||
capture_internal_exceptions,
|
||||
ensure_integration_enabled,
|
||||
@@ -102,7 +102,7 @@ def patch_run_job():
|
||||
name="unknown arq task",
|
||||
status="ok",
|
||||
op=OP.QUEUE_TASK_ARQ,
|
||||
source=TRANSACTION_SOURCE_TASK,
|
||||
source=TransactionSource.TASK,
|
||||
origin=ArqIntegration.origin,
|
||||
)
|
||||
|
||||
@@ -199,12 +199,13 @@ def patch_create_worker():
|
||||
if isinstance(settings_cls, dict):
|
||||
if "functions" in settings_cls:
|
||||
settings_cls["functions"] = [
|
||||
_get_arq_function(func) for func in settings_cls["functions"]
|
||||
_get_arq_function(func)
|
||||
for func in settings_cls.get("functions", [])
|
||||
]
|
||||
if "cron_jobs" in settings_cls:
|
||||
settings_cls["cron_jobs"] = [
|
||||
_get_arq_cron_job(cron_job)
|
||||
for cron_job in settings_cls["cron_jobs"]
|
||||
for cron_job in settings_cls.get("cron_jobs", [])
|
||||
]
|
||||
|
||||
if hasattr(settings_cls, "functions"):
|
||||
@@ -213,16 +214,17 @@ def patch_create_worker():
|
||||
]
|
||||
if hasattr(settings_cls, "cron_jobs"):
|
||||
settings_cls.cron_jobs = [
|
||||
_get_arq_cron_job(cron_job) for cron_job in settings_cls.cron_jobs
|
||||
_get_arq_cron_job(cron_job)
|
||||
for cron_job in (settings_cls.cron_jobs or [])
|
||||
]
|
||||
|
||||
if "functions" in kwargs:
|
||||
kwargs["functions"] = [
|
||||
_get_arq_function(func) for func in kwargs["functions"]
|
||||
_get_arq_function(func) for func in kwargs.get("functions", [])
|
||||
]
|
||||
if "cron_jobs" in kwargs:
|
||||
kwargs["cron_jobs"] = [
|
||||
_get_arq_cron_job(cron_job) for cron_job in kwargs["cron_jobs"]
|
||||
_get_arq_cron_job(cron_job) for cron_job in kwargs.get("cron_jobs", [])
|
||||
]
|
||||
|
||||
return old_create_worker(*args, **kwargs)
|
||||
|
||||
@@ -12,7 +12,6 @@ from functools import partial
|
||||
import sentry_sdk
|
||||
from sentry_sdk.api import continue_trace
|
||||
from sentry_sdk.consts import OP
|
||||
|
||||
from sentry_sdk.integrations._asgi_common import (
|
||||
_get_headers,
|
||||
_get_request_data,
|
||||
@@ -25,10 +24,7 @@ from sentry_sdk.integrations._wsgi_common import (
|
||||
from sentry_sdk.sessions import track_session
|
||||
from sentry_sdk.tracing import (
|
||||
SOURCE_FOR_STYLE,
|
||||
TRANSACTION_SOURCE_ROUTE,
|
||||
TRANSACTION_SOURCE_URL,
|
||||
TRANSACTION_SOURCE_COMPONENT,
|
||||
TRANSACTION_SOURCE_CUSTOM,
|
||||
TransactionSource,
|
||||
)
|
||||
from sentry_sdk.utils import (
|
||||
ContextVar,
|
||||
@@ -45,7 +41,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
@@ -105,6 +100,7 @@ class SentryAsgiMiddleware:
|
||||
mechanism_type="asgi", # type: str
|
||||
span_origin="manual", # type: str
|
||||
http_methods_to_capture=DEFAULT_HTTP_METHODS_TO_CAPTURE, # type: Tuple[str, ...]
|
||||
asgi_version=None, # type: Optional[int]
|
||||
):
|
||||
# type: (...) -> None
|
||||
"""
|
||||
@@ -143,10 +139,32 @@ class SentryAsgiMiddleware:
|
||||
self.app = app
|
||||
self.http_methods_to_capture = http_methods_to_capture
|
||||
|
||||
if _looks_like_asgi3(app):
|
||||
self.__call__ = self._run_asgi3 # type: Callable[..., Any]
|
||||
else:
|
||||
self.__call__ = self._run_asgi2
|
||||
if asgi_version is None:
|
||||
if _looks_like_asgi3(app):
|
||||
asgi_version = 3
|
||||
else:
|
||||
asgi_version = 2
|
||||
|
||||
if asgi_version == 3:
|
||||
self.__call__ = self._run_asgi3
|
||||
elif asgi_version == 2:
|
||||
self.__call__ = self._run_asgi2 # type: ignore
|
||||
|
||||
def _capture_lifespan_exception(self, exc):
|
||||
# type: (Exception) -> None
|
||||
"""Capture exceptions raise in application lifespan handlers.
|
||||
|
||||
The separate function is needed to support overriding in derived integrations that use different catching mechanisms.
|
||||
"""
|
||||
return _capture_exception(exc=exc, mechanism_type=self.mechanism_type)
|
||||
|
||||
def _capture_request_exception(self, exc):
|
||||
# type: (Exception) -> None
|
||||
"""Capture exceptions raised in incoming request handlers.
|
||||
|
||||
The separate function is needed to support overriding in derived integrations that use different catching mechanisms.
|
||||
"""
|
||||
return _capture_exception(exc=exc, mechanism_type=self.mechanism_type)
|
||||
|
||||
def _run_asgi2(self, scope):
|
||||
# type: (Any) -> Any
|
||||
@@ -161,7 +179,7 @@ class SentryAsgiMiddleware:
|
||||
return await self._run_app(scope, receive, send, asgi_version=3)
|
||||
|
||||
async def _run_app(self, scope, receive, send, asgi_version):
|
||||
# type: (Any, Any, Any, Any, int) -> Any
|
||||
# type: (Any, Any, Any, int) -> Any
|
||||
is_recursive_asgi_middleware = _asgi_middleware_applied.get(False)
|
||||
is_lifespan = scope["type"] == "lifespan"
|
||||
if is_recursive_asgi_middleware or is_lifespan:
|
||||
@@ -172,7 +190,7 @@ class SentryAsgiMiddleware:
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
except Exception as exc:
|
||||
_capture_exception(exc, mechanism_type=self.mechanism_type)
|
||||
self._capture_lifespan_exception(exc)
|
||||
raise exc from None
|
||||
|
||||
_asgi_middleware_applied.set(True)
|
||||
@@ -195,8 +213,8 @@ class SentryAsgiMiddleware:
|
||||
|
||||
method = scope.get("method", "").upper()
|
||||
transaction = None
|
||||
if method in self.http_methods_to_capture:
|
||||
if ty in ("http", "websocket"):
|
||||
if ty in ("http", "websocket"):
|
||||
if ty == "websocket" or method in self.http_methods_to_capture:
|
||||
transaction = continue_trace(
|
||||
_get_headers(scope),
|
||||
op="{}.server".format(ty),
|
||||
@@ -204,37 +222,26 @@ class SentryAsgiMiddleware:
|
||||
source=transaction_source,
|
||||
origin=self.span_origin,
|
||||
)
|
||||
logger.debug(
|
||||
"[ASGI] Created transaction (continuing trace): %s",
|
||||
transaction,
|
||||
)
|
||||
else:
|
||||
transaction = Transaction(
|
||||
op=OP.HTTP_SERVER,
|
||||
name=transaction_name,
|
||||
source=transaction_source,
|
||||
origin=self.span_origin,
|
||||
)
|
||||
logger.debug(
|
||||
"[ASGI] Created transaction (new): %s", transaction
|
||||
)
|
||||
|
||||
transaction.set_tag("asgi.type", ty)
|
||||
logger.debug(
|
||||
"[ASGI] Set transaction name and source on transaction: '%s' / '%s'",
|
||||
transaction.name,
|
||||
transaction.source,
|
||||
else:
|
||||
transaction = Transaction(
|
||||
op=OP.HTTP_SERVER,
|
||||
name=transaction_name,
|
||||
source=transaction_source,
|
||||
origin=self.span_origin,
|
||||
)
|
||||
|
||||
with (
|
||||
if transaction:
|
||||
transaction.set_tag("asgi.type", ty)
|
||||
|
||||
transaction_context = (
|
||||
sentry_sdk.start_transaction(
|
||||
transaction,
|
||||
custom_sampling_context={"asgi_scope": scope},
|
||||
)
|
||||
if transaction is not None
|
||||
else nullcontext()
|
||||
):
|
||||
logger.debug("[ASGI] Started transaction: %s", transaction)
|
||||
)
|
||||
with transaction_context:
|
||||
try:
|
||||
|
||||
async def _sentry_wrapped_send(event):
|
||||
@@ -258,7 +265,7 @@ class SentryAsgiMiddleware:
|
||||
scope, receive, _sentry_wrapped_send
|
||||
)
|
||||
except Exception as exc:
|
||||
_capture_exception(exc, mechanism_type=self.mechanism_type)
|
||||
self._capture_request_exception(exc)
|
||||
raise exc from None
|
||||
finally:
|
||||
_asgi_middleware_applied.set(False)
|
||||
@@ -270,13 +277,18 @@ class SentryAsgiMiddleware:
|
||||
event["request"] = deepcopy(request_data)
|
||||
|
||||
# Only set transaction name if not already set by Starlette or FastAPI (or other frameworks)
|
||||
already_set = event["transaction"] != _DEFAULT_TRANSACTION_NAME and event[
|
||||
"transaction_info"
|
||||
].get("source") in [
|
||||
TRANSACTION_SOURCE_COMPONENT,
|
||||
TRANSACTION_SOURCE_ROUTE,
|
||||
TRANSACTION_SOURCE_CUSTOM,
|
||||
]
|
||||
transaction = event.get("transaction")
|
||||
transaction_source = (event.get("transaction_info") or {}).get("source")
|
||||
already_set = (
|
||||
transaction is not None
|
||||
and transaction != _DEFAULT_TRANSACTION_NAME
|
||||
and transaction_source
|
||||
in [
|
||||
TransactionSource.COMPONENT,
|
||||
TransactionSource.ROUTE,
|
||||
TransactionSource.CUSTOM,
|
||||
]
|
||||
)
|
||||
if not already_set:
|
||||
name, source = self._get_transaction_name_and_source(
|
||||
self.transaction_style, asgi_scope
|
||||
@@ -284,12 +296,6 @@ class SentryAsgiMiddleware:
|
||||
event["transaction"] = name
|
||||
event["transaction_info"] = {"source": source}
|
||||
|
||||
logger.debug(
|
||||
"[ASGI] Set transaction name and source in event_processor: '%s' / '%s'",
|
||||
event["transaction"],
|
||||
event["transaction_info"]["source"],
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
# Helper functions.
|
||||
@@ -313,7 +319,7 @@ class SentryAsgiMiddleware:
|
||||
name = transaction_from_function(endpoint) or ""
|
||||
else:
|
||||
name = _get_url(asgi_scope, "http" if ty == "http" else "ws", host=None)
|
||||
source = TRANSACTION_SOURCE_URL
|
||||
source = TransactionSource.URL
|
||||
|
||||
elif transaction_style == "url":
|
||||
# FastAPI includes the route object in the scope to let Sentry extract the
|
||||
@@ -325,11 +331,11 @@ class SentryAsgiMiddleware:
|
||||
name = path
|
||||
else:
|
||||
name = _get_url(asgi_scope, "http" if ty == "http" else "ws", host=None)
|
||||
source = TRANSACTION_SOURCE_URL
|
||||
source = TransactionSource.URL
|
||||
|
||||
if name is None:
|
||||
name = _DEFAULT_TRANSACTION_NAME
|
||||
source = TRANSACTION_SOURCE_ROUTE
|
||||
source = TransactionSource.ROUTE
|
||||
return name, source
|
||||
|
||||
return name, source
|
||||
|
||||
@@ -3,7 +3,7 @@ import sys
|
||||
import sentry_sdk
|
||||
from sentry_sdk.consts import OP
|
||||
from sentry_sdk.integrations import Integration, DidNotEnable
|
||||
from sentry_sdk.utils import event_from_exception, reraise
|
||||
from sentry_sdk.utils import event_from_exception, logger, reraise
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
@@ -11,7 +11,7 @@ try:
|
||||
except ImportError:
|
||||
raise DidNotEnable("asyncio not available")
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import cast, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
@@ -39,7 +39,7 @@ def patch_asyncio():
|
||||
def _sentry_task_factory(loop, coro, **kwargs):
|
||||
# type: (asyncio.AbstractEventLoop, Coroutine[Any, Any, Any], Any) -> asyncio.Future[Any]
|
||||
|
||||
async def _coro_creating_hub_and_span():
|
||||
async def _task_with_sentry_span_creation():
|
||||
# type: () -> Any
|
||||
result = None
|
||||
|
||||
@@ -51,32 +51,54 @@ def patch_asyncio():
|
||||
):
|
||||
try:
|
||||
result = await coro
|
||||
except StopAsyncIteration as e:
|
||||
raise e from None
|
||||
except Exception:
|
||||
reraise(*_capture_exception())
|
||||
|
||||
return result
|
||||
|
||||
task = None
|
||||
|
||||
# Trying to use user set task factory (if there is one)
|
||||
if orig_task_factory:
|
||||
return orig_task_factory(loop, _coro_creating_hub_and_span(), **kwargs)
|
||||
task = orig_task_factory(
|
||||
loop, _task_with_sentry_span_creation(), **kwargs
|
||||
)
|
||||
|
||||
# The default task factory in `asyncio` does not have its own function
|
||||
# but is just a couple of lines in `asyncio.base_events.create_task()`
|
||||
# Those lines are copied here.
|
||||
if task is None:
|
||||
# The default task factory in `asyncio` does not have its own function
|
||||
# but is just a couple of lines in `asyncio.base_events.create_task()`
|
||||
# Those lines are copied here.
|
||||
|
||||
# WARNING:
|
||||
# If the default behavior of the task creation in asyncio changes,
|
||||
# this will break!
|
||||
task = Task(_coro_creating_hub_and_span(), loop=loop, **kwargs)
|
||||
if task._source_traceback: # type: ignore
|
||||
del task._source_traceback[-1] # type: ignore
|
||||
# WARNING:
|
||||
# If the default behavior of the task creation in asyncio changes,
|
||||
# this will break!
|
||||
task = Task(_task_with_sentry_span_creation(), loop=loop, **kwargs)
|
||||
if task._source_traceback: # type: ignore
|
||||
del task._source_traceback[-1] # type: ignore
|
||||
|
||||
# Set the task name to include the original coroutine's name
|
||||
try:
|
||||
cast("asyncio.Task[Any]", task).set_name(
|
||||
f"{get_name(coro)} (Sentry-wrapped)"
|
||||
)
|
||||
except AttributeError:
|
||||
# set_name might not be available in all Python versions
|
||||
pass
|
||||
|
||||
return task
|
||||
|
||||
loop.set_task_factory(_sentry_task_factory) # type: ignore
|
||||
|
||||
except RuntimeError:
|
||||
# When there is no running loop, we have nothing to patch.
|
||||
pass
|
||||
logger.warning(
|
||||
"There is no running asyncio loop so there is nothing Sentry can patch. "
|
||||
"Please make sure you call sentry_sdk.init() within a running "
|
||||
"asyncio loop for the AsyncioIntegration to work. "
|
||||
"See https://docs.sentry.io/platforms/python/integrations/asyncio/"
|
||||
)
|
||||
|
||||
|
||||
def _capture_exception():
|
||||
|
||||
@@ -10,7 +10,7 @@ import sentry_sdk
|
||||
from sentry_sdk.api import continue_trace
|
||||
from sentry_sdk.consts import OP
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.tracing import TRANSACTION_SOURCE_COMPONENT
|
||||
from sentry_sdk.tracing import TransactionSource
|
||||
from sentry_sdk.utils import (
|
||||
AnnotatedValue,
|
||||
capture_internal_exceptions,
|
||||
@@ -61,7 +61,10 @@ def _wrap_init_error(init_error):
|
||||
|
||||
else:
|
||||
# Fall back to AWS lambdas JSON representation of the error
|
||||
sentry_event = _event_from_error_json(json.loads(args[1]))
|
||||
error_info = args[1]
|
||||
if isinstance(error_info, str):
|
||||
error_info = json.loads(error_info)
|
||||
sentry_event = _event_from_error_json(error_info)
|
||||
sentry_sdk.capture_event(sentry_event)
|
||||
|
||||
return init_error(*args, **kwargs)
|
||||
@@ -135,6 +138,8 @@ def _wrap_handler(handler):
|
||||
timeout_thread = TimeoutThread(
|
||||
waiting_time,
|
||||
configured_time / MILLIS_TO_SECONDS,
|
||||
isolation_scope=scope,
|
||||
current_scope=sentry_sdk.get_current_scope(),
|
||||
)
|
||||
|
||||
# Starting the thread to raise timeout warning exception
|
||||
@@ -150,7 +155,7 @@ def _wrap_handler(handler):
|
||||
headers,
|
||||
op=OP.FUNCTION_AWS,
|
||||
name=aws_context.function_name,
|
||||
source=TRANSACTION_SOURCE_COMPONENT,
|
||||
source=TransactionSource.COMPONENT,
|
||||
origin=AwsLambdaIntegration.origin,
|
||||
)
|
||||
with sentry_sdk.start_transaction(
|
||||
|
||||
@@ -177,14 +177,20 @@ def _set_transaction_name_and_source(event, transaction_style, request):
|
||||
name = ""
|
||||
|
||||
if transaction_style == "url":
|
||||
name = request.route.rule or ""
|
||||
try:
|
||||
name = request.route.rule or ""
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
elif transaction_style == "endpoint":
|
||||
name = (
|
||||
request.route.name
|
||||
or transaction_from_function(request.route.callback)
|
||||
or ""
|
||||
)
|
||||
try:
|
||||
name = (
|
||||
request.route.name
|
||||
or transaction_from_function(request.route.callback)
|
||||
or ""
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
event["transaction"] = name
|
||||
event["transaction_info"] = {"source": SOURCE_FOR_STYLE[transaction_style]}
|
||||
|
||||
+5
-4
@@ -9,12 +9,12 @@ from sentry_sdk.consts import OP, SPANSTATUS, SPANDATA
|
||||
from sentry_sdk.integrations import _check_minimum_version, Integration, DidNotEnable
|
||||
from sentry_sdk.integrations.celery.beat import (
|
||||
_patch_beat_apply_entry,
|
||||
_patch_redbeat_maybe_due,
|
||||
_patch_redbeat_apply_async,
|
||||
_setup_celery_beat_signals,
|
||||
)
|
||||
from sentry_sdk.integrations.celery.utils import _now_seconds_since_epoch
|
||||
from sentry_sdk.integrations.logging import ignore_logger
|
||||
from sentry_sdk.tracing import BAGGAGE_HEADER_NAME, TRANSACTION_SOURCE_TASK
|
||||
from sentry_sdk.tracing import BAGGAGE_HEADER_NAME, TransactionSource
|
||||
from sentry_sdk.tracing_utils import Baggage
|
||||
from sentry_sdk.utils import (
|
||||
capture_internal_exceptions,
|
||||
@@ -73,7 +73,7 @@ class CeleryIntegration(Integration):
|
||||
self.exclude_beat_tasks = exclude_beat_tasks
|
||||
|
||||
_patch_beat_apply_entry()
|
||||
_patch_redbeat_maybe_due()
|
||||
_patch_redbeat_apply_async()
|
||||
_setup_celery_beat_signals(monitor_beat_tasks)
|
||||
|
||||
@staticmethod
|
||||
@@ -319,7 +319,7 @@ def _wrap_tracer(task, f):
|
||||
headers,
|
||||
op=OP.QUEUE_TASK_CELERY,
|
||||
name="unknown celery task",
|
||||
source=TRANSACTION_SOURCE_TASK,
|
||||
source=TransactionSource.TASK,
|
||||
origin=CeleryIntegration.origin,
|
||||
)
|
||||
transaction.name = task.name
|
||||
@@ -391,6 +391,7 @@ def _wrap_task_call(task, f):
|
||||
)
|
||||
|
||||
if latency is not None:
|
||||
latency *= 1000 # milliseconds
|
||||
span.set_data(SPANDATA.MESSAGING_MESSAGE_RECEIVE_LATENCY, latency)
|
||||
|
||||
with capture_internal_exceptions():
|
||||
|
||||
+2
-2
@@ -202,12 +202,12 @@ def _patch_beat_apply_entry():
|
||||
Scheduler.apply_entry = _wrap_beat_scheduler(Scheduler.apply_entry)
|
||||
|
||||
|
||||
def _patch_redbeat_maybe_due():
|
||||
def _patch_redbeat_apply_async():
|
||||
# type: () -> None
|
||||
if RedBeatScheduler is None:
|
||||
return
|
||||
|
||||
RedBeatScheduler.maybe_due = _wrap_beat_scheduler(RedBeatScheduler.maybe_due)
|
||||
RedBeatScheduler.apply_async = _wrap_beat_scheduler(RedBeatScheduler.apply_async)
|
||||
|
||||
|
||||
def _setup_celery_beat_signals(monitor_beat_tasks):
|
||||
|
||||
@@ -4,7 +4,7 @@ from functools import wraps
|
||||
import sentry_sdk
|
||||
from sentry_sdk.integrations import Integration, DidNotEnable
|
||||
from sentry_sdk.integrations.aws_lambda import _make_request_event_processor
|
||||
from sentry_sdk.tracing import TRANSACTION_SOURCE_COMPONENT
|
||||
from sentry_sdk.tracing import TransactionSource
|
||||
from sentry_sdk.utils import (
|
||||
capture_internal_exceptions,
|
||||
event_from_exception,
|
||||
@@ -67,7 +67,7 @@ def _get_view_function_response(app, view_function, function_args):
|
||||
configured_time = app.lambda_context.get_remaining_time_in_millis()
|
||||
scope.set_transaction_name(
|
||||
app.lambda_context.function_name,
|
||||
source=TRANSACTION_SOURCE_COMPONENT,
|
||||
source=TransactionSource.COMPONENT,
|
||||
)
|
||||
|
||||
scope.add_event_processor(
|
||||
|
||||
+33
-13
@@ -11,7 +11,8 @@ from typing import TYPE_CHECKING, TypeVar
|
||||
# without introducing a hard dependency on `typing_extensions`
|
||||
# from: https://stackoverflow.com/a/71944042/300572
|
||||
if TYPE_CHECKING:
|
||||
from typing import ParamSpec, Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, ParamSpec, Callable
|
||||
else:
|
||||
# Fake ParamSpec
|
||||
class ParamSpec:
|
||||
@@ -49,9 +50,7 @@ class ClickhouseDriverIntegration(Integration):
|
||||
)
|
||||
|
||||
# If the query contains parameters then the send_data function is used to send those parameters to clickhouse
|
||||
clickhouse_driver.client.Client.send_data = _wrap_send_data(
|
||||
clickhouse_driver.client.Client.send_data
|
||||
)
|
||||
_wrap_send_data()
|
||||
|
||||
# Every query ends either with the Client's `receive_end_of_query` (no result expected)
|
||||
# or its `receive_result` (result expected)
|
||||
@@ -128,23 +127,44 @@ def _wrap_end(f: Callable[P, T]) -> Callable[P, T]:
|
||||
return _inner_end
|
||||
|
||||
|
||||
def _wrap_send_data(f: Callable[P, T]) -> Callable[P, T]:
|
||||
def _inner_send_data(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
instance = args[0] # type: clickhouse_driver.client.Client
|
||||
data = args[2]
|
||||
span = getattr(instance.connection, "_sentry_span", None)
|
||||
def _wrap_send_data() -> None:
|
||||
original_send_data = clickhouse_driver.client.Client.send_data
|
||||
|
||||
def _inner_send_data( # type: ignore[no-untyped-def] # clickhouse-driver does not type send_data
|
||||
self, sample_block, data, types_check=False, columnar=False, *args, **kwargs
|
||||
):
|
||||
span = getattr(self.connection, "_sentry_span", None)
|
||||
|
||||
if span is not None:
|
||||
_set_db_data(span, instance.connection)
|
||||
_set_db_data(span, self.connection)
|
||||
|
||||
if should_send_default_pii():
|
||||
db_params = span._data.get("db.params", [])
|
||||
db_params.extend(data)
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
db_params.extend(data)
|
||||
|
||||
else: # data is a generic iterator
|
||||
orig_data = data
|
||||
|
||||
# Wrap the generator to add items to db.params as they are yielded.
|
||||
# This allows us to send the params to Sentry without needing to allocate
|
||||
# memory for the entire generator at once.
|
||||
def wrapped_generator() -> "Iterator[Any]":
|
||||
for item in orig_data:
|
||||
db_params.append(item)
|
||||
yield item
|
||||
|
||||
# Replace the original iterator with the wrapped one.
|
||||
data = wrapped_generator()
|
||||
|
||||
span.set_data("db.params", db_params)
|
||||
|
||||
return f(*args, **kwargs)
|
||||
return original_send_data(
|
||||
self, sample_block, data, types_check, columnar, *args, **kwargs
|
||||
)
|
||||
|
||||
return _inner_send_data
|
||||
clickhouse_driver.client.Client.send_data = _inner_send_data
|
||||
|
||||
|
||||
def _set_db_data(
|
||||
|
||||
+29
-7
@@ -13,6 +13,8 @@ if TYPE_CHECKING:
|
||||
|
||||
CONTEXT_TYPE = "cloud_resource"
|
||||
|
||||
HTTP_TIMEOUT = 2.0
|
||||
|
||||
AWS_METADATA_HOST = "169.254.169.254"
|
||||
AWS_TOKEN_URL = "http://{}/latest/api/token".format(AWS_METADATA_HOST)
|
||||
AWS_METADATA_URL = "http://{}/latest/dynamic/instance-identity/document".format(
|
||||
@@ -59,7 +61,7 @@ class CloudResourceContextIntegration(Integration):
|
||||
cloud_provider = ""
|
||||
|
||||
aws_token = ""
|
||||
http = urllib3.PoolManager()
|
||||
http = urllib3.PoolManager(timeout=HTTP_TIMEOUT)
|
||||
|
||||
gcp_metadata = None
|
||||
|
||||
@@ -83,7 +85,13 @@ class CloudResourceContextIntegration(Integration):
|
||||
cls.aws_token = r.data.decode()
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
except urllib3.exceptions.TimeoutError:
|
||||
logger.debug(
|
||||
"AWS metadata service timed out after %s seconds", HTTP_TIMEOUT
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("Error checking AWS metadata service: %s", str(e))
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@@ -131,8 +139,12 @@ class CloudResourceContextIntegration(Integration):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
except urllib3.exceptions.TimeoutError:
|
||||
logger.debug(
|
||||
"AWS metadata service timed out after %s seconds", HTTP_TIMEOUT
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Error fetching AWS metadata: %s", str(e))
|
||||
|
||||
return ctx
|
||||
|
||||
@@ -152,7 +164,13 @@ class CloudResourceContextIntegration(Integration):
|
||||
cls.gcp_metadata = json.loads(r.data.decode("utf-8"))
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
except urllib3.exceptions.TimeoutError:
|
||||
logger.debug(
|
||||
"GCP metadata service timed out after %s seconds", HTTP_TIMEOUT
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("Error checking GCP metadata service: %s", str(e))
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@@ -201,8 +219,12 @@ class CloudResourceContextIntegration(Integration):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
except urllib3.exceptions.TimeoutError:
|
||||
logger.debug(
|
||||
"GCP metadata service timed out after %s seconds", HTTP_TIMEOUT
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Error fetching GCP metadata: %s", str(e))
|
||||
|
||||
return ctx
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ from sentry_sdk.ai.utils import set_data_normalized
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sentry_sdk.tracing_utils import set_span_errored
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Iterator
|
||||
from sentry_sdk.tracing import Span
|
||||
@@ -52,17 +54,17 @@ COLLECTED_PII_CHAT_PARAMS = {
|
||||
}
|
||||
|
||||
COLLECTED_CHAT_RESP_ATTRS = {
|
||||
"generation_id": "ai.generation_id",
|
||||
"is_search_required": "ai.is_search_required",
|
||||
"finish_reason": "ai.finish_reason",
|
||||
"generation_id": SPANDATA.AI_GENERATION_ID,
|
||||
"is_search_required": SPANDATA.AI_SEARCH_REQUIRED,
|
||||
"finish_reason": SPANDATA.AI_FINISH_REASON,
|
||||
}
|
||||
|
||||
COLLECTED_PII_CHAT_RESP_ATTRS = {
|
||||
"citations": "ai.citations",
|
||||
"documents": "ai.documents",
|
||||
"search_queries": "ai.search_queries",
|
||||
"search_results": "ai.search_results",
|
||||
"tool_calls": "ai.tool_calls",
|
||||
"citations": SPANDATA.AI_CITATIONS,
|
||||
"documents": SPANDATA.AI_DOCUMENTS,
|
||||
"search_queries": SPANDATA.AI_SEARCH_QUERIES,
|
||||
"search_results": SPANDATA.AI_SEARCH_RESULTS,
|
||||
"tool_calls": SPANDATA.AI_TOOL_CALLS,
|
||||
}
|
||||
|
||||
|
||||
@@ -84,6 +86,8 @@ class CohereIntegration(Integration):
|
||||
|
||||
def _capture_exception(exc):
|
||||
# type: (Any) -> None
|
||||
set_span_errored()
|
||||
|
||||
event, hint = event_from_exception(
|
||||
exc,
|
||||
client_options=sentry_sdk.get_client().options,
|
||||
@@ -116,18 +120,18 @@ def _wrap_chat(f, streaming):
|
||||
if hasattr(res.meta, "billed_units"):
|
||||
record_token_usage(
|
||||
span,
|
||||
prompt_tokens=res.meta.billed_units.input_tokens,
|
||||
completion_tokens=res.meta.billed_units.output_tokens,
|
||||
input_tokens=res.meta.billed_units.input_tokens,
|
||||
output_tokens=res.meta.billed_units.output_tokens,
|
||||
)
|
||||
elif hasattr(res.meta, "tokens"):
|
||||
record_token_usage(
|
||||
span,
|
||||
prompt_tokens=res.meta.tokens.input_tokens,
|
||||
completion_tokens=res.meta.tokens.output_tokens,
|
||||
input_tokens=res.meta.tokens.input_tokens,
|
||||
output_tokens=res.meta.tokens.output_tokens,
|
||||
)
|
||||
|
||||
if hasattr(res.meta, "warnings"):
|
||||
set_data_normalized(span, "ai.warnings", res.meta.warnings)
|
||||
set_data_normalized(span, SPANDATA.AI_WARNINGS, res.meta.warnings)
|
||||
|
||||
@wraps(f)
|
||||
def new_chat(*args, **kwargs):
|
||||
@@ -238,7 +242,7 @@ def _wrap_embed(f):
|
||||
should_send_default_pii() and integration.include_prompts
|
||||
):
|
||||
if isinstance(kwargs["texts"], str):
|
||||
set_data_normalized(span, "ai.texts", [kwargs["texts"]])
|
||||
set_data_normalized(span, SPANDATA.AI_TEXTS, [kwargs["texts"]])
|
||||
elif (
|
||||
isinstance(kwargs["texts"], list)
|
||||
and len(kwargs["texts"]) > 0
|
||||
@@ -262,7 +266,7 @@ def _wrap_embed(f):
|
||||
):
|
||||
record_token_usage(
|
||||
span,
|
||||
prompt_tokens=res.meta.billed_units.input_tokens,
|
||||
input_tokens=res.meta.billed_units.input_tokens,
|
||||
total_tokens=res.meta.billed_units.input_tokens,
|
||||
)
|
||||
return res
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import weakref
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.utils import ContextVar
|
||||
from sentry_sdk.utils import ContextVar, logger
|
||||
from sentry_sdk.integrations import Integration
|
||||
from sentry_sdk.scope import add_global_event_processor
|
||||
|
||||
@@ -35,8 +37,31 @@ class DedupeIntegration(Integration):
|
||||
if exc_info is None:
|
||||
return event
|
||||
|
||||
last_seen = integration._last_seen.get(None)
|
||||
if last_seen is not None:
|
||||
# last_seen is either a weakref or the original instance
|
||||
last_seen = (
|
||||
last_seen() if isinstance(last_seen, weakref.ref) else last_seen
|
||||
)
|
||||
|
||||
exc = exc_info[1]
|
||||
if integration._last_seen.get(None) is exc:
|
||||
if last_seen is exc:
|
||||
logger.info("DedupeIntegration dropped duplicated error event %s", exc)
|
||||
return None
|
||||
integration._last_seen.set(exc)
|
||||
|
||||
# we can only weakref non builtin types
|
||||
try:
|
||||
integration._last_seen.set(weakref.ref(exc))
|
||||
except TypeError:
|
||||
integration._last_seen.set(exc)
|
||||
|
||||
return event
|
||||
|
||||
@staticmethod
|
||||
def reset_last_seen():
|
||||
# type: () -> None
|
||||
integration = sentry_sdk.get_client().get_integration(DedupeIntegration)
|
||||
if integration is None:
|
||||
return
|
||||
|
||||
integration._last_seen.set(None)
|
||||
|
||||
+15
-4
@@ -7,8 +7,8 @@ from importlib import import_module
|
||||
import sentry_sdk
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
from sentry_sdk.scope import add_global_event_processor, should_send_default_pii
|
||||
from sentry_sdk.serializer import add_global_repr_processor
|
||||
from sentry_sdk.tracing import SOURCE_FOR_STYLE, TRANSACTION_SOURCE_URL
|
||||
from sentry_sdk.serializer import add_global_repr_processor, add_repr_sequence_type
|
||||
from sentry_sdk.tracing import SOURCE_FOR_STYLE, TransactionSource
|
||||
from sentry_sdk.tracing_utils import add_query_source, record_sql_queries
|
||||
from sentry_sdk.utils import (
|
||||
AnnotatedValue,
|
||||
@@ -269,6 +269,7 @@ class DjangoIntegration(Integration):
|
||||
patch_views()
|
||||
patch_templates()
|
||||
patch_signals()
|
||||
add_template_context_repr_sequence()
|
||||
|
||||
if patch_caching is not None:
|
||||
patch_caching()
|
||||
@@ -398,7 +399,7 @@ def _set_transaction_name_and_source(scope, transaction_style, request):
|
||||
|
||||
if transaction_name is None:
|
||||
transaction_name = request.path_info
|
||||
source = TRANSACTION_SOURCE_URL
|
||||
source = TransactionSource.URL
|
||||
else:
|
||||
source = SOURCE_FOR_STYLE[transaction_style]
|
||||
|
||||
@@ -584,7 +585,7 @@ class DjangoRequestExtractor(RequestExtractor):
|
||||
# type: () -> Optional[Dict[str, Any]]
|
||||
try:
|
||||
return self.request.data
|
||||
except AttributeError:
|
||||
except Exception:
|
||||
return RequestExtractor.parsed_body(self)
|
||||
|
||||
|
||||
@@ -745,3 +746,13 @@ def _set_db_data(span, cursor_or_db):
|
||||
server_socket_address = connection_params.get("unix_socket")
|
||||
if server_socket_address is not None:
|
||||
span.set_data(SPANDATA.SERVER_SOCKET_ADDRESS, server_socket_address)
|
||||
|
||||
|
||||
def add_template_context_repr_sequence():
|
||||
# type: () -> None
|
||||
try:
|
||||
from django.template.context import BaseContext
|
||||
|
||||
add_repr_sequence_type(BaseContext)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
+3
-3
@@ -155,7 +155,7 @@ def patch_channels_asgi_handler_impl(cls):
|
||||
http_methods_to_capture=integration.http_methods_to_capture,
|
||||
)
|
||||
|
||||
return await middleware(self.scope)(receive, send)
|
||||
return await middleware(self.scope)(receive, send) # type: ignore
|
||||
|
||||
cls.__call__ = sentry_patched_asgi_handler
|
||||
|
||||
@@ -237,9 +237,9 @@ def _asgi_middleware_mixin_factory(_check_middleware_span):
|
||||
middleware_span = _check_middleware_span(old_method=f)
|
||||
|
||||
if middleware_span is None:
|
||||
return await f(*args, **kwargs)
|
||||
return await f(*args, **kwargs) # type: ignore
|
||||
|
||||
with middleware_span:
|
||||
return await f(*args, **kwargs)
|
||||
return await f(*args, **kwargs) # type: ignore
|
||||
|
||||
return SentryASGIMixin
|
||||
|
||||
+16
-3
@@ -45,7 +45,8 @@ def _patch_cache_method(cache, method_name, address, port):
|
||||
):
|
||||
# type: (CacheHandler, str, Callable[..., Any], tuple[Any, ...], dict[str, Any], Optional[str], Optional[int]) -> Any
|
||||
is_set_operation = method_name.startswith("set")
|
||||
is_get_operation = not is_set_operation
|
||||
is_get_method = method_name == "get"
|
||||
is_get_many_method = method_name == "get_many"
|
||||
|
||||
op = OP.CACHE_PUT if is_set_operation else OP.CACHE_GET
|
||||
description = _get_span_description(method_name, args, kwargs)
|
||||
@@ -69,8 +70,20 @@ def _patch_cache_method(cache, method_name, address, port):
|
||||
span.set_data(SPANDATA.CACHE_KEY, key)
|
||||
|
||||
item_size = None
|
||||
if is_get_operation:
|
||||
if value:
|
||||
if is_get_many_method:
|
||||
if value != {}:
|
||||
item_size = len(str(value))
|
||||
span.set_data(SPANDATA.CACHE_HIT, True)
|
||||
else:
|
||||
span.set_data(SPANDATA.CACHE_HIT, False)
|
||||
elif is_get_method:
|
||||
default_value = None
|
||||
if len(args) >= 2:
|
||||
default_value = args[1]
|
||||
elif "default" in kwargs:
|
||||
default_value = kwargs["default"]
|
||||
|
||||
if value != default_value:
|
||||
item_size = len(str(value))
|
||||
span.set_data(SPANDATA.CACHE_HIT, True)
|
||||
else:
|
||||
|
||||
@@ -1,18 +1,31 @@
|
||||
import json
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.integrations import Integration
|
||||
from sentry_sdk.consts import OP, SPANSTATUS
|
||||
from sentry_sdk.api import continue_trace, get_baggage, get_traceparent
|
||||
from sentry_sdk.integrations import Integration, DidNotEnable
|
||||
from sentry_sdk.integrations._wsgi_common import request_body_within_bounds
|
||||
from sentry_sdk.tracing import (
|
||||
BAGGAGE_HEADER_NAME,
|
||||
SENTRY_TRACE_HEADER_NAME,
|
||||
TransactionSource,
|
||||
)
|
||||
from sentry_sdk.utils import (
|
||||
AnnotatedValue,
|
||||
capture_internal_exceptions,
|
||||
event_from_exception,
|
||||
)
|
||||
from typing import TypeVar
|
||||
|
||||
from dramatiq.broker import Broker # type: ignore
|
||||
from dramatiq.message import Message # type: ignore
|
||||
from dramatiq.middleware import Middleware, default_middleware # type: ignore
|
||||
from dramatiq.errors import Retry # type: ignore
|
||||
R = TypeVar("R")
|
||||
|
||||
try:
|
||||
from dramatiq.broker import Broker
|
||||
from dramatiq.middleware import Middleware, default_middleware
|
||||
from dramatiq.errors import Retry
|
||||
from dramatiq.message import Message
|
||||
except ImportError:
|
||||
raise DidNotEnable("Dramatiq is not installed")
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -34,10 +47,12 @@ class DramatiqIntegration(Integration):
|
||||
"""
|
||||
|
||||
identifier = "dramatiq"
|
||||
origin = f"auto.queue.{identifier}"
|
||||
|
||||
@staticmethod
|
||||
def setup_once():
|
||||
# type: () -> None
|
||||
|
||||
_patch_dramatiq_broker()
|
||||
|
||||
|
||||
@@ -85,22 +100,54 @@ class SentryMiddleware(Middleware): # type: ignore[misc]
|
||||
DramatiqIntegration.
|
||||
"""
|
||||
|
||||
def before_process_message(self, broker, message):
|
||||
# type: (Broker, Message) -> None
|
||||
SENTRY_HEADERS_NAME = "_sentry_headers"
|
||||
|
||||
def before_enqueue(self, broker, message, delay):
|
||||
# type: (Broker, Message[R], int) -> None
|
||||
integration = sentry_sdk.get_client().get_integration(DramatiqIntegration)
|
||||
if integration is None:
|
||||
return
|
||||
|
||||
message._scope_manager = sentry_sdk.new_scope()
|
||||
message._scope_manager.__enter__()
|
||||
message.options[self.SENTRY_HEADERS_NAME] = {
|
||||
BAGGAGE_HEADER_NAME: get_baggage(),
|
||||
SENTRY_TRACE_HEADER_NAME: get_traceparent(),
|
||||
}
|
||||
|
||||
scope = sentry_sdk.get_current_scope()
|
||||
scope.transaction = message.actor_name
|
||||
def before_process_message(self, broker, message):
|
||||
# type: (Broker, Message[R]) -> None
|
||||
integration = sentry_sdk.get_client().get_integration(DramatiqIntegration)
|
||||
if integration is None:
|
||||
return
|
||||
|
||||
message._scope_manager = sentry_sdk.isolation_scope()
|
||||
scope = message._scope_manager.__enter__()
|
||||
scope.clear_breadcrumbs()
|
||||
scope.set_extra("dramatiq_message_id", message.message_id)
|
||||
scope.add_event_processor(_make_message_event_processor(message, integration))
|
||||
|
||||
sentry_headers = message.options.get(self.SENTRY_HEADERS_NAME) or {}
|
||||
if "retries" in message.options:
|
||||
# start new trace in case of retrying
|
||||
sentry_headers = {}
|
||||
|
||||
transaction = continue_trace(
|
||||
sentry_headers,
|
||||
name=message.actor_name,
|
||||
op=OP.QUEUE_TASK_DRAMATIQ,
|
||||
source=TransactionSource.TASK,
|
||||
origin=DramatiqIntegration.origin,
|
||||
)
|
||||
transaction.set_status(SPANSTATUS.OK)
|
||||
sentry_sdk.start_transaction(
|
||||
transaction,
|
||||
name=message.actor_name,
|
||||
op=OP.QUEUE_TASK_DRAMATIQ,
|
||||
source=TransactionSource.TASK,
|
||||
)
|
||||
transaction.__enter__()
|
||||
|
||||
def after_process_message(self, broker, message, *, result=None, exception=None):
|
||||
# type: (Broker, Message, Any, Optional[Any], Optional[Exception]) -> None
|
||||
# type: (Broker, Message[R], Optional[Any], Optional[Exception]) -> None
|
||||
integration = sentry_sdk.get_client().get_integration(DramatiqIntegration)
|
||||
if integration is None:
|
||||
return
|
||||
@@ -108,27 +155,38 @@ class SentryMiddleware(Middleware): # type: ignore[misc]
|
||||
actor = broker.get_actor(message.actor_name)
|
||||
throws = message.options.get("throws") or actor.options.get("throws")
|
||||
|
||||
try:
|
||||
if (
|
||||
exception is not None
|
||||
and not (throws and isinstance(exception, throws))
|
||||
and not isinstance(exception, Retry)
|
||||
):
|
||||
event, hint = event_from_exception(
|
||||
exception,
|
||||
client_options=sentry_sdk.get_client().options,
|
||||
mechanism={
|
||||
"type": DramatiqIntegration.identifier,
|
||||
"handled": False,
|
||||
},
|
||||
)
|
||||
sentry_sdk.capture_event(event, hint=hint)
|
||||
finally:
|
||||
message._scope_manager.__exit__(None, None, None)
|
||||
scope_manager = message._scope_manager
|
||||
transaction = sentry_sdk.get_current_scope().transaction
|
||||
if not transaction:
|
||||
return None
|
||||
|
||||
is_event_capture_required = (
|
||||
exception is not None
|
||||
and not (throws and isinstance(exception, throws))
|
||||
and not isinstance(exception, Retry)
|
||||
)
|
||||
if not is_event_capture_required:
|
||||
# normal transaction finish
|
||||
transaction.__exit__(None, None, None)
|
||||
scope_manager.__exit__(None, None, None)
|
||||
return
|
||||
|
||||
event, hint = event_from_exception(
|
||||
exception, # type: ignore[arg-type]
|
||||
client_options=sentry_sdk.get_client().options,
|
||||
mechanism={
|
||||
"type": DramatiqIntegration.identifier,
|
||||
"handled": False,
|
||||
},
|
||||
)
|
||||
sentry_sdk.capture_event(event, hint=hint)
|
||||
# transaction error
|
||||
transaction.__exit__(type(exception), exception, None)
|
||||
scope_manager.__exit__(type(exception), exception, None)
|
||||
|
||||
|
||||
def _make_message_event_processor(message, integration):
|
||||
# type: (Message, DramatiqIntegration) -> Callable[[Event, Hint], Optional[Event]]
|
||||
# type: (Message[R], DramatiqIntegration) -> Callable[[Event, Hint], Optional[Event]]
|
||||
|
||||
def inner(event, hint):
|
||||
# type: (Event, Hint) -> Optional[Event]
|
||||
@@ -142,7 +200,7 @@ def _make_message_event_processor(message, integration):
|
||||
|
||||
class DramatiqMessageExtractor:
|
||||
def __init__(self, message):
|
||||
# type: (Message) -> None
|
||||
# type: (Message[R]) -> None
|
||||
self.message_data = dict(message.asdict())
|
||||
|
||||
def content_length(self):
|
||||
|
||||
@@ -5,11 +5,8 @@ from functools import wraps
|
||||
import sentry_sdk
|
||||
from sentry_sdk.integrations import DidNotEnable
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.tracing import SOURCE_FOR_STYLE, TRANSACTION_SOURCE_ROUTE
|
||||
from sentry_sdk.utils import (
|
||||
transaction_from_function,
|
||||
logger,
|
||||
)
|
||||
from sentry_sdk.tracing import SOURCE_FOR_STYLE, TransactionSource
|
||||
from sentry_sdk.utils import transaction_from_function
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -61,14 +58,11 @@ def _set_transaction_name_and_source(scope, transaction_style, request):
|
||||
|
||||
if not name:
|
||||
name = _DEFAULT_TRANSACTION_NAME
|
||||
source = TRANSACTION_SOURCE_ROUTE
|
||||
source = TransactionSource.ROUTE
|
||||
else:
|
||||
source = SOURCE_FOR_STYLE[transaction_style]
|
||||
|
||||
scope.set_transaction_name(name, source=source)
|
||||
logger.debug(
|
||||
"[FastAPI] Set transaction name and source on scope: %s / %s", name, source
|
||||
)
|
||||
|
||||
|
||||
def patch_get_request_handler():
|
||||
|
||||
@@ -72,6 +72,18 @@ class FlaskIntegration(Integration):
|
||||
@staticmethod
|
||||
def setup_once():
|
||||
# type: () -> None
|
||||
try:
|
||||
from quart import Quart # type: ignore
|
||||
|
||||
if Flask == Quart:
|
||||
# This is Quart masquerading as Flask, don't enable the Flask
|
||||
# integration. See https://github.com/getsentry/sentry-python/issues/2709
|
||||
raise DidNotEnable(
|
||||
"This is not a Flask app but rather Quart pretending to be Flask"
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
version = package_version("flask")
|
||||
_check_minimum_version(FlaskIntegration, version)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from sentry_sdk.consts import OP
|
||||
from sentry_sdk.integrations import Integration
|
||||
from sentry_sdk.integrations._wsgi_common import _filter_headers
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.tracing import TRANSACTION_SOURCE_COMPONENT
|
||||
from sentry_sdk.tracing import TransactionSource
|
||||
from sentry_sdk.utils import (
|
||||
AnnotatedValue,
|
||||
capture_internal_exceptions,
|
||||
@@ -75,7 +75,12 @@ def _wrap_func(func):
|
||||
):
|
||||
waiting_time = configured_time - TIMEOUT_WARNING_BUFFER
|
||||
|
||||
timeout_thread = TimeoutThread(waiting_time, configured_time)
|
||||
timeout_thread = TimeoutThread(
|
||||
waiting_time,
|
||||
configured_time,
|
||||
isolation_scope=scope,
|
||||
current_scope=sentry_sdk.get_current_scope(),
|
||||
)
|
||||
|
||||
# Starting the thread to raise timeout warning exception
|
||||
timeout_thread.start()
|
||||
@@ -88,7 +93,7 @@ def _wrap_func(func):
|
||||
headers,
|
||||
op=OP.FUNCTION_GCP,
|
||||
name=environ.get("FUNCTION_NAME", ""),
|
||||
source=TRANSACTION_SOURCE_COMPONENT,
|
||||
source=TransactionSource.COMPONENT,
|
||||
origin=GcpIntegration.origin,
|
||||
)
|
||||
sampling_context = {
|
||||
|
||||
+7
-15
@@ -11,24 +11,16 @@ if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
from sentry_sdk._types import Event
|
||||
|
||||
|
||||
MODULE_RE = r"[a-zA-Z0-9/._:\\-]+"
|
||||
TYPE_RE = r"[a-zA-Z0-9._:<>,-]+"
|
||||
HEXVAL_RE = r"[A-Fa-f0-9]+"
|
||||
|
||||
# function is everything between index at @
|
||||
# and then we match on the @ plus the hex val
|
||||
FUNCTION_RE = r"[^@]+?"
|
||||
HEX_ADDRESS = r"\s+@\s+0x[0-9a-fA-F]+"
|
||||
|
||||
FRAME_RE = r"""
|
||||
^(?P<index>\d+)\.\s
|
||||
(?P<package>{MODULE_RE})\(
|
||||
(?P<retval>{TYPE_RE}\ )?
|
||||
((?P<function>{TYPE_RE})
|
||||
(?P<args>\(.*\))?
|
||||
)?
|
||||
((?P<constoffset>\ const)?\+0x(?P<offset>{HEXVAL_RE}))?
|
||||
\)\s
|
||||
\[0x(?P<retaddr>{HEXVAL_RE})\]$
|
||||
^(?P<index>\d+)\.\s+(?P<function>{FUNCTION_RE}){HEX_ADDRESS}(?:\s+in\s+(?P<package>.+))?$
|
||||
""".format(
|
||||
MODULE_RE=MODULE_RE, HEXVAL_RE=HEXVAL_RE, TYPE_RE=TYPE_RE
|
||||
FUNCTION_RE=FUNCTION_RE,
|
||||
HEX_ADDRESS=HEX_ADDRESS,
|
||||
)
|
||||
|
||||
FRAME_RE = re.compile(FRAME_RE, re.MULTILINE | re.VERBOSE)
|
||||
|
||||
+301
@@ -0,0 +1,301 @@
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterator,
|
||||
List,
|
||||
)
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.ai.utils import get_start_span_function
|
||||
from sentry_sdk.integrations import DidNotEnable, Integration
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
from sentry_sdk.tracing import SPANSTATUS
|
||||
|
||||
|
||||
try:
|
||||
from google.genai.models import Models, AsyncModels
|
||||
except ImportError:
|
||||
raise DidNotEnable("google-genai not installed")
|
||||
|
||||
|
||||
from .consts import IDENTIFIER, ORIGIN, GEN_AI_SYSTEM
|
||||
from .utils import (
|
||||
set_span_data_for_request,
|
||||
set_span_data_for_response,
|
||||
_capture_exception,
|
||||
prepare_generate_content_args,
|
||||
)
|
||||
from .streaming import (
|
||||
set_span_data_for_streaming_response,
|
||||
accumulate_streaming_response,
|
||||
)
|
||||
|
||||
|
||||
class GoogleGenAIIntegration(Integration):
|
||||
identifier = IDENTIFIER
|
||||
origin = ORIGIN
|
||||
|
||||
def __init__(self, include_prompts=True):
|
||||
# type: (GoogleGenAIIntegration, bool) -> None
|
||||
self.include_prompts = include_prompts
|
||||
|
||||
@staticmethod
|
||||
def setup_once():
|
||||
# type: () -> None
|
||||
# Patch sync methods
|
||||
Models.generate_content = _wrap_generate_content(Models.generate_content)
|
||||
Models.generate_content_stream = _wrap_generate_content_stream(
|
||||
Models.generate_content_stream
|
||||
)
|
||||
|
||||
# Patch async methods
|
||||
AsyncModels.generate_content = _wrap_async_generate_content(
|
||||
AsyncModels.generate_content
|
||||
)
|
||||
AsyncModels.generate_content_stream = _wrap_async_generate_content_stream(
|
||||
AsyncModels.generate_content_stream
|
||||
)
|
||||
|
||||
|
||||
def _wrap_generate_content_stream(f):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
@wraps(f)
|
||||
def new_generate_content_stream(self, *args, **kwargs):
|
||||
# type: (Any, Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(GoogleGenAIIntegration)
|
||||
if integration is None:
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
_model, contents, model_name = prepare_generate_content_args(args, kwargs)
|
||||
|
||||
span = get_start_span_function()(
|
||||
op=OP.GEN_AI_INVOKE_AGENT,
|
||||
name="invoke_agent",
|
||||
origin=ORIGIN,
|
||||
)
|
||||
span.__enter__()
|
||||
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
|
||||
set_span_data_for_request(span, integration, model_name, contents, kwargs)
|
||||
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
|
||||
|
||||
chat_span = sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_CHAT,
|
||||
name=f"chat {model_name}",
|
||||
origin=ORIGIN,
|
||||
)
|
||||
chat_span.__enter__()
|
||||
chat_span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
|
||||
chat_span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
|
||||
chat_span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
|
||||
set_span_data_for_request(chat_span, integration, model_name, contents, kwargs)
|
||||
chat_span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
|
||||
chat_span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
|
||||
|
||||
try:
|
||||
stream = f(self, *args, **kwargs)
|
||||
|
||||
# Create wrapper iterator to accumulate responses
|
||||
def new_iterator():
|
||||
# type: () -> Iterator[Any]
|
||||
chunks = [] # type: List[Any]
|
||||
try:
|
||||
for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as exc:
|
||||
_capture_exception(exc)
|
||||
chat_span.set_status(SPANSTATUS.ERROR)
|
||||
raise
|
||||
finally:
|
||||
# Accumulate all chunks and set final response data on spans
|
||||
if chunks:
|
||||
accumulated_response = accumulate_streaming_response(chunks)
|
||||
set_span_data_for_streaming_response(
|
||||
chat_span, integration, accumulated_response
|
||||
)
|
||||
set_span_data_for_streaming_response(
|
||||
span, integration, accumulated_response
|
||||
)
|
||||
chat_span.__exit__(None, None, None)
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
return new_iterator()
|
||||
|
||||
except Exception as exc:
|
||||
_capture_exception(exc)
|
||||
chat_span.__exit__(None, None, None)
|
||||
span.__exit__(None, None, None)
|
||||
raise
|
||||
|
||||
return new_generate_content_stream
|
||||
|
||||
|
||||
def _wrap_async_generate_content_stream(f):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
@wraps(f)
|
||||
async def new_async_generate_content_stream(self, *args, **kwargs):
|
||||
# type: (Any, Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(GoogleGenAIIntegration)
|
||||
if integration is None:
|
||||
return await f(self, *args, **kwargs)
|
||||
|
||||
_model, contents, model_name = prepare_generate_content_args(args, kwargs)
|
||||
|
||||
span = get_start_span_function()(
|
||||
op=OP.GEN_AI_INVOKE_AGENT,
|
||||
name="invoke_agent",
|
||||
origin=ORIGIN,
|
||||
)
|
||||
span.__enter__()
|
||||
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
|
||||
set_span_data_for_request(span, integration, model_name, contents, kwargs)
|
||||
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
|
||||
|
||||
chat_span = sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_CHAT,
|
||||
name=f"chat {model_name}",
|
||||
origin=ORIGIN,
|
||||
)
|
||||
chat_span.__enter__()
|
||||
chat_span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
|
||||
chat_span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
|
||||
chat_span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
|
||||
set_span_data_for_request(chat_span, integration, model_name, contents, kwargs)
|
||||
chat_span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
|
||||
chat_span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
|
||||
|
||||
try:
|
||||
stream = await f(self, *args, **kwargs)
|
||||
|
||||
# Create wrapper async iterator to accumulate responses
|
||||
async def new_async_iterator():
|
||||
# type: () -> AsyncIterator[Any]
|
||||
chunks = [] # type: List[Any]
|
||||
try:
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as exc:
|
||||
_capture_exception(exc)
|
||||
chat_span.set_status(SPANSTATUS.ERROR)
|
||||
raise
|
||||
finally:
|
||||
# Accumulate all chunks and set final response data on spans
|
||||
if chunks:
|
||||
accumulated_response = accumulate_streaming_response(chunks)
|
||||
set_span_data_for_streaming_response(
|
||||
chat_span, integration, accumulated_response
|
||||
)
|
||||
set_span_data_for_streaming_response(
|
||||
span, integration, accumulated_response
|
||||
)
|
||||
chat_span.__exit__(None, None, None)
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
return new_async_iterator()
|
||||
|
||||
except Exception as exc:
|
||||
_capture_exception(exc)
|
||||
chat_span.__exit__(None, None, None)
|
||||
span.__exit__(None, None, None)
|
||||
raise
|
||||
|
||||
return new_async_generate_content_stream
|
||||
|
||||
|
||||
def _wrap_generate_content(f):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
@wraps(f)
|
||||
def new_generate_content(self, *args, **kwargs):
|
||||
# type: (Any, Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(GoogleGenAIIntegration)
|
||||
if integration is None:
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
model, contents, model_name = prepare_generate_content_args(args, kwargs)
|
||||
|
||||
with get_start_span_function()(
|
||||
op=OP.GEN_AI_INVOKE_AGENT,
|
||||
name="invoke_agent",
|
||||
origin=ORIGIN,
|
||||
) as span:
|
||||
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
|
||||
set_span_data_for_request(span, integration, model_name, contents, kwargs)
|
||||
|
||||
with sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_CHAT,
|
||||
name=f"chat {model_name}",
|
||||
origin=ORIGIN,
|
||||
) as chat_span:
|
||||
chat_span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
|
||||
chat_span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
|
||||
chat_span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
|
||||
chat_span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
|
||||
set_span_data_for_request(
|
||||
chat_span, integration, model_name, contents, kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
response = f(self, *args, **kwargs)
|
||||
except Exception as exc:
|
||||
_capture_exception(exc)
|
||||
chat_span.set_status(SPANSTATUS.ERROR)
|
||||
raise
|
||||
|
||||
set_span_data_for_response(chat_span, integration, response)
|
||||
set_span_data_for_response(span, integration, response)
|
||||
|
||||
return response
|
||||
|
||||
return new_generate_content
|
||||
|
||||
|
||||
def _wrap_async_generate_content(f):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
@wraps(f)
|
||||
async def new_async_generate_content(self, *args, **kwargs):
|
||||
# type: (Any, Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(GoogleGenAIIntegration)
|
||||
if integration is None:
|
||||
return await f(self, *args, **kwargs)
|
||||
|
||||
model, contents, model_name = prepare_generate_content_args(args, kwargs)
|
||||
|
||||
with get_start_span_function()(
|
||||
op=OP.GEN_AI_INVOKE_AGENT,
|
||||
name="invoke_agent",
|
||||
origin=ORIGIN,
|
||||
) as span:
|
||||
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
|
||||
set_span_data_for_request(span, integration, model_name, contents, kwargs)
|
||||
|
||||
with sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_CHAT,
|
||||
name=f"chat {model_name}",
|
||||
origin=ORIGIN,
|
||||
) as chat_span:
|
||||
chat_span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
|
||||
chat_span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
|
||||
chat_span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
|
||||
set_span_data_for_request(
|
||||
chat_span, integration, model_name, contents, kwargs
|
||||
)
|
||||
try:
|
||||
response = await f(self, *args, **kwargs)
|
||||
except Exception as exc:
|
||||
_capture_exception(exc)
|
||||
chat_span.set_status(SPANSTATUS.ERROR)
|
||||
raise
|
||||
|
||||
set_span_data_for_response(chat_span, integration, response)
|
||||
set_span_data_for_response(span, integration, response)
|
||||
|
||||
return response
|
||||
|
||||
return new_async_generate_content
|
||||
+16
@@ -0,0 +1,16 @@
|
||||
GEN_AI_SYSTEM = "gcp.gemini"
|
||||
|
||||
# Mapping of tool attributes to their descriptions
|
||||
# These are all tools that are available in the Google GenAI API
|
||||
TOOL_ATTRIBUTES_MAP = {
|
||||
"google_search_retrieval": "Google Search retrieval tool",
|
||||
"google_search": "Google Search tool",
|
||||
"retrieval": "Retrieval tool",
|
||||
"enterprise_web_search": "Enterprise web search tool",
|
||||
"google_maps": "Google Maps tool",
|
||||
"code_execution": "Code execution tool",
|
||||
"computer_use": "Computer use tool",
|
||||
}
|
||||
|
||||
IDENTIFIER = "google_genai"
|
||||
ORIGIN = f"auto.ai.{IDENTIFIER}"
|
||||
+155
@@ -0,0 +1,155 @@
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
List,
|
||||
TypedDict,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from sentry_sdk.ai.utils import set_data_normalized
|
||||
from sentry_sdk.consts import SPANDATA
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.utils import (
|
||||
safe_serialize,
|
||||
)
|
||||
from .utils import (
|
||||
extract_tool_calls,
|
||||
extract_finish_reasons,
|
||||
extract_contents_text,
|
||||
extract_usage_data,
|
||||
UsageData,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentry_sdk.tracing import Span
|
||||
from google.genai.types import GenerateContentResponse
|
||||
|
||||
|
||||
class AccumulatedResponse(TypedDict):
|
||||
id: Optional[str]
|
||||
model: Optional[str]
|
||||
text: str
|
||||
finish_reasons: List[str]
|
||||
tool_calls: List[dict[str, Any]]
|
||||
usage_metadata: UsageData
|
||||
|
||||
|
||||
def accumulate_streaming_response(chunks):
|
||||
# type: (List[GenerateContentResponse]) -> AccumulatedResponse
|
||||
"""Accumulate streaming chunks into a single response-like object."""
|
||||
accumulated_text = []
|
||||
finish_reasons = []
|
||||
tool_calls = []
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
total_tokens = 0
|
||||
total_cached_tokens = 0
|
||||
total_reasoning_tokens = 0
|
||||
response_id = None
|
||||
model = None
|
||||
|
||||
for chunk in chunks:
|
||||
# Extract text and tool calls
|
||||
if getattr(chunk, "candidates", None):
|
||||
for candidate in getattr(chunk, "candidates", []):
|
||||
if hasattr(candidate, "content") and getattr(
|
||||
candidate.content, "parts", []
|
||||
):
|
||||
extracted_text = extract_contents_text(candidate.content)
|
||||
if extracted_text:
|
||||
accumulated_text.append(extracted_text)
|
||||
|
||||
extracted_finish_reasons = extract_finish_reasons(chunk)
|
||||
if extracted_finish_reasons:
|
||||
finish_reasons.extend(extracted_finish_reasons)
|
||||
|
||||
extracted_tool_calls = extract_tool_calls(chunk)
|
||||
if extracted_tool_calls:
|
||||
tool_calls.extend(extracted_tool_calls)
|
||||
|
||||
# Accumulate token usage
|
||||
extracted_usage_data = extract_usage_data(chunk)
|
||||
total_input_tokens += extracted_usage_data["input_tokens"]
|
||||
total_output_tokens += extracted_usage_data["output_tokens"]
|
||||
total_cached_tokens += extracted_usage_data["input_tokens_cached"]
|
||||
total_reasoning_tokens += extracted_usage_data["output_tokens_reasoning"]
|
||||
total_tokens += extracted_usage_data["total_tokens"]
|
||||
|
||||
accumulated_response = AccumulatedResponse(
|
||||
text="".join(accumulated_text),
|
||||
finish_reasons=finish_reasons,
|
||||
tool_calls=tool_calls,
|
||||
usage_metadata=UsageData(
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
input_tokens_cached=total_cached_tokens,
|
||||
output_tokens_reasoning=total_reasoning_tokens,
|
||||
total_tokens=total_tokens,
|
||||
),
|
||||
id=response_id,
|
||||
model=model,
|
||||
)
|
||||
|
||||
return accumulated_response
|
||||
|
||||
|
||||
def set_span_data_for_streaming_response(span, integration, accumulated_response):
|
||||
# type: (Span, Any, AccumulatedResponse) -> None
|
||||
"""Set span data for accumulated streaming response."""
|
||||
if (
|
||||
should_send_default_pii()
|
||||
and integration.include_prompts
|
||||
and accumulated_response.get("text")
|
||||
):
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_RESPONSE_TEXT,
|
||||
safe_serialize([accumulated_response["text"]]),
|
||||
)
|
||||
|
||||
if accumulated_response.get("finish_reasons"):
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
|
||||
accumulated_response["finish_reasons"],
|
||||
)
|
||||
|
||||
if accumulated_response.get("tool_calls"):
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
|
||||
safe_serialize(accumulated_response["tool_calls"]),
|
||||
)
|
||||
|
||||
if accumulated_response.get("id"):
|
||||
span.set_data(SPANDATA.GEN_AI_RESPONSE_ID, accumulated_response["id"])
|
||||
if accumulated_response.get("model"):
|
||||
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, accumulated_response["model"])
|
||||
|
||||
if accumulated_response["usage_metadata"]["input_tokens"]:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS,
|
||||
accumulated_response["usage_metadata"]["input_tokens"],
|
||||
)
|
||||
|
||||
if accumulated_response["usage_metadata"]["input_tokens_cached"]:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
|
||||
accumulated_response["usage_metadata"]["input_tokens_cached"],
|
||||
)
|
||||
|
||||
if accumulated_response["usage_metadata"]["output_tokens"]:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||
accumulated_response["usage_metadata"]["output_tokens"],
|
||||
)
|
||||
|
||||
if accumulated_response["usage_metadata"]["output_tokens_reasoning"]:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
|
||||
accumulated_response["usage_metadata"]["output_tokens_reasoning"],
|
||||
)
|
||||
|
||||
if accumulated_response["usage_metadata"]["total_tokens"]:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
accumulated_response["usage_metadata"]["total_tokens"],
|
||||
)
|
||||
+576
@@ -0,0 +1,576 @@
|
||||
import copy
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from .consts import ORIGIN, TOOL_ATTRIBUTES_MAP, GEN_AI_SYSTEM
|
||||
from typing import (
|
||||
cast,
|
||||
TYPE_CHECKING,
|
||||
Iterable,
|
||||
Any,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
TypedDict,
|
||||
)
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.ai.utils import (
|
||||
set_data_normalized,
|
||||
truncate_and_annotate_messages,
|
||||
normalize_message_roles,
|
||||
)
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.utils import (
|
||||
capture_internal_exceptions,
|
||||
event_from_exception,
|
||||
safe_serialize,
|
||||
)
|
||||
from google.genai.types import GenerateContentConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentry_sdk.tracing import Span
|
||||
from google.genai.types import (
|
||||
GenerateContentResponse,
|
||||
ContentListUnion,
|
||||
Tool,
|
||||
Model,
|
||||
)
|
||||
|
||||
|
||||
class UsageData(TypedDict):
|
||||
"""Structure for token usage data."""
|
||||
|
||||
input_tokens: int
|
||||
input_tokens_cached: int
|
||||
output_tokens: int
|
||||
output_tokens_reasoning: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
def extract_usage_data(response):
|
||||
# type: (Union[GenerateContentResponse, dict[str, Any]]) -> UsageData
|
||||
"""Extract usage data from response into a structured format.
|
||||
|
||||
Args:
|
||||
response: The GenerateContentResponse object or dictionary containing usage metadata
|
||||
|
||||
Returns:
|
||||
UsageData: Dictionary with input_tokens, input_tokens_cached,
|
||||
output_tokens, and output_tokens_reasoning fields
|
||||
"""
|
||||
usage_data = UsageData(
|
||||
input_tokens=0,
|
||||
input_tokens_cached=0,
|
||||
output_tokens=0,
|
||||
output_tokens_reasoning=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
# Handle dictionary response (from streaming)
|
||||
if isinstance(response, dict):
|
||||
usage = response.get("usage_metadata", {})
|
||||
if not usage:
|
||||
return usage_data
|
||||
|
||||
prompt_tokens = usage.get("prompt_token_count", 0) or 0
|
||||
tool_use_prompt_tokens = usage.get("tool_use_prompt_token_count", 0) or 0
|
||||
usage_data["input_tokens"] = prompt_tokens + tool_use_prompt_tokens
|
||||
|
||||
cached_tokens = usage.get("cached_content_token_count", 0) or 0
|
||||
usage_data["input_tokens_cached"] = cached_tokens
|
||||
|
||||
reasoning_tokens = usage.get("thoughts_token_count", 0) or 0
|
||||
usage_data["output_tokens_reasoning"] = reasoning_tokens
|
||||
|
||||
candidates_tokens = usage.get("candidates_token_count", 0) or 0
|
||||
# python-genai reports output and reasoning tokens separately
|
||||
# reasoning should be sub-category of output tokens
|
||||
usage_data["output_tokens"] = candidates_tokens + reasoning_tokens
|
||||
|
||||
total_tokens = usage.get("total_token_count", 0) or 0
|
||||
usage_data["total_tokens"] = total_tokens
|
||||
|
||||
return usage_data
|
||||
|
||||
if not hasattr(response, "usage_metadata"):
|
||||
return usage_data
|
||||
|
||||
usage = response.usage_metadata
|
||||
|
||||
# Input tokens include both prompt and tool use prompt tokens
|
||||
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
||||
tool_use_prompt_tokens = getattr(usage, "tool_use_prompt_token_count", 0) or 0
|
||||
usage_data["input_tokens"] = prompt_tokens + tool_use_prompt_tokens
|
||||
|
||||
# Cached input tokens
|
||||
cached_tokens = getattr(usage, "cached_content_token_count", 0) or 0
|
||||
usage_data["input_tokens_cached"] = cached_tokens
|
||||
|
||||
# Reasoning tokens
|
||||
reasoning_tokens = getattr(usage, "thoughts_token_count", 0) or 0
|
||||
usage_data["output_tokens_reasoning"] = reasoning_tokens
|
||||
|
||||
# output_tokens = candidates_tokens + reasoning_tokens
|
||||
# google-genai reports output and reasoning tokens separately
|
||||
candidates_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
||||
usage_data["output_tokens"] = candidates_tokens + reasoning_tokens
|
||||
|
||||
total_tokens = getattr(usage, "total_token_count", 0) or 0
|
||||
usage_data["total_tokens"] = total_tokens
|
||||
|
||||
return usage_data
|
||||
|
||||
|
||||
def _capture_exception(exc):
|
||||
# type: (Any) -> None
|
||||
"""Capture exception with Google GenAI mechanism."""
|
||||
event, hint = event_from_exception(
|
||||
exc,
|
||||
client_options=sentry_sdk.get_client().options,
|
||||
mechanism={"type": "google_genai", "handled": False},
|
||||
)
|
||||
sentry_sdk.capture_event(event, hint=hint)
|
||||
|
||||
|
||||
def get_model_name(model):
|
||||
# type: (Union[str, Model]) -> str
|
||||
"""Extract model name from model parameter."""
|
||||
if isinstance(model, str):
|
||||
return model
|
||||
# Handle case where model might be an object with a name attribute
|
||||
if hasattr(model, "name"):
|
||||
return str(model.name)
|
||||
return str(model)
|
||||
|
||||
|
||||
def extract_contents_text(contents):
|
||||
# type: (ContentListUnion) -> Optional[str]
|
||||
"""Extract text from contents parameter which can have various formats."""
|
||||
if contents is None:
|
||||
return None
|
||||
|
||||
# Simple string case
|
||||
if isinstance(contents, str):
|
||||
return contents
|
||||
|
||||
# List of contents or parts
|
||||
if isinstance(contents, list):
|
||||
texts = []
|
||||
for item in contents:
|
||||
# Recursively extract text from each item
|
||||
extracted = extract_contents_text(item)
|
||||
if extracted:
|
||||
texts.append(extracted)
|
||||
return " ".join(texts) if texts else None
|
||||
|
||||
# Dictionary case
|
||||
if isinstance(contents, dict):
|
||||
if "text" in contents:
|
||||
return contents["text"]
|
||||
# Try to extract from parts if present in dict
|
||||
if "parts" in contents:
|
||||
return extract_contents_text(contents["parts"])
|
||||
|
||||
# Content object with parts - recurse into parts
|
||||
if getattr(contents, "parts", None):
|
||||
return extract_contents_text(contents.parts)
|
||||
|
||||
# Direct text attribute
|
||||
if hasattr(contents, "text"):
|
||||
return contents.text
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _format_tools_for_span(tools):
|
||||
# type: (Iterable[Tool | Callable[..., Any]]) -> Optional[List[dict[str, Any]]]
|
||||
"""Format tools parameter for span data."""
|
||||
formatted_tools = []
|
||||
for tool in tools:
|
||||
if callable(tool):
|
||||
# Handle callable functions passed directly
|
||||
formatted_tools.append(
|
||||
{
|
||||
"name": getattr(tool, "__name__", "unknown"),
|
||||
"description": getattr(tool, "__doc__", None),
|
||||
}
|
||||
)
|
||||
elif (
|
||||
hasattr(tool, "function_declarations")
|
||||
and tool.function_declarations is not None
|
||||
):
|
||||
# Tool object with function declarations
|
||||
for func_decl in tool.function_declarations:
|
||||
formatted_tools.append(
|
||||
{
|
||||
"name": getattr(func_decl, "name", None),
|
||||
"description": getattr(func_decl, "description", None),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Check for predefined tool attributes - each of these tools
|
||||
# is an attribute of the tool object, by default set to None
|
||||
for attr_name, description in TOOL_ATTRIBUTES_MAP.items():
|
||||
if getattr(tool, attr_name, None):
|
||||
formatted_tools.append(
|
||||
{
|
||||
"name": attr_name,
|
||||
"description": description,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
return formatted_tools if formatted_tools else None
|
||||
|
||||
|
||||
def extract_tool_calls(response):
|
||||
# type: (GenerateContentResponse) -> Optional[List[dict[str, Any]]]
|
||||
"""Extract tool/function calls from response candidates and automatic function calling history."""
|
||||
|
||||
tool_calls = []
|
||||
|
||||
# Extract from candidates, sometimes tool calls are nested under the content.parts object
|
||||
if getattr(response, "candidates", []):
|
||||
for candidate in response.candidates:
|
||||
if not hasattr(candidate, "content") or not getattr(
|
||||
candidate.content, "parts", []
|
||||
):
|
||||
continue
|
||||
|
||||
for part in candidate.content.parts:
|
||||
if getattr(part, "function_call", None):
|
||||
function_call = part.function_call
|
||||
tool_call = {
|
||||
"name": getattr(function_call, "name", None),
|
||||
"type": "function_call",
|
||||
}
|
||||
|
||||
# Extract arguments if available
|
||||
if getattr(function_call, "args", None):
|
||||
tool_call["arguments"] = safe_serialize(function_call.args)
|
||||
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
# Extract from automatic_function_calling_history
|
||||
# This is the history of tool calls made by the model
|
||||
if getattr(response, "automatic_function_calling_history", None):
|
||||
for content in response.automatic_function_calling_history:
|
||||
if not getattr(content, "parts", None):
|
||||
continue
|
||||
|
||||
for part in getattr(content, "parts", []):
|
||||
if getattr(part, "function_call", None):
|
||||
function_call = part.function_call
|
||||
tool_call = {
|
||||
"name": getattr(function_call, "name", None),
|
||||
"type": "function_call",
|
||||
}
|
||||
|
||||
# Extract arguments if available
|
||||
if hasattr(function_call, "args"):
|
||||
tool_call["arguments"] = safe_serialize(function_call.args)
|
||||
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls if tool_calls else None
|
||||
|
||||
|
||||
def _capture_tool_input(args, kwargs, tool):
|
||||
# type: (tuple[Any, ...], dict[str, Any], Tool) -> dict[str, Any]
|
||||
"""Capture tool input from args and kwargs."""
|
||||
tool_input = kwargs.copy() if kwargs else {}
|
||||
|
||||
# If we have positional args, try to map them to the function signature
|
||||
if args:
|
||||
try:
|
||||
sig = inspect.signature(tool)
|
||||
param_names = list(sig.parameters.keys())
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(param_names):
|
||||
tool_input[param_names[i]] = arg
|
||||
except Exception:
|
||||
# Fallback if we can't get the signature
|
||||
tool_input["args"] = args
|
||||
|
||||
return tool_input
|
||||
|
||||
|
||||
def _create_tool_span(tool_name, tool_doc):
|
||||
# type: (str, Optional[str]) -> Span
|
||||
"""Create a span for tool execution."""
|
||||
span = sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_EXECUTE_TOOL,
|
||||
name=f"execute_tool {tool_name}",
|
||||
origin=ORIGIN,
|
||||
)
|
||||
span.set_data(SPANDATA.GEN_AI_TOOL_NAME, tool_name)
|
||||
span.set_data(SPANDATA.GEN_AI_TOOL_TYPE, "function")
|
||||
if tool_doc:
|
||||
span.set_data(SPANDATA.GEN_AI_TOOL_DESCRIPTION, tool_doc)
|
||||
return span
|
||||
|
||||
|
||||
def wrapped_tool(tool):
|
||||
# type: (Tool | Callable[..., Any]) -> Tool | Callable[..., Any]
|
||||
"""Wrap a tool to emit execute_tool spans when called."""
|
||||
if not callable(tool):
|
||||
# Not a callable function, return as-is (predefined tools)
|
||||
return tool
|
||||
|
||||
tool_name = getattr(tool, "__name__", "unknown")
|
||||
tool_doc = tool.__doc__
|
||||
|
||||
if inspect.iscoroutinefunction(tool):
|
||||
# Async function
|
||||
@wraps(tool)
|
||||
async def async_wrapped(*args, **kwargs):
|
||||
# type: (Any, Any) -> Any
|
||||
with _create_tool_span(tool_name, tool_doc) as span:
|
||||
# Capture tool input
|
||||
tool_input = _capture_tool_input(args, kwargs, tool)
|
||||
with capture_internal_exceptions():
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_TOOL_INPUT, safe_serialize(tool_input)
|
||||
)
|
||||
|
||||
try:
|
||||
result = await tool(*args, **kwargs)
|
||||
|
||||
# Capture tool output
|
||||
with capture_internal_exceptions():
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_TOOL_OUTPUT, safe_serialize(result)
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as exc:
|
||||
_capture_exception(exc)
|
||||
raise
|
||||
|
||||
return async_wrapped
|
||||
else:
|
||||
# Sync function
|
||||
@wraps(tool)
|
||||
def sync_wrapped(*args, **kwargs):
|
||||
# type: (Any, Any) -> Any
|
||||
with _create_tool_span(tool_name, tool_doc) as span:
|
||||
# Capture tool input
|
||||
tool_input = _capture_tool_input(args, kwargs, tool)
|
||||
with capture_internal_exceptions():
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_TOOL_INPUT, safe_serialize(tool_input)
|
||||
)
|
||||
|
||||
try:
|
||||
result = tool(*args, **kwargs)
|
||||
|
||||
# Capture tool output
|
||||
with capture_internal_exceptions():
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_TOOL_OUTPUT, safe_serialize(result)
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as exc:
|
||||
_capture_exception(exc)
|
||||
raise
|
||||
|
||||
return sync_wrapped
|
||||
|
||||
|
||||
def wrapped_config_with_tools(config):
|
||||
# type: (GenerateContentConfig) -> GenerateContentConfig
|
||||
"""Wrap tools in config to emit execute_tool spans. Tools are sometimes passed directly as
|
||||
callable functions as a part of the config object."""
|
||||
|
||||
if not config or not getattr(config, "tools", None):
|
||||
return config
|
||||
|
||||
result = copy.copy(config)
|
||||
result.tools = [wrapped_tool(tool) for tool in config.tools]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _extract_response_text(response):
|
||||
# type: (GenerateContentResponse) -> Optional[List[str]]
|
||||
"""Extract text from response candidates."""
|
||||
|
||||
if not response or not getattr(response, "candidates", []):
|
||||
return None
|
||||
|
||||
texts = []
|
||||
for candidate in response.candidates:
|
||||
if not hasattr(candidate, "content") or not hasattr(candidate.content, "parts"):
|
||||
continue
|
||||
|
||||
for part in candidate.content.parts:
|
||||
if getattr(part, "text", None):
|
||||
texts.append(part.text)
|
||||
|
||||
return texts if texts else None
|
||||
|
||||
|
||||
def extract_finish_reasons(response):
|
||||
# type: (GenerateContentResponse) -> Optional[List[str]]
|
||||
"""Extract finish reasons from response candidates."""
|
||||
if not response or not getattr(response, "candidates", []):
|
||||
return None
|
||||
|
||||
finish_reasons = []
|
||||
for candidate in response.candidates:
|
||||
if getattr(candidate, "finish_reason", None):
|
||||
# Convert enum value to string if necessary
|
||||
reason = str(candidate.finish_reason)
|
||||
# Remove enum prefix if present (e.g., "FinishReason.STOP" -> "STOP")
|
||||
if "." in reason:
|
||||
reason = reason.split(".")[-1]
|
||||
finish_reasons.append(reason)
|
||||
|
||||
return finish_reasons if finish_reasons else None
|
||||
|
||||
|
||||
def set_span_data_for_request(span, integration, model, contents, kwargs):
|
||||
# type: (Span, Any, str, ContentListUnion, dict[str, Any]) -> None
|
||||
"""Set span data for the request."""
|
||||
span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
|
||||
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)
|
||||
|
||||
if kwargs.get("stream", False):
|
||||
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
|
||||
|
||||
config = kwargs.get("config")
|
||||
|
||||
if config is None:
|
||||
return
|
||||
|
||||
config = cast(GenerateContentConfig, config)
|
||||
|
||||
# Set input messages/prompts if PII is allowed
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
messages = []
|
||||
|
||||
# Add system instruction if present
|
||||
if hasattr(config, "system_instruction"):
|
||||
system_instruction = config.system_instruction
|
||||
if system_instruction:
|
||||
system_text = extract_contents_text(system_instruction)
|
||||
if system_text:
|
||||
messages.append({"role": "system", "content": system_text})
|
||||
|
||||
# Add user message
|
||||
contents_text = extract_contents_text(contents)
|
||||
if contents_text:
|
||||
messages.append({"role": "user", "content": contents_text})
|
||||
|
||||
if messages:
|
||||
normalized_messages = normalize_message_roles(messages)
|
||||
scope = sentry_sdk.get_current_scope()
|
||||
messages_data = truncate_and_annotate_messages(
|
||||
normalized_messages, span, scope
|
||||
)
|
||||
if messages_data is not None:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_REQUEST_MESSAGES,
|
||||
messages_data,
|
||||
unpack=False,
|
||||
)
|
||||
|
||||
# Extract parameters directly from config (not nested under generation_config)
|
||||
for param, span_key in [
|
||||
("temperature", SPANDATA.GEN_AI_REQUEST_TEMPERATURE),
|
||||
("top_p", SPANDATA.GEN_AI_REQUEST_TOP_P),
|
||||
("top_k", SPANDATA.GEN_AI_REQUEST_TOP_K),
|
||||
("max_output_tokens", SPANDATA.GEN_AI_REQUEST_MAX_TOKENS),
|
||||
("presence_penalty", SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY),
|
||||
("frequency_penalty", SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY),
|
||||
("seed", SPANDATA.GEN_AI_REQUEST_SEED),
|
||||
]:
|
||||
if hasattr(config, param):
|
||||
value = getattr(config, param)
|
||||
if value is not None:
|
||||
span.set_data(span_key, value)
|
||||
|
||||
# Set tools if available
|
||||
if hasattr(config, "tools"):
|
||||
tools = config.tools
|
||||
if tools:
|
||||
formatted_tools = _format_tools_for_span(tools)
|
||||
if formatted_tools:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
|
||||
formatted_tools,
|
||||
unpack=False,
|
||||
)
|
||||
|
||||
|
||||
def set_span_data_for_response(span, integration, response):
|
||||
# type: (Span, Any, GenerateContentResponse) -> None
|
||||
"""Set span data for the response."""
|
||||
if not response:
|
||||
return
|
||||
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
response_texts = _extract_response_text(response)
|
||||
if response_texts:
|
||||
# Format as JSON string array as per documentation
|
||||
span.set_data(SPANDATA.GEN_AI_RESPONSE_TEXT, safe_serialize(response_texts))
|
||||
|
||||
tool_calls = extract_tool_calls(response)
|
||||
if tool_calls:
|
||||
# Tool calls should be JSON serialized
|
||||
span.set_data(SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, safe_serialize(tool_calls))
|
||||
|
||||
finish_reasons = extract_finish_reasons(response)
|
||||
if finish_reasons:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS, finish_reasons
|
||||
)
|
||||
|
||||
if getattr(response, "response_id", None):
|
||||
span.set_data(SPANDATA.GEN_AI_RESPONSE_ID, response.response_id)
|
||||
|
||||
if getattr(response, "model_version", None):
|
||||
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response.model_version)
|
||||
|
||||
usage_data = extract_usage_data(response)
|
||||
|
||||
if usage_data["input_tokens"]:
|
||||
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage_data["input_tokens"])
|
||||
|
||||
if usage_data["input_tokens_cached"]:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
|
||||
usage_data["input_tokens_cached"],
|
||||
)
|
||||
|
||||
if usage_data["output_tokens"]:
|
||||
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage_data["output_tokens"])
|
||||
|
||||
if usage_data["output_tokens_reasoning"]:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
|
||||
usage_data["output_tokens_reasoning"],
|
||||
)
|
||||
|
||||
if usage_data["total_tokens"]:
|
||||
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage_data["total_tokens"])
|
||||
|
||||
|
||||
def prepare_generate_content_args(args, kwargs):
|
||||
# type: (tuple[Any, ...], dict[str, Any]) -> tuple[Any, Any, str]
|
||||
"""Extract and prepare common arguments for generate_content methods."""
|
||||
model = args[0] if args else kwargs.get("model", "unknown")
|
||||
contents = args[1] if len(args) > 1 else kwargs.get("contents")
|
||||
model_name = get_model_name(model)
|
||||
|
||||
config = kwargs.get("config")
|
||||
wrapped_config = wrapped_config_with_tools(config)
|
||||
if wrapped_config is not config:
|
||||
kwargs["config"] = wrapped_config
|
||||
|
||||
return model, contents, model_name
|
||||
@@ -18,6 +18,13 @@ try:
|
||||
)
|
||||
from gql.transport import Transport, AsyncTransport # type: ignore[import-not-found]
|
||||
from gql.transport.exceptions import TransportQueryError # type: ignore[import-not-found]
|
||||
|
||||
try:
|
||||
# gql 4.0+
|
||||
from gql import GraphQLRequest
|
||||
except ImportError:
|
||||
GraphQLRequest = None
|
||||
|
||||
except ImportError:
|
||||
raise DidNotEnable("gql is not installed")
|
||||
|
||||
@@ -92,13 +99,13 @@ def _patch_execute():
|
||||
real_execute = gql.Client.execute
|
||||
|
||||
@ensure_integration_enabled(GQLIntegration, real_execute)
|
||||
def sentry_patched_execute(self, document, *args, **kwargs):
|
||||
def sentry_patched_execute(self, document_or_request, *args, **kwargs):
|
||||
# type: (gql.Client, DocumentNode, Any, Any) -> Any
|
||||
scope = sentry_sdk.get_isolation_scope()
|
||||
scope.add_event_processor(_make_gql_event_processor(self, document))
|
||||
scope.add_event_processor(_make_gql_event_processor(self, document_or_request))
|
||||
|
||||
try:
|
||||
return real_execute(self, document, *args, **kwargs)
|
||||
return real_execute(self, document_or_request, *args, **kwargs)
|
||||
except TransportQueryError as e:
|
||||
event, hint = event_from_exception(
|
||||
e,
|
||||
@@ -112,8 +119,8 @@ def _patch_execute():
|
||||
gql.Client.execute = sentry_patched_execute
|
||||
|
||||
|
||||
def _make_gql_event_processor(client, document):
|
||||
# type: (gql.Client, DocumentNode) -> EventProcessor
|
||||
def _make_gql_event_processor(client, document_or_request):
|
||||
# type: (gql.Client, Union[DocumentNode, gql.GraphQLRequest]) -> EventProcessor
|
||||
def processor(event, hint):
|
||||
# type: (Event, dict[str, Any]) -> Event
|
||||
try:
|
||||
@@ -130,6 +137,16 @@ def _make_gql_event_processor(client, document):
|
||||
)
|
||||
|
||||
if should_send_default_pii():
|
||||
if GraphQLRequest is not None and isinstance(
|
||||
document_or_request, GraphQLRequest
|
||||
):
|
||||
# In v4.0.0, gql moved to using GraphQLRequest instead of
|
||||
# DocumentNode in execute
|
||||
# https://github.com/graphql-python/gql/pull/556
|
||||
document = document_or_request.document
|
||||
else:
|
||||
document = document_or_request
|
||||
|
||||
request["data"] = _data_from_document(document)
|
||||
contexts = event.setdefault("contexts", {})
|
||||
response = contexts.setdefault("response", {})
|
||||
|
||||
+18
-1
@@ -6,6 +6,7 @@ from grpc.aio import Channel as AsyncChannel
|
||||
from grpc.aio import Server as AsyncServer
|
||||
|
||||
from sentry_sdk.integrations import Integration
|
||||
from sentry_sdk.utils import parse_version
|
||||
|
||||
from .client import ClientInterceptor
|
||||
from .server import ServerInterceptor
|
||||
@@ -41,6 +42,8 @@ else:
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
GRPC_VERSION = parse_version(grpc.__version__)
|
||||
|
||||
|
||||
def _wrap_channel_sync(func: Callable[P, Channel]) -> Callable[P, Channel]:
|
||||
"Wrapper for synchronous secure and insecure channel."
|
||||
@@ -127,7 +130,21 @@ def _wrap_async_server(func: Callable[P, AsyncServer]) -> Callable[P, AsyncServe
|
||||
**kwargs: P.kwargs,
|
||||
) -> Server:
|
||||
server_interceptor = AsyncServerInterceptor()
|
||||
interceptors = (server_interceptor, *(interceptors or []))
|
||||
interceptors = [
|
||||
server_interceptor,
|
||||
*(interceptors or []),
|
||||
] # type: Sequence[grpc.ServerInterceptor]
|
||||
|
||||
try:
|
||||
# We prefer interceptors as a list because of compatibility with
|
||||
# opentelemetry https://github.com/getsentry/sentry-python/issues/4389
|
||||
# However, prior to grpc 1.42.0, only tuples were accepted, so we
|
||||
# have no choice there.
|
||||
if GRPC_VERSION is not None and GRPC_VERSION < (1, 42, 0):
|
||||
interceptors = tuple(interceptors)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return func(*args, interceptors=interceptors, **kwargs) # type: ignore
|
||||
|
||||
return patched_aio_server # type: ignore
|
||||
|
||||
+2
-1
@@ -65,7 +65,8 @@ class SentryUnaryUnaryClientInterceptor(ClientInterceptor, UnaryUnaryClientInter
|
||||
|
||||
|
||||
class SentryUnaryStreamClientInterceptor(
|
||||
ClientInterceptor, UnaryStreamClientInterceptor # type: ignore
|
||||
ClientInterceptor,
|
||||
UnaryStreamClientInterceptor, # type: ignore
|
||||
):
|
||||
async def intercept_unary_stream(
|
||||
self,
|
||||
|
||||
+2
-2
@@ -2,7 +2,7 @@ import sentry_sdk
|
||||
from sentry_sdk.consts import OP
|
||||
from sentry_sdk.integrations import DidNotEnable
|
||||
from sentry_sdk.integrations.grpc.consts import SPAN_ORIGIN
|
||||
from sentry_sdk.tracing import Transaction, TRANSACTION_SOURCE_CUSTOM
|
||||
from sentry_sdk.tracing import Transaction, TransactionSource
|
||||
from sentry_sdk.utils import event_from_exception
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -48,7 +48,7 @@ class ServerInterceptor(grpc.aio.ServerInterceptor): # type: ignore
|
||||
dict(context.invocation_metadata()),
|
||||
op=OP.GRPC_SERVER,
|
||||
name=name,
|
||||
source=TRANSACTION_SOURCE_CUSTOM,
|
||||
source=TransactionSource.CUSTOM,
|
||||
origin=SPAN_ORIGIN,
|
||||
)
|
||||
|
||||
|
||||
+3
-4
@@ -19,7 +19,8 @@ except ImportError:
|
||||
|
||||
|
||||
class ClientInterceptor(
|
||||
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor # type: ignore
|
||||
grpc.UnaryUnaryClientInterceptor, # type: ignore
|
||||
grpc.UnaryStreamClientInterceptor, # type: ignore
|
||||
):
|
||||
_is_intercepted = False
|
||||
|
||||
@@ -60,9 +61,7 @@ class ClientInterceptor(
|
||||
client_call_details
|
||||
)
|
||||
|
||||
response = continuation(
|
||||
client_call_details, request
|
||||
) # type: UnaryStreamCall
|
||||
response = continuation(client_call_details, request) # type: UnaryStreamCall
|
||||
# Setting code on unary-stream leads to execution getting stuck
|
||||
# span.set_data("code", response.code().name)
|
||||
|
||||
|
||||
+2
-2
@@ -2,7 +2,7 @@ import sentry_sdk
|
||||
from sentry_sdk.consts import OP
|
||||
from sentry_sdk.integrations import DidNotEnable
|
||||
from sentry_sdk.integrations.grpc.consts import SPAN_ORIGIN
|
||||
from sentry_sdk.tracing import Transaction, TRANSACTION_SOURCE_CUSTOM
|
||||
from sentry_sdk.tracing import Transaction, TransactionSource
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -42,7 +42,7 @@ class ServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
||||
metadata,
|
||||
op=OP.GRPC_SERVER,
|
||||
name=name,
|
||||
source=TRANSACTION_SOURCE_CUSTOM,
|
||||
source=TransactionSource.CUSTOM,
|
||||
origin=SPAN_ORIGIN,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import sentry_sdk
|
||||
from sentry_sdk import start_span
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
from sentry_sdk.integrations import Integration, DidNotEnable
|
||||
from sentry_sdk.tracing import BAGGAGE_HEADER_NAME
|
||||
from sentry_sdk.tracing_utils import Baggage, should_propagate_trace
|
||||
from sentry_sdk.tracing_utils import (
|
||||
Baggage,
|
||||
should_propagate_trace,
|
||||
add_http_request_source,
|
||||
)
|
||||
from sentry_sdk.utils import (
|
||||
SENSITIVE_DATA_SUBSTITUTE,
|
||||
capture_internal_exceptions,
|
||||
@@ -52,7 +57,7 @@ def _install_httpx_client():
|
||||
with capture_internal_exceptions():
|
||||
parsed_url = parse_url(str(request.url), sanitize=False)
|
||||
|
||||
with sentry_sdk.start_span(
|
||||
with start_span(
|
||||
op=OP.HTTP_CLIENT,
|
||||
name="%s %s"
|
||||
% (
|
||||
@@ -88,7 +93,10 @@ def _install_httpx_client():
|
||||
span.set_http_status(rv.status_code)
|
||||
span.set_data("reason", rv.reason_phrase)
|
||||
|
||||
return rv
|
||||
with capture_internal_exceptions():
|
||||
add_http_request_source(span)
|
||||
|
||||
return rv
|
||||
|
||||
Client.send = send
|
||||
|
||||
@@ -106,7 +114,7 @@ def _install_httpx_async_client():
|
||||
with capture_internal_exceptions():
|
||||
parsed_url = parse_url(str(request.url), sanitize=False)
|
||||
|
||||
with sentry_sdk.start_span(
|
||||
with start_span(
|
||||
op=OP.HTTP_CLIENT,
|
||||
name="%s %s"
|
||||
% (
|
||||
@@ -144,7 +152,10 @@ def _install_httpx_async_client():
|
||||
span.set_http_status(rv.status_code)
|
||||
span.set_data("reason", rv.reason_phrase)
|
||||
|
||||
return rv
|
||||
with capture_internal_exceptions():
|
||||
add_http_request_source(span)
|
||||
|
||||
return rv
|
||||
|
||||
AsyncClient.send = send
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.tracing import (
|
||||
BAGGAGE_HEADER_NAME,
|
||||
SENTRY_TRACE_HEADER_NAME,
|
||||
TRANSACTION_SOURCE_TASK,
|
||||
TransactionSource,
|
||||
)
|
||||
from sentry_sdk.utils import (
|
||||
capture_internal_exceptions,
|
||||
@@ -159,7 +159,7 @@ def patch_execute():
|
||||
sentry_headers or {},
|
||||
name=task.name,
|
||||
op=OP.QUEUE_TASK_HUEY,
|
||||
source=TRANSACTION_SOURCE_TASK,
|
||||
source=TransactionSource.TASK,
|
||||
origin=HueyIntegration.origin,
|
||||
)
|
||||
transaction.set_status(SPANSTATUS.OK)
|
||||
|
||||
+284
-81
@@ -1,24 +1,25 @@
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
from sentry_sdk import consts
|
||||
import sentry_sdk
|
||||
from sentry_sdk.ai.monitoring import record_token_usage
|
||||
from sentry_sdk.ai.utils import set_data_normalized
|
||||
from sentry_sdk.consts import SPANDATA
|
||||
|
||||
from typing import Any, Iterable, Callable
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
from sentry_sdk.integrations import DidNotEnable, Integration
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.tracing_utils import set_span_errored
|
||||
from sentry_sdk.utils import (
|
||||
capture_internal_exceptions,
|
||||
event_from_exception,
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
try:
|
||||
import huggingface_hub.inference._client
|
||||
|
||||
from huggingface_hub import ChatCompletionStreamOutput, TextGenerationOutput
|
||||
except ImportError:
|
||||
raise DidNotEnable("Huggingface not installed")
|
||||
|
||||
@@ -34,15 +35,26 @@ class HuggingfaceHubIntegration(Integration):
|
||||
@staticmethod
|
||||
def setup_once():
|
||||
# type: () -> None
|
||||
|
||||
# Other tasks that can be called: https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks
|
||||
huggingface_hub.inference._client.InferenceClient.text_generation = (
|
||||
_wrap_text_generation(
|
||||
huggingface_hub.inference._client.InferenceClient.text_generation
|
||||
_wrap_huggingface_task(
|
||||
huggingface_hub.inference._client.InferenceClient.text_generation,
|
||||
OP.GEN_AI_GENERATE_TEXT,
|
||||
)
|
||||
)
|
||||
huggingface_hub.inference._client.InferenceClient.chat_completion = (
|
||||
_wrap_huggingface_task(
|
||||
huggingface_hub.inference._client.InferenceClient.chat_completion,
|
||||
OP.GEN_AI_CHAT,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _capture_exception(exc):
|
||||
# type: (Any) -> None
|
||||
set_span_errored()
|
||||
|
||||
event, hint = event_from_exception(
|
||||
exc,
|
||||
client_options=sentry_sdk.get_client().options,
|
||||
@@ -51,34 +63,70 @@ def _capture_exception(exc):
|
||||
sentry_sdk.capture_event(event, hint=hint)
|
||||
|
||||
|
||||
def _wrap_text_generation(f):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
def _wrap_huggingface_task(f, op):
|
||||
# type: (Callable[..., Any], str) -> Callable[..., Any]
|
||||
@wraps(f)
|
||||
def new_text_generation(*args, **kwargs):
|
||||
def new_huggingface_task(*args, **kwargs):
|
||||
# type: (*Any, **Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)
|
||||
if integration is None:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
prompt = None
|
||||
if "prompt" in kwargs:
|
||||
prompt = kwargs["prompt"]
|
||||
elif "messages" in kwargs:
|
||||
prompt = kwargs["messages"]
|
||||
elif len(args) >= 2:
|
||||
kwargs["prompt"] = args[1]
|
||||
prompt = kwargs["prompt"]
|
||||
args = (args[0],) + args[2:]
|
||||
else:
|
||||
# invalid call, let it return error
|
||||
if isinstance(args[1], str) or isinstance(args[1], list):
|
||||
prompt = args[1]
|
||||
|
||||
if prompt is None:
|
||||
# invalid call, dont instrument, let it return error
|
||||
return f(*args, **kwargs)
|
||||
|
||||
model = kwargs.get("model")
|
||||
streaming = kwargs.get("stream")
|
||||
client = args[0]
|
||||
model = client.model or kwargs.get("model") or ""
|
||||
operation_name = op.split(".")[-1]
|
||||
|
||||
span = sentry_sdk.start_span(
|
||||
op=consts.OP.HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE,
|
||||
name="Text Generation",
|
||||
op=op,
|
||||
name=f"{operation_name} {model}",
|
||||
origin=HuggingfaceHubIntegration.origin,
|
||||
)
|
||||
span.__enter__()
|
||||
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, operation_name)
|
||||
|
||||
if model:
|
||||
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)
|
||||
|
||||
# Input attributes
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, prompt, unpack=False
|
||||
)
|
||||
|
||||
attribute_mapping = {
|
||||
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
|
||||
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
|
||||
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
|
||||
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
|
||||
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
|
||||
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
|
||||
"top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
|
||||
"stream": SPANDATA.GEN_AI_RESPONSE_STREAMING,
|
||||
}
|
||||
|
||||
for attribute, span_attribute in attribute_mapping.items():
|
||||
value = kwargs.get(attribute, None)
|
||||
if value is not None:
|
||||
if isinstance(value, (int, float, bool, str)):
|
||||
span.set_data(span_attribute, value)
|
||||
else:
|
||||
set_data_normalized(span, span_attribute, value, unpack=False)
|
||||
|
||||
# LLM Execution
|
||||
try:
|
||||
res = f(*args, **kwargs)
|
||||
except Exception as e:
|
||||
@@ -86,90 +134,245 @@ def _wrap_text_generation(f):
|
||||
span.__exit__(None, None, None)
|
||||
raise e from None
|
||||
|
||||
# Output attributes
|
||||
finish_reason = None
|
||||
response_model = None
|
||||
response_text_buffer: list[str] = []
|
||||
tokens_used = 0
|
||||
tool_calls = None
|
||||
usage = None
|
||||
|
||||
with capture_internal_exceptions():
|
||||
if isinstance(res, str) and res is not None:
|
||||
response_text_buffer.append(res)
|
||||
|
||||
if hasattr(res, "generated_text") and res.generated_text is not None:
|
||||
response_text_buffer.append(res.generated_text)
|
||||
|
||||
if hasattr(res, "model") and res.model is not None:
|
||||
response_model = res.model
|
||||
|
||||
if hasattr(res, "details") and hasattr(res.details, "finish_reason"):
|
||||
finish_reason = res.details.finish_reason
|
||||
|
||||
if (
|
||||
hasattr(res, "details")
|
||||
and hasattr(res.details, "generated_tokens")
|
||||
and res.details.generated_tokens is not None
|
||||
):
|
||||
tokens_used = res.details.generated_tokens
|
||||
|
||||
if hasattr(res, "usage") and res.usage is not None:
|
||||
usage = res.usage
|
||||
|
||||
if hasattr(res, "choices") and res.choices is not None:
|
||||
for choice in res.choices:
|
||||
if hasattr(choice, "finish_reason"):
|
||||
finish_reason = choice.finish_reason
|
||||
if hasattr(choice, "message") and hasattr(
|
||||
choice.message, "tool_calls"
|
||||
):
|
||||
tool_calls = choice.message.tool_calls
|
||||
if (
|
||||
hasattr(choice, "message")
|
||||
and hasattr(choice.message, "content")
|
||||
and choice.message.content is not None
|
||||
):
|
||||
response_text_buffer.append(choice.message.content)
|
||||
|
||||
if response_model is not None:
|
||||
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)
|
||||
|
||||
if finish_reason is not None:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
|
||||
finish_reason,
|
||||
)
|
||||
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompt)
|
||||
|
||||
set_data_normalized(span, SPANDATA.AI_MODEL_ID, model)
|
||||
set_data_normalized(span, SPANDATA.AI_STREAMING, streaming)
|
||||
|
||||
if isinstance(res, str):
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
if tool_calls is not None and len(tool_calls) > 0:
|
||||
set_data_normalized(
|
||||
span,
|
||||
"ai.responses",
|
||||
[res],
|
||||
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
|
||||
tool_calls,
|
||||
unpack=False,
|
||||
)
|
||||
span.__exit__(None, None, None)
|
||||
return res
|
||||
|
||||
if isinstance(res, TextGenerationOutput):
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
set_data_normalized(
|
||||
span,
|
||||
"ai.responses",
|
||||
[res.generated_text],
|
||||
)
|
||||
if res.details is not None and res.details.generated_tokens > 0:
|
||||
record_token_usage(span, total_tokens=res.details.generated_tokens)
|
||||
span.__exit__(None, None, None)
|
||||
return res
|
||||
if len(response_text_buffer) > 0:
|
||||
text_response = "".join(response_text_buffer)
|
||||
if text_response:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_RESPONSE_TEXT,
|
||||
text_response,
|
||||
)
|
||||
|
||||
if not isinstance(res, Iterable):
|
||||
# we only know how to deal with strings and iterables, ignore
|
||||
set_data_normalized(span, "unknown_response", True)
|
||||
if usage is not None:
|
||||
record_token_usage(
|
||||
span,
|
||||
input_tokens=usage.prompt_tokens,
|
||||
output_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
)
|
||||
elif tokens_used > 0:
|
||||
record_token_usage(
|
||||
span,
|
||||
total_tokens=tokens_used,
|
||||
)
|
||||
|
||||
# If the response is not a generator (meaning a streaming response)
|
||||
# we are done and can return the response
|
||||
if not inspect.isgenerator(res):
|
||||
span.__exit__(None, None, None)
|
||||
return res
|
||||
|
||||
if kwargs.get("details", False):
|
||||
# res is Iterable[TextGenerationStreamOutput]
|
||||
# text-generation stream output
|
||||
def new_details_iterator():
|
||||
# type: () -> Iterable[ChatCompletionStreamOutput]
|
||||
# type: () -> Iterable[Any]
|
||||
finish_reason = None
|
||||
response_text_buffer: list[str] = []
|
||||
tokens_used = 0
|
||||
|
||||
with capture_internal_exceptions():
|
||||
tokens_used = 0
|
||||
data_buf: list[str] = []
|
||||
for x in res:
|
||||
if hasattr(x, "token") and hasattr(x.token, "text"):
|
||||
data_buf.append(x.token.text)
|
||||
if hasattr(x, "details") and hasattr(
|
||||
x.details, "generated_tokens"
|
||||
for chunk in res:
|
||||
if (
|
||||
hasattr(chunk, "token")
|
||||
and hasattr(chunk.token, "text")
|
||||
and chunk.token.text is not None
|
||||
):
|
||||
tokens_used = x.details.generated_tokens
|
||||
yield x
|
||||
if (
|
||||
len(data_buf) > 0
|
||||
and should_send_default_pii()
|
||||
and integration.include_prompts
|
||||
):
|
||||
response_text_buffer.append(chunk.token.text)
|
||||
|
||||
if hasattr(chunk, "details") and hasattr(
|
||||
chunk.details, "finish_reason"
|
||||
):
|
||||
finish_reason = chunk.details.finish_reason
|
||||
|
||||
if (
|
||||
hasattr(chunk, "details")
|
||||
and hasattr(chunk.details, "generated_tokens")
|
||||
and chunk.details.generated_tokens is not None
|
||||
):
|
||||
tokens_used = chunk.details.generated_tokens
|
||||
|
||||
yield chunk
|
||||
|
||||
if finish_reason is not None:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
|
||||
span,
|
||||
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
|
||||
finish_reason,
|
||||
)
|
||||
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
if len(response_text_buffer) > 0:
|
||||
text_response = "".join(response_text_buffer)
|
||||
if text_response:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_RESPONSE_TEXT,
|
||||
text_response,
|
||||
)
|
||||
|
||||
if tokens_used > 0:
|
||||
record_token_usage(span, total_tokens=tokens_used)
|
||||
record_token_usage(
|
||||
span,
|
||||
total_tokens=tokens_used,
|
||||
)
|
||||
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
return new_details_iterator()
|
||||
else:
|
||||
# res is Iterable[str]
|
||||
|
||||
else:
|
||||
# chat-completion stream output
|
||||
def new_iterator():
|
||||
# type: () -> Iterable[str]
|
||||
data_buf: list[str] = []
|
||||
finish_reason = None
|
||||
response_model = None
|
||||
response_text_buffer: list[str] = []
|
||||
tool_calls = None
|
||||
usage = None
|
||||
|
||||
with capture_internal_exceptions():
|
||||
for s in res:
|
||||
if isinstance(s, str):
|
||||
data_buf.append(s)
|
||||
yield s
|
||||
if (
|
||||
len(data_buf) > 0
|
||||
and should_send_default_pii()
|
||||
and integration.include_prompts
|
||||
):
|
||||
set_data_normalized(
|
||||
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
|
||||
for chunk in res:
|
||||
if hasattr(chunk, "model") and chunk.model is not None:
|
||||
response_model = chunk.model
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
usage = chunk.usage
|
||||
|
||||
if isinstance(chunk, str):
|
||||
if chunk is not None:
|
||||
response_text_buffer.append(chunk)
|
||||
|
||||
if hasattr(chunk, "choices") and chunk.choices is not None:
|
||||
for choice in chunk.choices:
|
||||
if (
|
||||
hasattr(choice, "delta")
|
||||
and hasattr(choice.delta, "content")
|
||||
and choice.delta.content is not None
|
||||
):
|
||||
response_text_buffer.append(
|
||||
choice.delta.content
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(choice, "finish_reason")
|
||||
and choice.finish_reason is not None
|
||||
):
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
if (
|
||||
hasattr(choice, "delta")
|
||||
and hasattr(choice.delta, "tool_calls")
|
||||
and choice.delta.tool_calls is not None
|
||||
):
|
||||
tool_calls = choice.delta.tool_calls
|
||||
|
||||
yield chunk
|
||||
|
||||
if response_model is not None:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_RESPONSE_MODEL, response_model
|
||||
)
|
||||
|
||||
if finish_reason is not None:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
|
||||
finish_reason,
|
||||
)
|
||||
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
if tool_calls is not None and len(tool_calls) > 0:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
|
||||
tool_calls,
|
||||
unpack=False,
|
||||
)
|
||||
|
||||
if len(response_text_buffer) > 0:
|
||||
text_response = "".join(response_text_buffer)
|
||||
if text_response:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_RESPONSE_TEXT,
|
||||
text_response,
|
||||
)
|
||||
|
||||
if usage is not None:
|
||||
record_token_usage(
|
||||
span,
|
||||
input_tokens=usage.prompt_tokens,
|
||||
output_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
)
|
||||
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
return new_iterator()
|
||||
|
||||
return new_text_generation
|
||||
return new_huggingface_task
|
||||
|
||||
+678
-244
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,337 @@
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.ai.utils import (
|
||||
set_data_normalized,
|
||||
normalize_message_roles,
|
||||
truncate_and_annotate_messages,
|
||||
)
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
from sentry_sdk.integrations import DidNotEnable, Integration
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.utils import safe_serialize
|
||||
|
||||
|
||||
try:
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.pregel import Pregel
|
||||
except ImportError:
|
||||
raise DidNotEnable("langgraph not installed")
|
||||
|
||||
|
||||
class LanggraphIntegration(Integration):
|
||||
identifier = "langgraph"
|
||||
origin = f"auto.ai.{identifier}"
|
||||
|
||||
def __init__(self, include_prompts=True):
|
||||
# type: (LanggraphIntegration, bool) -> None
|
||||
self.include_prompts = include_prompts
|
||||
|
||||
@staticmethod
|
||||
def setup_once():
|
||||
# type: () -> None
|
||||
# LangGraph lets users create agents using a StateGraph or the Functional API.
|
||||
# StateGraphs are then compiled to a CompiledStateGraph. Both CompiledStateGraph and
|
||||
# the functional API execute on a Pregel instance. Pregel is the runtime for the graph
|
||||
# and the invocation happens on Pregel, so patching the invoke methods takes care of both.
|
||||
# The streaming methods are not patched, because due to some internal reasons, LangGraph
|
||||
# will automatically patch the streaming methods to run through invoke, and by doing this
|
||||
# we prevent duplicate spans for invocations.
|
||||
StateGraph.compile = _wrap_state_graph_compile(StateGraph.compile)
|
||||
if hasattr(Pregel, "invoke"):
|
||||
Pregel.invoke = _wrap_pregel_invoke(Pregel.invoke)
|
||||
if hasattr(Pregel, "ainvoke"):
|
||||
Pregel.ainvoke = _wrap_pregel_ainvoke(Pregel.ainvoke)
|
||||
|
||||
|
||||
def _get_graph_name(graph_obj):
|
||||
# type: (Any) -> Optional[str]
|
||||
for attr in ["name", "graph_name", "__name__", "_name"]:
|
||||
if hasattr(graph_obj, attr):
|
||||
name = getattr(graph_obj, attr)
|
||||
if name and isinstance(name, str):
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_langgraph_message(message):
|
||||
# type: (Any) -> Any
|
||||
if not hasattr(message, "content"):
|
||||
return None
|
||||
|
||||
parsed = {"role": getattr(message, "type", None), "content": message.content}
|
||||
|
||||
for attr in ["name", "tool_calls", "function_call", "tool_call_id"]:
|
||||
if hasattr(message, attr):
|
||||
value = getattr(message, attr)
|
||||
if value is not None:
|
||||
parsed[attr] = value
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def _parse_langgraph_messages(state):
|
||||
# type: (Any) -> Optional[List[Any]]
|
||||
if not state:
|
||||
return None
|
||||
|
||||
messages = None
|
||||
|
||||
if isinstance(state, dict):
|
||||
messages = state.get("messages")
|
||||
elif hasattr(state, "messages"):
|
||||
messages = state.messages
|
||||
elif hasattr(state, "get") and callable(state.get):
|
||||
try:
|
||||
messages = state.get("messages")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not messages or not isinstance(messages, (list, tuple)):
|
||||
return None
|
||||
|
||||
normalized_messages = []
|
||||
for message in messages:
|
||||
try:
|
||||
normalized = _normalize_langgraph_message(message)
|
||||
if normalized:
|
||||
normalized_messages.append(normalized)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return normalized_messages if normalized_messages else None
|
||||
|
||||
|
||||
def _wrap_state_graph_compile(f):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
@wraps(f)
|
||||
def new_compile(self, *args, **kwargs):
|
||||
# type: (Any, Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(LanggraphIntegration)
|
||||
if integration is None:
|
||||
return f(self, *args, **kwargs)
|
||||
with sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_CREATE_AGENT,
|
||||
origin=LanggraphIntegration.origin,
|
||||
) as span:
|
||||
compiled_graph = f(self, *args, **kwargs)
|
||||
|
||||
compiled_graph_name = getattr(compiled_graph, "name", None)
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "create_agent")
|
||||
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, compiled_graph_name)
|
||||
|
||||
if compiled_graph_name:
|
||||
span.description = f"create_agent {compiled_graph_name}"
|
||||
else:
|
||||
span.description = "create_agent"
|
||||
|
||||
if kwargs.get("model", None) is not None:
|
||||
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, kwargs.get("model"))
|
||||
|
||||
tools = None
|
||||
get_graph = getattr(compiled_graph, "get_graph", None)
|
||||
if get_graph and callable(get_graph):
|
||||
graph_obj = compiled_graph.get_graph()
|
||||
nodes = getattr(graph_obj, "nodes", None)
|
||||
if nodes and isinstance(nodes, dict):
|
||||
tools_node = nodes.get("tools")
|
||||
if tools_node:
|
||||
data = getattr(tools_node, "data", None)
|
||||
if data and hasattr(data, "tools_by_name"):
|
||||
tools = list(data.tools_by_name.keys())
|
||||
|
||||
if tools is not None:
|
||||
span.set_data(SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, tools)
|
||||
|
||||
return compiled_graph
|
||||
|
||||
return new_compile
|
||||
|
||||
|
||||
def _wrap_pregel_invoke(f):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
|
||||
@wraps(f)
|
||||
def new_invoke(self, *args, **kwargs):
|
||||
# type: (Any, Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(LanggraphIntegration)
|
||||
if integration is None:
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
graph_name = _get_graph_name(self)
|
||||
span_name = (
|
||||
f"invoke_agent {graph_name}".strip() if graph_name else "invoke_agent"
|
||||
)
|
||||
|
||||
with sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_INVOKE_AGENT,
|
||||
name=span_name,
|
||||
origin=LanggraphIntegration.origin,
|
||||
) as span:
|
||||
if graph_name:
|
||||
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, graph_name)
|
||||
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, graph_name)
|
||||
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
|
||||
|
||||
# Store input messages to later compare with output
|
||||
input_messages = None
|
||||
if (
|
||||
len(args) > 0
|
||||
and should_send_default_pii()
|
||||
and integration.include_prompts
|
||||
):
|
||||
input_messages = _parse_langgraph_messages(args[0])
|
||||
if input_messages:
|
||||
normalized_input_messages = normalize_message_roles(input_messages)
|
||||
scope = sentry_sdk.get_current_scope()
|
||||
messages_data = truncate_and_annotate_messages(
|
||||
normalized_input_messages, span, scope
|
||||
)
|
||||
if messages_data is not None:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_REQUEST_MESSAGES,
|
||||
messages_data,
|
||||
unpack=False,
|
||||
)
|
||||
|
||||
result = f(self, *args, **kwargs)
|
||||
|
||||
_set_response_attributes(span, input_messages, result, integration)
|
||||
|
||||
return result
|
||||
|
||||
return new_invoke
|
||||
|
||||
|
||||
def _wrap_pregel_ainvoke(f):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
|
||||
@wraps(f)
|
||||
async def new_ainvoke(self, *args, **kwargs):
|
||||
# type: (Any, Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(LanggraphIntegration)
|
||||
if integration is None:
|
||||
return await f(self, *args, **kwargs)
|
||||
|
||||
graph_name = _get_graph_name(self)
|
||||
span_name = (
|
||||
f"invoke_agent {graph_name}".strip() if graph_name else "invoke_agent"
|
||||
)
|
||||
|
||||
with sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_INVOKE_AGENT,
|
||||
name=span_name,
|
||||
origin=LanggraphIntegration.origin,
|
||||
) as span:
|
||||
if graph_name:
|
||||
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, graph_name)
|
||||
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, graph_name)
|
||||
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
|
||||
|
||||
input_messages = None
|
||||
if (
|
||||
len(args) > 0
|
||||
and should_send_default_pii()
|
||||
and integration.include_prompts
|
||||
):
|
||||
input_messages = _parse_langgraph_messages(args[0])
|
||||
if input_messages:
|
||||
normalized_input_messages = normalize_message_roles(input_messages)
|
||||
scope = sentry_sdk.get_current_scope()
|
||||
messages_data = truncate_and_annotate_messages(
|
||||
normalized_input_messages, span, scope
|
||||
)
|
||||
if messages_data is not None:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_REQUEST_MESSAGES,
|
||||
messages_data,
|
||||
unpack=False,
|
||||
)
|
||||
|
||||
result = await f(self, *args, **kwargs)
|
||||
|
||||
_set_response_attributes(span, input_messages, result, integration)
|
||||
|
||||
return result
|
||||
|
||||
return new_ainvoke
|
||||
|
||||
|
||||
def _get_new_messages(input_messages, output_messages):
|
||||
# type: (Optional[List[Any]], Optional[List[Any]]) -> Optional[List[Any]]
|
||||
"""Extract only the new messages added during this invocation."""
|
||||
if not output_messages:
|
||||
return None
|
||||
|
||||
if not input_messages:
|
||||
return output_messages
|
||||
|
||||
# only return the new messages, aka the output messages that are not in the input messages
|
||||
input_count = len(input_messages)
|
||||
new_messages = (
|
||||
output_messages[input_count:] if len(output_messages) > input_count else []
|
||||
)
|
||||
|
||||
return new_messages if new_messages else None
|
||||
|
||||
|
||||
def _extract_llm_response_text(messages):
|
||||
# type: (Optional[List[Any]]) -> Optional[str]
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, dict):
|
||||
role = message.get("role")
|
||||
if role in ["assistant", "ai"]:
|
||||
content = message.get("content")
|
||||
if content and isinstance(content, str):
|
||||
return content
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_tool_calls(messages):
|
||||
# type: (Optional[List[Any]]) -> Optional[List[Any]]
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
tool_calls = []
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
msg_tool_calls = message.get("tool_calls")
|
||||
if msg_tool_calls and isinstance(msg_tool_calls, list):
|
||||
tool_calls.extend(msg_tool_calls)
|
||||
|
||||
return tool_calls if tool_calls else None
|
||||
|
||||
|
||||
def _set_response_attributes(span, input_messages, result, integration):
|
||||
# type: (Any, Optional[List[Any]], Any, LanggraphIntegration) -> None
|
||||
if not (should_send_default_pii() and integration.include_prompts):
|
||||
return
|
||||
|
||||
parsed_response_messages = _parse_langgraph_messages(result)
|
||||
new_messages = _get_new_messages(input_messages, parsed_response_messages)
|
||||
|
||||
llm_response_text = _extract_llm_response_text(new_messages)
|
||||
if llm_response_text:
|
||||
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, llm_response_text)
|
||||
elif new_messages:
|
||||
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, new_messages)
|
||||
else:
|
||||
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, result)
|
||||
|
||||
tool_calls = _extract_tool_calls(new_messages)
|
||||
if tool_calls:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
|
||||
safe_serialize(tool_calls),
|
||||
unpack=False,
|
||||
)
|
||||
+3
-4
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import sentry_sdk
|
||||
|
||||
from sentry_sdk.feature_flags import add_feature_flag
|
||||
from sentry_sdk.integrations import DidNotEnable, Integration
|
||||
|
||||
try:
|
||||
@@ -44,7 +44,6 @@ class LaunchDarklyIntegration(Integration):
|
||||
|
||||
|
||||
class LaunchDarklyHook(Hook):
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
# type: () -> Metadata
|
||||
@@ -53,8 +52,8 @@ class LaunchDarklyHook(Hook):
|
||||
def after_evaluation(self, series_context, data, detail):
|
||||
# type: (EvaluationSeriesContext, dict[Any, Any], EvaluationDetail) -> dict[Any, Any]
|
||||
if isinstance(detail.value, bool):
|
||||
flags = sentry_sdk.get_current_scope().flags
|
||||
flags.set(series_context.key, detail.value)
|
||||
add_feature_flag(series_context.key, detail.value)
|
||||
|
||||
return data
|
||||
|
||||
def before_evaluation(self, series_context, data):
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk import consts
|
||||
from sentry_sdk.ai.monitoring import record_token_usage
|
||||
from sentry_sdk.ai.utils import (
|
||||
get_start_span_function,
|
||||
set_data_normalized,
|
||||
truncate_and_annotate_messages,
|
||||
)
|
||||
from sentry_sdk.consts import SPANDATA
|
||||
from sentry_sdk.integrations import DidNotEnable, Integration
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.utils import event_from_exception
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Dict
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
import litellm # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
raise DidNotEnable("LiteLLM not installed")
|
||||
|
||||
|
||||
def _get_metadata_dict(kwargs):
|
||||
# type: (Dict[str, Any]) -> Dict[str, Any]
|
||||
"""Get the metadata dictionary from the kwargs."""
|
||||
litellm_params = kwargs.setdefault("litellm_params", {})
|
||||
|
||||
# we need this weird little dance, as metadata might be set but may be None initially
|
||||
metadata = litellm_params.get("metadata")
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
litellm_params["metadata"] = metadata
|
||||
return metadata
|
||||
|
||||
|
||||
def _input_callback(kwargs):
|
||||
# type: (Dict[str, Any]) -> None
|
||||
"""Handle the start of a request."""
|
||||
integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration)
|
||||
|
||||
if integration is None:
|
||||
return
|
||||
|
||||
# Get key parameters
|
||||
full_model = kwargs.get("model", "")
|
||||
try:
|
||||
model, provider, _, _ = litellm.get_llm_provider(full_model)
|
||||
except Exception:
|
||||
model = full_model
|
||||
provider = "unknown"
|
||||
|
||||
call_type = kwargs.get("call_type", None)
|
||||
if call_type == "embedding":
|
||||
operation = "embeddings"
|
||||
else:
|
||||
operation = "chat"
|
||||
|
||||
# Start a new span/transaction
|
||||
span = get_start_span_function()(
|
||||
op=(
|
||||
consts.OP.GEN_AI_CHAT
|
||||
if operation == "chat"
|
||||
else consts.OP.GEN_AI_EMBEDDINGS
|
||||
),
|
||||
name=f"{operation} {model}",
|
||||
origin=LiteLLMIntegration.origin,
|
||||
)
|
||||
span.__enter__()
|
||||
|
||||
# Store span for later
|
||||
_get_metadata_dict(kwargs)["_sentry_span"] = span
|
||||
|
||||
# Set basic data
|
||||
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, provider)
|
||||
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, operation)
|
||||
|
||||
# Record messages if allowed
|
||||
messages = kwargs.get("messages", [])
|
||||
if messages and should_send_default_pii() and integration.include_prompts:
|
||||
scope = sentry_sdk.get_current_scope()
|
||||
messages_data = truncate_and_annotate_messages(messages, span, scope)
|
||||
if messages_data is not None:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
|
||||
)
|
||||
|
||||
# Record other parameters
|
||||
params = {
|
||||
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
|
||||
"stream": SPANDATA.GEN_AI_RESPONSE_STREAMING,
|
||||
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
|
||||
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
|
||||
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
|
||||
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
|
||||
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
|
||||
}
|
||||
for key, attribute in params.items():
|
||||
value = kwargs.get(key)
|
||||
if value is not None:
|
||||
set_data_normalized(span, attribute, value)
|
||||
|
||||
# Record LiteLLM-specific parameters
|
||||
litellm_params = {
|
||||
"api_base": kwargs.get("api_base"),
|
||||
"api_version": kwargs.get("api_version"),
|
||||
"custom_llm_provider": kwargs.get("custom_llm_provider"),
|
||||
}
|
||||
for key, value in litellm_params.items():
|
||||
if value is not None:
|
||||
set_data_normalized(span, f"gen_ai.litellm.{key}", value)
|
||||
|
||||
|
||||
def _success_callback(kwargs, completion_response, start_time, end_time):
|
||||
# type: (Dict[str, Any], Any, datetime, datetime) -> None
|
||||
"""Handle successful completion."""
|
||||
|
||||
span = _get_metadata_dict(kwargs).get("_sentry_span")
|
||||
if span is None:
|
||||
return
|
||||
|
||||
integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration)
|
||||
if integration is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# Record model information
|
||||
if hasattr(completion_response, "model"):
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_RESPONSE_MODEL, completion_response.model
|
||||
)
|
||||
|
||||
# Record response content if allowed
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
if hasattr(completion_response, "choices"):
|
||||
response_messages = []
|
||||
for choice in completion_response.choices:
|
||||
if hasattr(choice, "message"):
|
||||
if hasattr(choice.message, "model_dump"):
|
||||
response_messages.append(choice.message.model_dump())
|
||||
elif hasattr(choice.message, "dict"):
|
||||
response_messages.append(choice.message.dict())
|
||||
else:
|
||||
# Fallback for basic message objects
|
||||
msg = {}
|
||||
if hasattr(choice.message, "role"):
|
||||
msg["role"] = choice.message.role
|
||||
if hasattr(choice.message, "content"):
|
||||
msg["content"] = choice.message.content
|
||||
if hasattr(choice.message, "tool_calls"):
|
||||
msg["tool_calls"] = choice.message.tool_calls
|
||||
response_messages.append(msg)
|
||||
|
||||
if response_messages:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_messages
|
||||
)
|
||||
|
||||
# Record token usage
|
||||
if hasattr(completion_response, "usage"):
|
||||
usage = completion_response.usage
|
||||
record_token_usage(
|
||||
span,
|
||||
input_tokens=getattr(usage, "prompt_tokens", None),
|
||||
output_tokens=getattr(usage, "completion_tokens", None),
|
||||
total_tokens=getattr(usage, "total_tokens", None),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Always finish the span and clean up
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
|
||||
def _failure_callback(kwargs, exception, start_time, end_time):
|
||||
# type: (Dict[str, Any], Exception, datetime, datetime) -> None
|
||||
"""Handle request failure."""
|
||||
span = _get_metadata_dict(kwargs).get("_sentry_span")
|
||||
if span is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# Capture the exception
|
||||
event, hint = event_from_exception(
|
||||
exception,
|
||||
client_options=sentry_sdk.get_client().options,
|
||||
mechanism={"type": "litellm", "handled": False},
|
||||
)
|
||||
sentry_sdk.capture_event(event, hint=hint)
|
||||
finally:
|
||||
# Always finish the span and clean up
|
||||
span.__exit__(type(exception), exception, None)
|
||||
|
||||
|
||||
class LiteLLMIntegration(Integration):
|
||||
"""
|
||||
LiteLLM integration for Sentry.
|
||||
|
||||
This integration automatically captures LiteLLM API calls and sends them to Sentry
|
||||
for monitoring and error tracking. It supports all 100+ LLM providers that LiteLLM
|
||||
supports, including OpenAI, Anthropic, Google, Cohere, and many others.
|
||||
|
||||
Features:
|
||||
- Automatic exception capture for all LiteLLM calls
|
||||
- Token usage tracking across all providers
|
||||
- Provider detection and attribution
|
||||
- Input/output message capture (configurable)
|
||||
- Streaming response support
|
||||
- Cost tracking integration
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
import litellm
|
||||
import sentry_sdk
|
||||
|
||||
# Initialize Sentry with the LiteLLM integration
|
||||
sentry_sdk.init(
|
||||
dsn="your-dsn",
|
||||
send_default_pii=True
|
||||
integrations=[
|
||||
sentry_sdk.integrations.LiteLLMIntegration(
|
||||
include_prompts=True # Set to False to exclude message content
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# All LiteLLM calls will now be monitored
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello!"}]
|
||||
)
|
||||
```
|
||||
|
||||
Configuration:
|
||||
- include_prompts (bool): Whether to include prompts and responses in spans.
|
||||
Defaults to True. Set to False to exclude potentially sensitive data.
|
||||
"""
|
||||
|
||||
identifier = "litellm"
|
||||
origin = f"auto.ai.{identifier}"
|
||||
|
||||
def __init__(self, include_prompts=True):
|
||||
# type: (LiteLLMIntegration, bool) -> None
|
||||
self.include_prompts = include_prompts
|
||||
|
||||
@staticmethod
|
||||
def setup_once():
|
||||
# type: () -> None
|
||||
"""Set up LiteLLM callbacks for monitoring."""
|
||||
litellm.input_callback = litellm.input_callback or []
|
||||
if _input_callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(_input_callback)
|
||||
|
||||
litellm.success_callback = litellm.success_callback or []
|
||||
if _success_callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(_success_callback)
|
||||
|
||||
litellm.failure_callback = litellm.failure_callback or []
|
||||
if _failure_callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(_failure_callback)
|
||||
@@ -1,4 +1,6 @@
|
||||
from collections.abc import Set
|
||||
from copy import deepcopy
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.consts import OP
|
||||
from sentry_sdk.integrations import (
|
||||
@@ -9,7 +11,7 @@ from sentry_sdk.integrations import (
|
||||
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
|
||||
from sentry_sdk.integrations.logging import ignore_logger
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.tracing import SOURCE_FOR_STYLE, TRANSACTION_SOURCE_ROUTE
|
||||
from sentry_sdk.tracing import TransactionSource, SOURCE_FOR_STYLE
|
||||
from sentry_sdk.utils import (
|
||||
ensure_integration_enabled,
|
||||
event_from_exception,
|
||||
@@ -85,8 +87,18 @@ class SentryLitestarASGIMiddleware(SentryAsgiMiddleware):
|
||||
transaction_style="endpoint",
|
||||
mechanism_type="asgi",
|
||||
span_origin=span_origin,
|
||||
asgi_version=3,
|
||||
)
|
||||
|
||||
def _capture_request_exception(self, exc):
|
||||
# type: (Exception) -> None
|
||||
"""Avoid catching exceptions from request handlers.
|
||||
|
||||
Those exceptions are already handled in Litestar.after_exception handler.
|
||||
We still catch exceptions from application lifespan handlers.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def patch_app_init():
|
||||
# type: () -> None
|
||||
@@ -107,7 +119,6 @@ def patch_app_init():
|
||||
*(kwargs.get("after_exception") or []),
|
||||
]
|
||||
|
||||
SentryLitestarASGIMiddleware.__call__ = SentryLitestarASGIMiddleware._run_asgi3 # type: ignore
|
||||
middleware = kwargs.get("middleware") or []
|
||||
kwargs["middleware"] = [SentryLitestarASGIMiddleware, *middleware]
|
||||
old__init__(self, *args, **kwargs)
|
||||
@@ -213,9 +224,7 @@ def patch_http_route_handle():
|
||||
return await old_handle(self, scope, receive, send)
|
||||
|
||||
sentry_scope = sentry_sdk.get_isolation_scope()
|
||||
request = scope["app"].request_class(
|
||||
scope=scope, receive=receive, send=send
|
||||
) # type: Request[Any, Any]
|
||||
request = scope["app"].request_class(scope=scope, receive=receive, send=send) # type: Request[Any, Any]
|
||||
extracted_request_data = ConnectionDataExtractor(
|
||||
parse_body=True, parse_query=True
|
||||
)(request)
|
||||
@@ -249,11 +258,11 @@ def patch_http_route_handle():
|
||||
|
||||
if not tx_name:
|
||||
tx_name = _DEFAULT_TRANSACTION_NAME
|
||||
tx_info = {"source": TRANSACTION_SOURCE_ROUTE}
|
||||
tx_info = {"source": TransactionSource.ROUTE}
|
||||
|
||||
event.update(
|
||||
{
|
||||
"request": request_info,
|
||||
"request": deepcopy(request_info),
|
||||
"transaction": tx_name,
|
||||
"transaction_info": tx_info,
|
||||
}
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from fnmatch import fnmatch
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.client import BaseClient
|
||||
from sentry_sdk.logger import _log_level_to_otel
|
||||
from sentry_sdk.utils import (
|
||||
safe_repr,
|
||||
to_string,
|
||||
event_from_exception,
|
||||
current_stacktrace,
|
||||
capture_internal_exceptions,
|
||||
has_logs_enabled,
|
||||
)
|
||||
from sentry_sdk.integrations import Integration
|
||||
|
||||
@@ -33,6 +38,16 @@ LOGGING_TO_EVENT_LEVEL = {
|
||||
logging.CRITICAL: "fatal", # CRITICAL is same as FATAL
|
||||
}
|
||||
|
||||
# Map logging level numbers to corresponding OTel level numbers
|
||||
SEVERITY_TO_OTEL_SEVERITY = {
|
||||
logging.CRITICAL: 21, # fatal
|
||||
logging.ERROR: 17, # error
|
||||
logging.WARNING: 13, # warn
|
||||
logging.INFO: 9, # info
|
||||
logging.DEBUG: 5, # debug
|
||||
}
|
||||
|
||||
|
||||
# Capturing events from those loggers causes recursion errors. We cannot allow
|
||||
# the user to unconditionally create events from those loggers under any
|
||||
# circumstances.
|
||||
@@ -61,14 +76,23 @@ def ignore_logger(
|
||||
class LoggingIntegration(Integration):
|
||||
identifier = "logging"
|
||||
|
||||
def __init__(self, level=DEFAULT_LEVEL, event_level=DEFAULT_EVENT_LEVEL):
|
||||
# type: (Optional[int], Optional[int]) -> None
|
||||
def __init__(
|
||||
self,
|
||||
level=DEFAULT_LEVEL,
|
||||
event_level=DEFAULT_EVENT_LEVEL,
|
||||
sentry_logs_level=DEFAULT_LEVEL,
|
||||
):
|
||||
# type: (Optional[int], Optional[int], Optional[int]) -> None
|
||||
self._handler = None
|
||||
self._breadcrumb_handler = None
|
||||
self._sentry_logs_handler = None
|
||||
|
||||
if level is not None:
|
||||
self._breadcrumb_handler = BreadcrumbHandler(level=level)
|
||||
|
||||
if sentry_logs_level is not None:
|
||||
self._sentry_logs_handler = SentryLogsHandler(level=sentry_logs_level)
|
||||
|
||||
if event_level is not None:
|
||||
self._handler = EventHandler(level=event_level)
|
||||
|
||||
@@ -83,6 +107,12 @@ class LoggingIntegration(Integration):
|
||||
):
|
||||
self._breadcrumb_handler.handle(record)
|
||||
|
||||
if (
|
||||
self._sentry_logs_handler is not None
|
||||
and record.levelno >= self._sentry_logs_handler.level
|
||||
):
|
||||
self._sentry_logs_handler.handle(record)
|
||||
|
||||
@staticmethod
|
||||
def setup_once():
|
||||
# type: () -> None
|
||||
@@ -101,7 +131,10 @@ class LoggingIntegration(Integration):
|
||||
# the integration. Otherwise we have a high chance of getting
|
||||
# into a recursion error when the integration is resolved
|
||||
# (this also is slower).
|
||||
if ignored_loggers is not None and record.name not in ignored_loggers:
|
||||
if (
|
||||
ignored_loggers is not None
|
||||
and record.name.strip() not in ignored_loggers
|
||||
):
|
||||
integration = sentry_sdk.get_client().get_integration(
|
||||
LoggingIntegration
|
||||
)
|
||||
@@ -146,7 +179,7 @@ class _BaseHandler(logging.Handler):
|
||||
# type: (LogRecord) -> bool
|
||||
"""Prevents ignored loggers from recording"""
|
||||
for logger in _IGNORED_LOGGERS:
|
||||
if fnmatch(record.name, logger):
|
||||
if fnmatch(record.name.strip(), logger):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -231,25 +264,25 @@ class EventHandler(_BaseHandler):
|
||||
event["level"] = level # type: ignore[typeddict-item]
|
||||
event["logger"] = record.name
|
||||
|
||||
# Log records from `warnings` module as separate issues
|
||||
record_caputured_from_warnings_module = (
|
||||
record.name == "py.warnings" and record.msg == "%s"
|
||||
)
|
||||
if record_caputured_from_warnings_module:
|
||||
# use the actual message and not "%s" as the message
|
||||
# this prevents grouping all warnings under one "%s" issue
|
||||
msg = record.args[0] # type: ignore
|
||||
|
||||
event["logentry"] = {
|
||||
"message": msg,
|
||||
"params": (),
|
||||
}
|
||||
|
||||
if (
|
||||
sys.version_info < (3, 11)
|
||||
and record.name == "py.warnings"
|
||||
and record.msg == "%s"
|
||||
):
|
||||
# warnings module on Python 3.10 and below sets record.msg to "%s"
|
||||
# and record.args[0] to the actual warning message.
|
||||
# This was fixed in https://github.com/python/cpython/pull/30975.
|
||||
message = record.args[0]
|
||||
params = ()
|
||||
else:
|
||||
event["logentry"] = {
|
||||
"message": to_string(record.msg),
|
||||
"params": record.args,
|
||||
}
|
||||
message = record.msg
|
||||
params = record.args
|
||||
|
||||
event["logentry"] = {
|
||||
"message": to_string(message),
|
||||
"formatted": record.getMessage(),
|
||||
"params": params,
|
||||
}
|
||||
|
||||
event["extra"] = self._extra_from_record(record)
|
||||
|
||||
@@ -292,3 +325,97 @@ class BreadcrumbHandler(_BaseHandler):
|
||||
"timestamp": datetime.fromtimestamp(record.created, timezone.utc),
|
||||
"data": self._extra_from_record(record),
|
||||
}
|
||||
|
||||
|
||||
class SentryLogsHandler(_BaseHandler):
|
||||
"""
|
||||
A logging handler that records Sentry logs for each Python log record.
|
||||
|
||||
Note that you do not have to use this class if the logging integration is enabled, which it is by default.
|
||||
"""
|
||||
|
||||
def emit(self, record):
|
||||
# type: (LogRecord) -> Any
|
||||
with capture_internal_exceptions():
|
||||
self.format(record)
|
||||
if not self._can_record(record):
|
||||
return
|
||||
|
||||
client = sentry_sdk.get_client()
|
||||
if not client.is_active():
|
||||
return
|
||||
|
||||
if not has_logs_enabled(client.options):
|
||||
return
|
||||
|
||||
self._capture_log_from_record(client, record)
|
||||
|
||||
def _capture_log_from_record(self, client, record):
|
||||
# type: (BaseClient, LogRecord) -> None
|
||||
otel_severity_number, otel_severity_text = _log_level_to_otel(
|
||||
record.levelno, SEVERITY_TO_OTEL_SEVERITY
|
||||
)
|
||||
project_root = client.options["project_root"]
|
||||
|
||||
attrs = self._extra_from_record(record) # type: Any
|
||||
attrs["sentry.origin"] = "auto.logger.log"
|
||||
|
||||
parameters_set = False
|
||||
if record.args is not None:
|
||||
if isinstance(record.args, tuple):
|
||||
parameters_set = bool(record.args)
|
||||
for i, arg in enumerate(record.args):
|
||||
attrs[f"sentry.message.parameter.{i}"] = (
|
||||
arg
|
||||
if isinstance(arg, (str, float, int, bool))
|
||||
else safe_repr(arg)
|
||||
)
|
||||
elif isinstance(record.args, dict):
|
||||
parameters_set = bool(record.args)
|
||||
for key, value in record.args.items():
|
||||
attrs[f"sentry.message.parameter.{key}"] = (
|
||||
value
|
||||
if isinstance(value, (str, float, int, bool))
|
||||
else safe_repr(value)
|
||||
)
|
||||
|
||||
if parameters_set and isinstance(record.msg, str):
|
||||
# only include template if there is at least one
|
||||
# sentry.message.parameter.X set
|
||||
attrs["sentry.message.template"] = record.msg
|
||||
|
||||
if record.lineno:
|
||||
attrs["code.line.number"] = record.lineno
|
||||
|
||||
if record.pathname:
|
||||
if project_root is not None and record.pathname.startswith(project_root):
|
||||
attrs["code.file.path"] = record.pathname[len(project_root) + 1 :]
|
||||
else:
|
||||
attrs["code.file.path"] = record.pathname
|
||||
|
||||
if record.funcName:
|
||||
attrs["code.function.name"] = record.funcName
|
||||
|
||||
if record.thread:
|
||||
attrs["thread.id"] = record.thread
|
||||
if record.threadName:
|
||||
attrs["thread.name"] = record.threadName
|
||||
|
||||
if record.process:
|
||||
attrs["process.pid"] = record.process
|
||||
if record.processName:
|
||||
attrs["process.executable.name"] = record.processName
|
||||
if record.name:
|
||||
attrs["logger.name"] = record.name
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
client._capture_log(
|
||||
{
|
||||
"severity_text": otel_severity_text,
|
||||
"severity_number": otel_severity_number,
|
||||
"body": record.message,
|
||||
"attributes": attrs,
|
||||
"time_unix_nano": int(record.created * 1e9),
|
||||
"trace_id": None,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,22 +1,28 @@
|
||||
import enum
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.integrations import Integration, DidNotEnable
|
||||
from sentry_sdk.integrations.logging import (
|
||||
BreadcrumbHandler,
|
||||
EventHandler,
|
||||
_BaseHandler,
|
||||
)
|
||||
from sentry_sdk.logger import _log_level_to_otel
|
||||
from sentry_sdk.utils import has_logs_enabled
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import LogRecord
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
import loguru
|
||||
from loguru import logger
|
||||
from loguru._defaults import LOGURU_FORMAT as DEFAULT_FORMAT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from loguru import Message
|
||||
except ImportError:
|
||||
raise DidNotEnable("LOGURU is not installed")
|
||||
|
||||
@@ -33,68 +39,167 @@ class LoggingLevels(enum.IntEnum):
|
||||
|
||||
DEFAULT_LEVEL = LoggingLevels.INFO.value
|
||||
DEFAULT_EVENT_LEVEL = LoggingLevels.ERROR.value
|
||||
# We need to save the handlers to be able to remove them later
|
||||
# in tests (they call `LoguruIntegration.__init__` multiple times,
|
||||
# and we can't use `setup_once` because it's called before
|
||||
# than we get configuration).
|
||||
_ADDED_HANDLERS = (None, None) # type: Tuple[Optional[int], Optional[int]]
|
||||
|
||||
|
||||
SENTRY_LEVEL_FROM_LOGURU_LEVEL = {
|
||||
"TRACE": "DEBUG",
|
||||
"DEBUG": "DEBUG",
|
||||
"INFO": "INFO",
|
||||
"SUCCESS": "INFO",
|
||||
"WARNING": "WARNING",
|
||||
"ERROR": "ERROR",
|
||||
"CRITICAL": "CRITICAL",
|
||||
}
|
||||
|
||||
# Map Loguru level numbers to corresponding OTel level numbers
|
||||
SEVERITY_TO_OTEL_SEVERITY = {
|
||||
LoggingLevels.CRITICAL: 21, # fatal
|
||||
LoggingLevels.ERROR: 17, # error
|
||||
LoggingLevels.WARNING: 13, # warn
|
||||
LoggingLevels.SUCCESS: 11, # info
|
||||
LoggingLevels.INFO: 9, # info
|
||||
LoggingLevels.DEBUG: 5, # debug
|
||||
LoggingLevels.TRACE: 1, # trace
|
||||
}
|
||||
|
||||
|
||||
class LoguruIntegration(Integration):
|
||||
identifier = "loguru"
|
||||
|
||||
level = DEFAULT_LEVEL # type: Optional[int]
|
||||
event_level = DEFAULT_EVENT_LEVEL # type: Optional[int]
|
||||
breadcrumb_format = DEFAULT_FORMAT
|
||||
event_format = DEFAULT_FORMAT
|
||||
sentry_logs_level = DEFAULT_LEVEL # type: Optional[int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
level=DEFAULT_LEVEL,
|
||||
event_level=DEFAULT_EVENT_LEVEL,
|
||||
breadcrumb_format=DEFAULT_FORMAT,
|
||||
event_format=DEFAULT_FORMAT,
|
||||
sentry_logs_level=DEFAULT_LEVEL,
|
||||
):
|
||||
# type: (Optional[int], Optional[int], str | loguru.FormatFunction, str | loguru.FormatFunction) -> None
|
||||
global _ADDED_HANDLERS
|
||||
breadcrumb_handler, event_handler = _ADDED_HANDLERS
|
||||
|
||||
if breadcrumb_handler is not None:
|
||||
logger.remove(breadcrumb_handler)
|
||||
breadcrumb_handler = None
|
||||
if event_handler is not None:
|
||||
logger.remove(event_handler)
|
||||
event_handler = None
|
||||
|
||||
if level is not None:
|
||||
breadcrumb_handler = logger.add(
|
||||
LoguruBreadcrumbHandler(level=level),
|
||||
level=level,
|
||||
format=breadcrumb_format,
|
||||
)
|
||||
|
||||
if event_level is not None:
|
||||
event_handler = logger.add(
|
||||
LoguruEventHandler(level=event_level),
|
||||
level=event_level,
|
||||
format=event_format,
|
||||
)
|
||||
|
||||
_ADDED_HANDLERS = (breadcrumb_handler, event_handler)
|
||||
# type: (Optional[int], Optional[int], str | loguru.FormatFunction, str | loguru.FormatFunction, Optional[int]) -> None
|
||||
LoguruIntegration.level = level
|
||||
LoguruIntegration.event_level = event_level
|
||||
LoguruIntegration.breadcrumb_format = breadcrumb_format
|
||||
LoguruIntegration.event_format = event_format
|
||||
LoguruIntegration.sentry_logs_level = sentry_logs_level
|
||||
|
||||
@staticmethod
|
||||
def setup_once():
|
||||
# type: () -> None
|
||||
pass # we do everything in __init__
|
||||
if LoguruIntegration.level is not None:
|
||||
logger.add(
|
||||
LoguruBreadcrumbHandler(level=LoguruIntegration.level),
|
||||
level=LoguruIntegration.level,
|
||||
format=LoguruIntegration.breadcrumb_format,
|
||||
)
|
||||
|
||||
if LoguruIntegration.event_level is not None:
|
||||
logger.add(
|
||||
LoguruEventHandler(level=LoguruIntegration.event_level),
|
||||
level=LoguruIntegration.event_level,
|
||||
format=LoguruIntegration.event_format,
|
||||
)
|
||||
|
||||
if LoguruIntegration.sentry_logs_level is not None:
|
||||
logger.add(
|
||||
loguru_sentry_logs_handler,
|
||||
level=LoguruIntegration.sentry_logs_level,
|
||||
)
|
||||
|
||||
|
||||
class _LoguruBaseHandler(_BaseHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
# type: (*Any, **Any) -> None
|
||||
if kwargs.get("level"):
|
||||
kwargs["level"] = SENTRY_LEVEL_FROM_LOGURU_LEVEL.get(
|
||||
kwargs.get("level", ""), DEFAULT_LEVEL
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _logging_to_event_level(self, record):
|
||||
# type: (LogRecord) -> str
|
||||
try:
|
||||
return LoggingLevels(record.levelno).name.lower()
|
||||
except ValueError:
|
||||
return SENTRY_LEVEL_FROM_LOGURU_LEVEL[
|
||||
LoggingLevels(record.levelno).name
|
||||
].lower()
|
||||
except (ValueError, KeyError):
|
||||
return record.levelname.lower() if record.levelname else ""
|
||||
|
||||
|
||||
class LoguruEventHandler(_LoguruBaseHandler, EventHandler):
|
||||
"""Modified version of :class:`sentry_sdk.integrations.logging.EventHandler` to use loguru's level names."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LoguruBreadcrumbHandler(_LoguruBaseHandler, BreadcrumbHandler):
|
||||
"""Modified version of :class:`sentry_sdk.integrations.logging.BreadcrumbHandler` to use loguru's level names."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def loguru_sentry_logs_handler(message):
|
||||
# type: (Message) -> None
|
||||
# This is intentionally a callable sink instead of a standard logging handler
|
||||
# since otherwise we wouldn't get direct access to message.record
|
||||
client = sentry_sdk.get_client()
|
||||
|
||||
if not client.is_active():
|
||||
return
|
||||
|
||||
if not has_logs_enabled(client.options):
|
||||
return
|
||||
|
||||
record = message.record
|
||||
|
||||
if (
|
||||
LoguruIntegration.sentry_logs_level is None
|
||||
or record["level"].no < LoguruIntegration.sentry_logs_level
|
||||
):
|
||||
return
|
||||
|
||||
otel_severity_number, otel_severity_text = _log_level_to_otel(
|
||||
record["level"].no, SEVERITY_TO_OTEL_SEVERITY
|
||||
)
|
||||
|
||||
attrs = {"sentry.origin": "auto.logger.loguru"} # type: dict[str, Any]
|
||||
|
||||
project_root = client.options["project_root"]
|
||||
if record.get("file"):
|
||||
if project_root is not None and record["file"].path.startswith(project_root):
|
||||
attrs["code.file.path"] = record["file"].path[len(project_root) + 1 :]
|
||||
else:
|
||||
attrs["code.file.path"] = record["file"].path
|
||||
|
||||
if record.get("line") is not None:
|
||||
attrs["code.line.number"] = record["line"]
|
||||
|
||||
if record.get("function"):
|
||||
attrs["code.function.name"] = record["function"]
|
||||
|
||||
if record.get("thread"):
|
||||
attrs["thread.name"] = record["thread"].name
|
||||
attrs["thread.id"] = record["thread"].id
|
||||
|
||||
if record.get("process"):
|
||||
attrs["process.pid"] = record["process"].id
|
||||
attrs["process.executable.name"] = record["process"].name
|
||||
|
||||
if record.get("name"):
|
||||
attrs["logger.name"] = record["name"]
|
||||
|
||||
client._capture_log(
|
||||
{
|
||||
"severity_text": otel_severity_text,
|
||||
"severity_number": otel_severity_number,
|
||||
"body": record["message"],
|
||||
"attributes": attrs,
|
||||
"time_unix_nano": int(record["time"].timestamp() * 1e9),
|
||||
"trace_id": None,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -0,0 +1,552 @@
|
||||
"""
|
||||
Sentry integration for MCP (Model Context Protocol) servers.
|
||||
|
||||
This integration instruments MCP servers to create spans for tool, prompt,
|
||||
and resource handler execution, and captures errors that occur during execution.
|
||||
|
||||
Supports the low-level `mcp.server.lowlevel.Server` API.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.ai.utils import get_start_span_function
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
from sentry_sdk.integrations import Integration, DidNotEnable
|
||||
from sentry_sdk.utils import safe_serialize
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
|
||||
try:
|
||||
from mcp.server.lowlevel import Server # type: ignore[import-not-found]
|
||||
from mcp.server.lowlevel.server import request_ctx # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
raise DidNotEnable("MCP SDK not installed")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
class MCPIntegration(Integration):
|
||||
identifier = "mcp"
|
||||
origin = "auto.ai.mcp"
|
||||
|
||||
def __init__(self, include_prompts=True):
|
||||
# type: (bool) -> None
|
||||
"""
|
||||
Initialize the MCP integration.
|
||||
|
||||
Args:
|
||||
include_prompts: Whether to include prompts (tool results and prompt content)
|
||||
in span data. Requires send_default_pii=True. Default is True.
|
||||
"""
|
||||
self.include_prompts = include_prompts
|
||||
|
||||
@staticmethod
|
||||
def setup_once():
|
||||
# type: () -> None
|
||||
"""
|
||||
Patches MCP server classes to instrument handler execution.
|
||||
"""
|
||||
_patch_lowlevel_server()
|
||||
|
||||
|
||||
def _get_request_context_data():
|
||||
# type: () -> tuple[Optional[str], Optional[str], str]
|
||||
"""
|
||||
Extract request ID, session ID, and transport type from the MCP request context.
|
||||
|
||||
Returns:
|
||||
Tuple of (request_id, session_id, transport).
|
||||
- request_id: May be None if not available
|
||||
- session_id: May be None if not available
|
||||
- transport: "tcp" for HTTP-based, "pipe" for stdio
|
||||
"""
|
||||
request_id = None # type: Optional[str]
|
||||
session_id = None # type: Optional[str]
|
||||
transport = "pipe" # type: str
|
||||
|
||||
try:
|
||||
ctx = request_ctx.get()
|
||||
|
||||
if ctx is not None:
|
||||
request_id = ctx.request_id
|
||||
if hasattr(ctx, "request") and ctx.request is not None:
|
||||
transport = "tcp"
|
||||
request = ctx.request
|
||||
if hasattr(request, "headers"):
|
||||
session_id = request.headers.get("mcp-session-id")
|
||||
|
||||
except LookupError:
|
||||
# No request context available - default to pipe
|
||||
pass
|
||||
|
||||
return request_id, session_id, transport
|
||||
|
||||
|
||||
def _get_span_config(handler_type, item_name):
|
||||
# type: (str, str) -> tuple[str, str, str, Optional[str]]
|
||||
"""
|
||||
Get span configuration based on handler type.
|
||||
|
||||
Returns:
|
||||
Tuple of (span_data_key, span_name, mcp_method_name, result_data_key)
|
||||
Note: result_data_key is None for resources
|
||||
"""
|
||||
if handler_type == "tool":
|
||||
span_data_key = SPANDATA.MCP_TOOL_NAME
|
||||
mcp_method_name = "tools/call"
|
||||
result_data_key = SPANDATA.MCP_TOOL_RESULT_CONTENT
|
||||
elif handler_type == "prompt":
|
||||
span_data_key = SPANDATA.MCP_PROMPT_NAME
|
||||
mcp_method_name = "prompts/get"
|
||||
result_data_key = SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT
|
||||
else: # resource
|
||||
span_data_key = SPANDATA.MCP_RESOURCE_URI
|
||||
mcp_method_name = "resources/read"
|
||||
result_data_key = None # Resources don't capture result content
|
||||
|
||||
span_name = f"{mcp_method_name} {item_name}"
|
||||
return span_data_key, span_name, mcp_method_name, result_data_key
|
||||
|
||||
|
||||
def _set_span_input_data(
|
||||
span,
|
||||
handler_name,
|
||||
span_data_key,
|
||||
mcp_method_name,
|
||||
arguments,
|
||||
request_id,
|
||||
session_id,
|
||||
transport,
|
||||
):
|
||||
# type: (Any, str, str, str, dict[str, Any], Optional[str], Optional[str], str) -> None
|
||||
"""Set input span data for MCP handlers."""
|
||||
# Set handler identifier
|
||||
span.set_data(span_data_key, handler_name)
|
||||
span.set_data(SPANDATA.MCP_METHOD_NAME, mcp_method_name)
|
||||
|
||||
# Set transport type
|
||||
span.set_data(SPANDATA.MCP_TRANSPORT, transport)
|
||||
|
||||
# Set request_id if provided
|
||||
if request_id:
|
||||
span.set_data(SPANDATA.MCP_REQUEST_ID, request_id)
|
||||
|
||||
# Set session_id if provided
|
||||
if session_id:
|
||||
span.set_data(SPANDATA.MCP_SESSION_ID, session_id)
|
||||
|
||||
# Set request arguments (excluding common request context objects)
|
||||
for k, v in arguments.items():
|
||||
span.set_data(f"mcp.request.argument.{k}", safe_serialize(v))
|
||||
|
||||
|
||||
def _extract_tool_result_content(result):
|
||||
# type: (Any) -> Any
|
||||
"""
|
||||
Extract meaningful content from MCP tool result.
|
||||
|
||||
Tool handlers can return:
|
||||
- tuple (UnstructuredContent, StructuredContent): Return the structured content (dict)
|
||||
- dict (StructuredContent): Return as-is
|
||||
- Iterable (UnstructuredContent): Extract text from content blocks
|
||||
"""
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
# Handle CombinationContent: tuple of (UnstructuredContent, StructuredContent)
|
||||
if isinstance(result, tuple) and len(result) == 2:
|
||||
# Return the structured content (2nd element)
|
||||
return result[1]
|
||||
|
||||
# Handle StructuredContent: dict
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
|
||||
# Handle UnstructuredContent: iterable of ContentBlock objects
|
||||
# Try to extract text content
|
||||
if hasattr(result, "__iter__") and not isinstance(result, (str, bytes, dict)):
|
||||
texts = []
|
||||
try:
|
||||
for item in result:
|
||||
# Try to get text attribute from ContentBlock objects
|
||||
if hasattr(item, "text"):
|
||||
texts.append(item.text)
|
||||
elif isinstance(item, dict) and "text" in item:
|
||||
texts.append(item["text"])
|
||||
except Exception:
|
||||
# If extraction fails, return the original
|
||||
return result
|
||||
return " ".join(texts) if texts else result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _set_span_output_data(span, result, result_data_key, handler_type):
|
||||
# type: (Any, Any, Optional[str], str) -> None
|
||||
"""Set output span data for MCP handlers."""
|
||||
if result is None:
|
||||
return
|
||||
|
||||
# Get integration to check PII settings
|
||||
integration = sentry_sdk.get_client().get_integration(MCPIntegration)
|
||||
if integration is None:
|
||||
return
|
||||
|
||||
# Check if we should include sensitive data
|
||||
should_include_data = should_send_default_pii() and integration.include_prompts
|
||||
|
||||
# For tools, extract the meaningful content
|
||||
if handler_type == "tool":
|
||||
extracted = _extract_tool_result_content(result)
|
||||
if extracted is not None and should_include_data:
|
||||
span.set_data(result_data_key, safe_serialize(extracted))
|
||||
# Set content count if result is a dict
|
||||
if isinstance(extracted, dict):
|
||||
span.set_data(SPANDATA.MCP_TOOL_RESULT_CONTENT_COUNT, len(extracted))
|
||||
elif handler_type == "prompt":
|
||||
# For prompts, count messages and set role/content only for single-message prompts
|
||||
try:
|
||||
messages = None # type: Optional[list[str]]
|
||||
message_count = 0
|
||||
|
||||
# Check if result has messages attribute (GetPromptResult)
|
||||
if hasattr(result, "messages") and result.messages:
|
||||
messages = result.messages
|
||||
message_count = len(messages)
|
||||
# Also check if result is a dict with messages
|
||||
elif isinstance(result, dict) and result.get("messages"):
|
||||
messages = result["messages"]
|
||||
message_count = len(messages)
|
||||
|
||||
# Always set message count if we found messages
|
||||
if message_count > 0:
|
||||
span.set_data(SPANDATA.MCP_PROMPT_RESULT_MESSAGE_COUNT, message_count)
|
||||
|
||||
# Only set role and content for single-message prompts if PII is allowed
|
||||
if message_count == 1 and should_include_data and messages:
|
||||
first_message = messages[0]
|
||||
# Extract role
|
||||
role = None
|
||||
if hasattr(first_message, "role"):
|
||||
role = first_message.role
|
||||
elif isinstance(first_message, dict) and "role" in first_message:
|
||||
role = first_message["role"]
|
||||
|
||||
if role:
|
||||
span.set_data(SPANDATA.MCP_PROMPT_RESULT_MESSAGE_ROLE, role)
|
||||
|
||||
# Extract content text
|
||||
content_text = None
|
||||
if hasattr(first_message, "content"):
|
||||
msg_content = first_message.content
|
||||
# Content can be a TextContent object or similar
|
||||
if hasattr(msg_content, "text"):
|
||||
content_text = msg_content.text
|
||||
elif isinstance(msg_content, dict) and "text" in msg_content:
|
||||
content_text = msg_content["text"]
|
||||
elif isinstance(msg_content, str):
|
||||
content_text = msg_content
|
||||
elif isinstance(first_message, dict) and "content" in first_message:
|
||||
msg_content = first_message["content"]
|
||||
if isinstance(msg_content, dict) and "text" in msg_content:
|
||||
content_text = msg_content["text"]
|
||||
elif isinstance(msg_content, str):
|
||||
content_text = msg_content
|
||||
|
||||
if content_text:
|
||||
span.set_data(result_data_key, content_text)
|
||||
except Exception:
|
||||
# Silently ignore if we can't extract message info
|
||||
pass
|
||||
# Resources don't capture result content (result_data_key is None)
|
||||
|
||||
|
||||
# Handler data preparation and wrapping
|
||||
|
||||
|
||||
def _prepare_handler_data(handler_type, original_args):
|
||||
# type: (str, tuple[Any, ...]) -> tuple[str, dict[str, Any], str, str, str, Optional[str]]
|
||||
"""
|
||||
Prepare common handler data for both async and sync wrappers.
|
||||
|
||||
Returns:
|
||||
Tuple of (handler_name, arguments, span_data_key, span_name, mcp_method_name, result_data_key)
|
||||
"""
|
||||
# Extract handler-specific data based on handler type
|
||||
if handler_type == "tool":
|
||||
handler_name = original_args[0] # tool_name
|
||||
arguments = original_args[1] if len(original_args) > 1 else {}
|
||||
elif handler_type == "prompt":
|
||||
handler_name = original_args[0] # name
|
||||
arguments = original_args[1] if len(original_args) > 1 else {}
|
||||
# Include name in arguments dict for span data
|
||||
arguments = {"name": handler_name, **(arguments or {})}
|
||||
else: # resource
|
||||
uri = original_args[0]
|
||||
handler_name = str(uri) if uri else "unknown"
|
||||
arguments = {}
|
||||
|
||||
# Get span configuration
|
||||
span_data_key, span_name, mcp_method_name, result_data_key = _get_span_config(
|
||||
handler_type, handler_name
|
||||
)
|
||||
|
||||
return (
|
||||
handler_name,
|
||||
arguments,
|
||||
span_data_key,
|
||||
span_name,
|
||||
mcp_method_name,
|
||||
result_data_key,
|
||||
)
|
||||
|
||||
|
||||
async def _async_handler_wrapper(handler_type, func, original_args):
|
||||
# type: (str, Callable[..., Any], tuple[Any, ...]) -> Any
|
||||
"""
|
||||
Async wrapper for MCP handlers.
|
||||
|
||||
Args:
|
||||
handler_type: "tool", "prompt", or "resource"
|
||||
func: The async handler function to wrap
|
||||
original_args: Original arguments passed to the handler
|
||||
"""
|
||||
(
|
||||
handler_name,
|
||||
arguments,
|
||||
span_data_key,
|
||||
span_name,
|
||||
mcp_method_name,
|
||||
result_data_key,
|
||||
) = _prepare_handler_data(handler_type, original_args)
|
||||
|
||||
# Start span and execute
|
||||
with get_start_span_function()(
|
||||
op=OP.MCP_SERVER,
|
||||
name=span_name,
|
||||
origin=MCPIntegration.origin,
|
||||
) as span:
|
||||
# Get request ID, session ID, and transport from context
|
||||
request_id, session_id, transport = _get_request_context_data()
|
||||
|
||||
# Set input span data
|
||||
_set_span_input_data(
|
||||
span,
|
||||
handler_name,
|
||||
span_data_key,
|
||||
mcp_method_name,
|
||||
arguments,
|
||||
request_id,
|
||||
session_id,
|
||||
transport,
|
||||
)
|
||||
|
||||
# For resources, extract and set protocol
|
||||
if handler_type == "resource":
|
||||
uri = original_args[0]
|
||||
protocol = None
|
||||
if hasattr(uri, "scheme"):
|
||||
protocol = uri.scheme
|
||||
elif handler_name and "://" in handler_name:
|
||||
protocol = handler_name.split("://")[0]
|
||||
if protocol:
|
||||
span.set_data(SPANDATA.MCP_RESOURCE_PROTOCOL, protocol)
|
||||
|
||||
try:
|
||||
# Execute the async handler
|
||||
result = await func(*original_args)
|
||||
except Exception as e:
|
||||
# Set error flag for tools
|
||||
if handler_type == "tool":
|
||||
span.set_data(SPANDATA.MCP_TOOL_RESULT_IS_ERROR, True)
|
||||
sentry_sdk.capture_exception(e)
|
||||
raise
|
||||
|
||||
_set_span_output_data(span, result, result_data_key, handler_type)
|
||||
return result
|
||||
|
||||
|
||||
def _sync_handler_wrapper(handler_type, func, original_args):
|
||||
# type: (str, Callable[..., Any], tuple[Any, ...]) -> Any
|
||||
"""
|
||||
Sync wrapper for MCP handlers.
|
||||
|
||||
Args:
|
||||
handler_type: "tool", "prompt", or "resource"
|
||||
func: The sync handler function to wrap
|
||||
original_args: Original arguments passed to the handler
|
||||
"""
|
||||
(
|
||||
handler_name,
|
||||
arguments,
|
||||
span_data_key,
|
||||
span_name,
|
||||
mcp_method_name,
|
||||
result_data_key,
|
||||
) = _prepare_handler_data(handler_type, original_args)
|
||||
|
||||
# Start span and execute
|
||||
with get_start_span_function()(
|
||||
op=OP.MCP_SERVER,
|
||||
name=span_name,
|
||||
origin=MCPIntegration.origin,
|
||||
) as span:
|
||||
# Get request ID, session ID, and transport from context
|
||||
request_id, session_id, transport = _get_request_context_data()
|
||||
|
||||
# Set input span data
|
||||
_set_span_input_data(
|
||||
span,
|
||||
handler_name,
|
||||
span_data_key,
|
||||
mcp_method_name,
|
||||
arguments,
|
||||
request_id,
|
||||
session_id,
|
||||
transport,
|
||||
)
|
||||
|
||||
# For resources, extract and set protocol
|
||||
if handler_type == "resource":
|
||||
uri = original_args[0]
|
||||
protocol = None
|
||||
if hasattr(uri, "scheme"):
|
||||
protocol = uri.scheme
|
||||
elif handler_name and "://" in handler_name:
|
||||
protocol = handler_name.split("://")[0]
|
||||
if protocol:
|
||||
span.set_data(SPANDATA.MCP_RESOURCE_PROTOCOL, protocol)
|
||||
|
||||
try:
|
||||
# Execute the sync handler
|
||||
result = func(*original_args)
|
||||
except Exception as e:
|
||||
# Set error flag for tools
|
||||
if handler_type == "tool":
|
||||
span.set_data(SPANDATA.MCP_TOOL_RESULT_IS_ERROR, True)
|
||||
sentry_sdk.capture_exception(e)
|
||||
raise
|
||||
|
||||
_set_span_output_data(span, result, result_data_key, handler_type)
|
||||
return result
|
||||
|
||||
|
||||
def _create_instrumented_handler(handler_type, func):
|
||||
# type: (str, Callable[..., Any]) -> Callable[..., Any]
|
||||
"""
|
||||
Create an instrumented version of a handler function (async or sync).
|
||||
|
||||
This function wraps the user's handler with a runtime wrapper that will create
|
||||
Sentry spans and capture metrics when the handler is actually called.
|
||||
|
||||
The wrapper preserves the async/sync nature of the original function, which is
|
||||
critical for Python's async/await to work correctly.
|
||||
|
||||
Args:
|
||||
handler_type: "tool", "prompt", or "resource" - determines span configuration
|
||||
func: The handler function to instrument (async or sync)
|
||||
|
||||
Returns:
|
||||
A wrapped version of func that creates Sentry spans on execution
|
||||
"""
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args):
|
||||
# type: (*Any) -> Any
|
||||
return await _async_handler_wrapper(handler_type, func, args)
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args):
|
||||
# type: (*Any) -> Any
|
||||
return _sync_handler_wrapper(handler_type, func, args)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def _create_instrumented_decorator(
|
||||
original_decorator, handler_type, *decorator_args, **decorator_kwargs
|
||||
):
|
||||
# type: (Callable[..., Any], str, *Any, **Any) -> Callable[..., Any]
|
||||
"""
|
||||
Create an instrumented version of an MCP decorator.
|
||||
|
||||
This function intercepts MCP decorators (like @server.call_tool()) and injects
|
||||
Sentry instrumentation into the handler registration flow. The returned decorator
|
||||
will:
|
||||
1. Receive the user's handler function
|
||||
2. Wrap it with instrumentation via _create_instrumented_handler
|
||||
3. Pass the instrumented version to the original MCP decorator
|
||||
|
||||
This ensures that when the handler is called at runtime, it's already wrapped
|
||||
with Sentry spans and metrics collection.
|
||||
|
||||
Args:
|
||||
original_decorator: The original MCP decorator method (e.g., Server.call_tool)
|
||||
handler_type: "tool", "prompt", or "resource" - determines span configuration
|
||||
decorator_args: Positional arguments to pass to the original decorator (e.g., self)
|
||||
decorator_kwargs: Keyword arguments to pass to the original decorator
|
||||
|
||||
Returns:
|
||||
A decorator function that instruments handlers before registering them
|
||||
"""
|
||||
|
||||
def instrumented_decorator(func):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
# First wrap the handler with instrumentation
|
||||
instrumented_func = _create_instrumented_handler(handler_type, func)
|
||||
# Then register it with the original MCP decorator
|
||||
return original_decorator(*decorator_args, **decorator_kwargs)(
|
||||
instrumented_func
|
||||
)
|
||||
|
||||
return instrumented_decorator
|
||||
|
||||
|
||||
def _patch_lowlevel_server():
|
||||
# type: () -> None
|
||||
"""
|
||||
Patches the mcp.server.lowlevel.Server class to instrument handler execution.
|
||||
"""
|
||||
# Patch call_tool decorator
|
||||
original_call_tool = Server.call_tool
|
||||
|
||||
def patched_call_tool(self, **kwargs):
|
||||
# type: (Server, **Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]
|
||||
"""Patched version of Server.call_tool that adds Sentry instrumentation."""
|
||||
return lambda func: _create_instrumented_decorator(
|
||||
original_call_tool, "tool", self, **kwargs
|
||||
)(func)
|
||||
|
||||
Server.call_tool = patched_call_tool
|
||||
|
||||
# Patch get_prompt decorator
|
||||
original_get_prompt = Server.get_prompt
|
||||
|
||||
def patched_get_prompt(self):
|
||||
# type: (Server) -> Callable[[Callable[..., Any]], Callable[..., Any]]
|
||||
"""Patched version of Server.get_prompt that adds Sentry instrumentation."""
|
||||
return lambda func: _create_instrumented_decorator(
|
||||
original_get_prompt, "prompt", self
|
||||
)(func)
|
||||
|
||||
Server.get_prompt = patched_get_prompt
|
||||
|
||||
# Patch read_resource decorator
|
||||
original_read_resource = Server.read_resource
|
||||
|
||||
def patched_read_resource(self):
|
||||
# type: (Server) -> Callable[[Callable[..., Any]], Callable[..., Any]]
|
||||
"""Patched version of Server.read_resource that adds Sentry instrumentation."""
|
||||
return lambda func: _create_instrumented_decorator(
|
||||
original_read_resource, "resource", self
|
||||
)(func)
|
||||
|
||||
Server.read_resource = patched_read_resource
|
||||
@@ -3,13 +3,19 @@ from functools import wraps
|
||||
import sentry_sdk
|
||||
from sentry_sdk import consts
|
||||
from sentry_sdk.ai.monitoring import record_token_usage
|
||||
from sentry_sdk.ai.utils import set_data_normalized
|
||||
from sentry_sdk.ai.utils import (
|
||||
set_data_normalized,
|
||||
normalize_message_roles,
|
||||
truncate_and_annotate_messages,
|
||||
)
|
||||
from sentry_sdk.consts import SPANDATA
|
||||
from sentry_sdk.integrations import DidNotEnable, Integration
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.tracing_utils import set_span_errored
|
||||
from sentry_sdk.utils import (
|
||||
capture_internal_exceptions,
|
||||
event_from_exception,
|
||||
safe_serialize,
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -19,6 +25,16 @@ if TYPE_CHECKING:
|
||||
from sentry_sdk.tracing import Span
|
||||
|
||||
try:
|
||||
try:
|
||||
from openai import NotGiven
|
||||
except ImportError:
|
||||
NotGiven = None
|
||||
|
||||
try:
|
||||
from openai import Omit
|
||||
except ImportError:
|
||||
Omit = None
|
||||
|
||||
from openai.resources.chat.completions import Completions, AsyncCompletions
|
||||
from openai.resources import Embeddings, AsyncEmbeddings
|
||||
|
||||
@@ -27,6 +43,14 @@ try:
|
||||
except ImportError:
|
||||
raise DidNotEnable("OpenAI not installed")
|
||||
|
||||
RESPONSES_API_ENABLED = True
|
||||
try:
|
||||
# responses API support was introduced in v1.66.0
|
||||
from openai.resources.responses import Responses, AsyncResponses
|
||||
from openai.types.responses.response_completed_event import ResponseCompletedEvent
|
||||
except ImportError:
|
||||
RESPONSES_API_ENABLED = False
|
||||
|
||||
|
||||
class OpenAIIntegration(Integration):
|
||||
identifier = "openai"
|
||||
@@ -46,13 +70,17 @@ class OpenAIIntegration(Integration):
|
||||
def setup_once():
|
||||
# type: () -> None
|
||||
Completions.create = _wrap_chat_completion_create(Completions.create)
|
||||
Embeddings.create = _wrap_embeddings_create(Embeddings.create)
|
||||
|
||||
AsyncCompletions.create = _wrap_async_chat_completion_create(
|
||||
AsyncCompletions.create
|
||||
)
|
||||
|
||||
Embeddings.create = _wrap_embeddings_create(Embeddings.create)
|
||||
AsyncEmbeddings.create = _wrap_async_embeddings_create(AsyncEmbeddings.create)
|
||||
|
||||
if RESPONSES_API_ENABLED:
|
||||
Responses.create = _wrap_responses_create(Responses.create)
|
||||
AsyncResponses.create = _wrap_async_responses_create(AsyncResponses.create)
|
||||
|
||||
def count_tokens(self, s):
|
||||
# type: (OpenAIIntegration, str) -> int
|
||||
if self.tiktoken_encoding is not None:
|
||||
@@ -60,8 +88,16 @@ class OpenAIIntegration(Integration):
|
||||
return 0
|
||||
|
||||
|
||||
def _capture_exception(exc):
|
||||
# type: (Any) -> None
|
||||
def _capture_exception(exc, manual_span_cleanup=True):
|
||||
# type: (Any, bool) -> None
|
||||
# Close an eventually open span
|
||||
# We need to do this by hand because we are not using the start_span context manager
|
||||
current_span = sentry_sdk.get_current_span()
|
||||
set_span_errored(current_span)
|
||||
|
||||
if manual_span_cleanup and current_span is not None:
|
||||
current_span.__exit__(None, None, None)
|
||||
|
||||
event, hint = event_from_exception(
|
||||
exc,
|
||||
client_options=sentry_sdk.get_client().options,
|
||||
@@ -70,52 +106,316 @@ def _capture_exception(exc):
|
||||
sentry_sdk.capture_event(event, hint=hint)
|
||||
|
||||
|
||||
def _calculate_chat_completion_usage(
|
||||
def _get_usage(usage, names):
|
||||
# type: (Any, List[str]) -> int
|
||||
for name in names:
|
||||
if hasattr(usage, name) and isinstance(getattr(usage, name), int):
|
||||
return getattr(usage, name)
|
||||
return 0
|
||||
|
||||
|
||||
def _calculate_token_usage(
|
||||
messages, response, span, streaming_message_responses, count_tokens
|
||||
):
|
||||
# type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]], Callable[..., Any]) -> None
|
||||
completion_tokens = 0 # type: Optional[int]
|
||||
prompt_tokens = 0 # type: Optional[int]
|
||||
# type: (Optional[Iterable[ChatCompletionMessageParam]], Any, Span, Optional[List[str]], Callable[..., Any]) -> None
|
||||
input_tokens = 0 # type: Optional[int]
|
||||
input_tokens_cached = 0 # type: Optional[int]
|
||||
output_tokens = 0 # type: Optional[int]
|
||||
output_tokens_reasoning = 0 # type: Optional[int]
|
||||
total_tokens = 0 # type: Optional[int]
|
||||
|
||||
if hasattr(response, "usage"):
|
||||
if hasattr(response.usage, "completion_tokens") and isinstance(
|
||||
response.usage.completion_tokens, int
|
||||
):
|
||||
completion_tokens = response.usage.completion_tokens
|
||||
if hasattr(response.usage, "prompt_tokens") and isinstance(
|
||||
response.usage.prompt_tokens, int
|
||||
):
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
if hasattr(response.usage, "total_tokens") and isinstance(
|
||||
response.usage.total_tokens, int
|
||||
):
|
||||
total_tokens = response.usage.total_tokens
|
||||
input_tokens = _get_usage(response.usage, ["input_tokens", "prompt_tokens"])
|
||||
if hasattr(response.usage, "input_tokens_details"):
|
||||
input_tokens_cached = _get_usage(
|
||||
response.usage.input_tokens_details, ["cached_tokens"]
|
||||
)
|
||||
|
||||
if prompt_tokens == 0:
|
||||
for message in messages:
|
||||
if "content" in message:
|
||||
prompt_tokens += count_tokens(message["content"])
|
||||
output_tokens = _get_usage(
|
||||
response.usage, ["output_tokens", "completion_tokens"]
|
||||
)
|
||||
if hasattr(response.usage, "output_tokens_details"):
|
||||
output_tokens_reasoning = _get_usage(
|
||||
response.usage.output_tokens_details, ["reasoning_tokens"]
|
||||
)
|
||||
|
||||
if completion_tokens == 0:
|
||||
total_tokens = _get_usage(response.usage, ["total_tokens"])
|
||||
|
||||
# Manually count tokens
|
||||
if input_tokens == 0:
|
||||
for message in messages or []:
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
input_tokens += count_tokens(message["content"])
|
||||
elif isinstance(message, str):
|
||||
input_tokens += count_tokens(message)
|
||||
|
||||
if output_tokens == 0:
|
||||
if streaming_message_responses is not None:
|
||||
for message in streaming_message_responses:
|
||||
completion_tokens += count_tokens(message)
|
||||
output_tokens += count_tokens(message)
|
||||
elif hasattr(response, "choices"):
|
||||
for choice in response.choices:
|
||||
if hasattr(choice, "message"):
|
||||
completion_tokens += count_tokens(choice.message)
|
||||
output_tokens += count_tokens(choice.message)
|
||||
|
||||
if prompt_tokens == 0:
|
||||
prompt_tokens = None
|
||||
if completion_tokens == 0:
|
||||
completion_tokens = None
|
||||
if total_tokens == 0:
|
||||
total_tokens = None
|
||||
record_token_usage(span, prompt_tokens, completion_tokens, total_tokens)
|
||||
# Do not set token data if it is 0
|
||||
input_tokens = input_tokens or None
|
||||
input_tokens_cached = input_tokens_cached or None
|
||||
output_tokens = output_tokens or None
|
||||
output_tokens_reasoning = output_tokens_reasoning or None
|
||||
total_tokens = total_tokens or None
|
||||
|
||||
record_token_usage(
|
||||
span,
|
||||
input_tokens=input_tokens,
|
||||
input_tokens_cached=input_tokens_cached,
|
||||
output_tokens=output_tokens,
|
||||
output_tokens_reasoning=output_tokens_reasoning,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _set_input_data(span, kwargs, operation, integration):
|
||||
# type: (Span, dict[str, Any], str, OpenAIIntegration) -> None
|
||||
# Input messages (the prompt or data sent to the model)
|
||||
messages = kwargs.get("messages")
|
||||
if messages is None:
|
||||
messages = kwargs.get("input")
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [messages]
|
||||
|
||||
if (
|
||||
messages is not None
|
||||
and len(messages) > 0
|
||||
and should_send_default_pii()
|
||||
and integration.include_prompts
|
||||
):
|
||||
normalized_messages = normalize_message_roles(messages)
|
||||
scope = sentry_sdk.get_current_scope()
|
||||
messages_data = truncate_and_annotate_messages(normalized_messages, span, scope)
|
||||
if messages_data is not None:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
|
||||
)
|
||||
|
||||
# Input attributes: Common
|
||||
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "openai")
|
||||
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, operation)
|
||||
|
||||
# Input attributes: Optional
|
||||
kwargs_keys_to_attributes = {
|
||||
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
|
||||
"stream": SPANDATA.GEN_AI_RESPONSE_STREAMING,
|
||||
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
|
||||
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
|
||||
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
|
||||
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
|
||||
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
|
||||
}
|
||||
for key, attribute in kwargs_keys_to_attributes.items():
|
||||
value = kwargs.get(key)
|
||||
|
||||
if value is not None and _is_given(value):
|
||||
set_data_normalized(span, attribute, value)
|
||||
|
||||
# Input attributes: Tools
|
||||
tools = kwargs.get("tools")
|
||||
if tools is not None and _is_given(tools) and len(tools) > 0:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
|
||||
)
|
||||
|
||||
|
||||
def _set_output_data(span, response, kwargs, integration, finish_span=True):
|
||||
# type: (Span, Any, dict[str, Any], OpenAIIntegration, bool) -> None
|
||||
if hasattr(response, "model"):
|
||||
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, response.model)
|
||||
|
||||
# Input messages (the prompt or data sent to the model)
|
||||
# used for the token usage calculation
|
||||
messages = kwargs.get("messages")
|
||||
if messages is None:
|
||||
messages = kwargs.get("input")
|
||||
|
||||
if messages is not None and isinstance(messages, str):
|
||||
messages = [messages]
|
||||
|
||||
if hasattr(response, "choices"):
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
response_text = [choice.message.model_dump() for choice in response.choices]
|
||||
if len(response_text) > 0:
|
||||
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_text)
|
||||
|
||||
_calculate_token_usage(messages, response, span, None, integration.count_tokens)
|
||||
|
||||
if finish_span:
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
elif hasattr(response, "output"):
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
output_messages = {
|
||||
"response": [],
|
||||
"tool": [],
|
||||
} # type: (dict[str, list[Any]])
|
||||
|
||||
for output in response.output:
|
||||
if output.type == "function_call":
|
||||
output_messages["tool"].append(output.dict())
|
||||
elif output.type == "message":
|
||||
for output_message in output.content:
|
||||
try:
|
||||
output_messages["response"].append(output_message.text)
|
||||
except AttributeError:
|
||||
# Unknown output message type, just return the json
|
||||
output_messages["response"].append(output_message.dict())
|
||||
|
||||
if len(output_messages["tool"]) > 0:
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
|
||||
output_messages["tool"],
|
||||
unpack=False,
|
||||
)
|
||||
|
||||
if len(output_messages["response"]) > 0:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_RESPONSE_TEXT, output_messages["response"]
|
||||
)
|
||||
|
||||
_calculate_token_usage(messages, response, span, None, integration.count_tokens)
|
||||
|
||||
if finish_span:
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
elif hasattr(response, "_iterator"):
|
||||
data_buf: list[list[str]] = [] # one for each choice
|
||||
|
||||
old_iterator = response._iterator
|
||||
|
||||
def new_iterator():
|
||||
# type: () -> Iterator[ChatCompletionChunk]
|
||||
count_tokens_manually = True
|
||||
for x in old_iterator:
|
||||
with capture_internal_exceptions():
|
||||
# OpenAI chat completion API
|
||||
if hasattr(x, "choices"):
|
||||
choice_index = 0
|
||||
for choice in x.choices:
|
||||
if hasattr(choice, "delta") and hasattr(
|
||||
choice.delta, "content"
|
||||
):
|
||||
content = choice.delta.content
|
||||
if len(data_buf) <= choice_index:
|
||||
data_buf.append([])
|
||||
data_buf[choice_index].append(content or "")
|
||||
choice_index += 1
|
||||
|
||||
# OpenAI responses API
|
||||
elif hasattr(x, "delta"):
|
||||
if len(data_buf) == 0:
|
||||
data_buf.append([])
|
||||
data_buf[0].append(x.delta or "")
|
||||
|
||||
# OpenAI responses API end of streaming response
|
||||
if RESPONSES_API_ENABLED and isinstance(x, ResponseCompletedEvent):
|
||||
_calculate_token_usage(
|
||||
messages,
|
||||
x.response,
|
||||
span,
|
||||
None,
|
||||
integration.count_tokens,
|
||||
)
|
||||
count_tokens_manually = False
|
||||
|
||||
yield x
|
||||
|
||||
with capture_internal_exceptions():
|
||||
if len(data_buf) > 0:
|
||||
all_responses = ["".join(chunk) for chunk in data_buf]
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
|
||||
)
|
||||
if count_tokens_manually:
|
||||
_calculate_token_usage(
|
||||
messages,
|
||||
response,
|
||||
span,
|
||||
all_responses,
|
||||
integration.count_tokens,
|
||||
)
|
||||
|
||||
if finish_span:
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
async def new_iterator_async():
|
||||
# type: () -> AsyncIterator[ChatCompletionChunk]
|
||||
count_tokens_manually = True
|
||||
async for x in old_iterator:
|
||||
with capture_internal_exceptions():
|
||||
# OpenAI chat completion API
|
||||
if hasattr(x, "choices"):
|
||||
choice_index = 0
|
||||
for choice in x.choices:
|
||||
if hasattr(choice, "delta") and hasattr(
|
||||
choice.delta, "content"
|
||||
):
|
||||
content = choice.delta.content
|
||||
if len(data_buf) <= choice_index:
|
||||
data_buf.append([])
|
||||
data_buf[choice_index].append(content or "")
|
||||
choice_index += 1
|
||||
|
||||
# OpenAI responses API
|
||||
elif hasattr(x, "delta"):
|
||||
if len(data_buf) == 0:
|
||||
data_buf.append([])
|
||||
data_buf[0].append(x.delta or "")
|
||||
|
||||
# OpenAI responses API end of streaming response
|
||||
if RESPONSES_API_ENABLED and isinstance(x, ResponseCompletedEvent):
|
||||
_calculate_token_usage(
|
||||
messages,
|
||||
x.response,
|
||||
span,
|
||||
None,
|
||||
integration.count_tokens,
|
||||
)
|
||||
count_tokens_manually = False
|
||||
|
||||
yield x
|
||||
|
||||
with capture_internal_exceptions():
|
||||
if len(data_buf) > 0:
|
||||
all_responses = ["".join(chunk) for chunk in data_buf]
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
|
||||
)
|
||||
if count_tokens_manually:
|
||||
_calculate_token_usage(
|
||||
messages,
|
||||
response,
|
||||
span,
|
||||
all_responses,
|
||||
integration.count_tokens,
|
||||
)
|
||||
if finish_span:
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
if str(type(response._iterator)) == "<class 'async_generator'>":
|
||||
response._iterator = new_iterator_async()
|
||||
else:
|
||||
response._iterator = new_iterator()
|
||||
else:
|
||||
_calculate_token_usage(messages, response, span, None, integration.count_tokens)
|
||||
if finish_span:
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
|
||||
def _new_chat_completion_common(f, *args, **kwargs):
|
||||
# type: (Any, *Any, **Any) -> Any
|
||||
# type: (Any, Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
|
||||
if integration is None:
|
||||
return f(*args, **kwargs)
|
||||
@@ -130,124 +430,29 @@ def _new_chat_completion_common(f, *args, **kwargs):
|
||||
# invalid call (in all versions), messages must be iterable
|
||||
return f(*args, **kwargs)
|
||||
|
||||
kwargs["messages"] = list(kwargs["messages"])
|
||||
messages = kwargs["messages"]
|
||||
model = kwargs.get("model")
|
||||
streaming = kwargs.get("stream")
|
||||
operation = "chat"
|
||||
|
||||
span = sentry_sdk.start_span(
|
||||
op=consts.OP.OPENAI_CHAT_COMPLETIONS_CREATE,
|
||||
name="Chat Completion",
|
||||
op=consts.OP.GEN_AI_CHAT,
|
||||
name=f"{operation} {model}",
|
||||
origin=OpenAIIntegration.origin,
|
||||
)
|
||||
span.__enter__()
|
||||
|
||||
res = yield f, args, kwargs
|
||||
_set_input_data(span, kwargs, operation, integration)
|
||||
|
||||
with capture_internal_exceptions():
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, messages)
|
||||
response = yield f, args, kwargs
|
||||
|
||||
set_data_normalized(span, SPANDATA.AI_MODEL_ID, model)
|
||||
set_data_normalized(span, SPANDATA.AI_STREAMING, streaming)
|
||||
_set_output_data(span, response, kwargs, integration, finish_span=True)
|
||||
|
||||
if hasattr(res, "choices"):
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
set_data_normalized(
|
||||
span,
|
||||
"ai.responses",
|
||||
list(map(lambda x: x.message, res.choices)),
|
||||
)
|
||||
_calculate_chat_completion_usage(
|
||||
messages, res, span, None, integration.count_tokens
|
||||
)
|
||||
span.__exit__(None, None, None)
|
||||
elif hasattr(res, "_iterator"):
|
||||
data_buf: list[list[str]] = [] # one for each choice
|
||||
|
||||
old_iterator = res._iterator
|
||||
|
||||
def new_iterator():
|
||||
# type: () -> Iterator[ChatCompletionChunk]
|
||||
with capture_internal_exceptions():
|
||||
for x in old_iterator:
|
||||
if hasattr(x, "choices"):
|
||||
choice_index = 0
|
||||
for choice in x.choices:
|
||||
if hasattr(choice, "delta") and hasattr(
|
||||
choice.delta, "content"
|
||||
):
|
||||
content = choice.delta.content
|
||||
if len(data_buf) <= choice_index:
|
||||
data_buf.append([])
|
||||
data_buf[choice_index].append(content or "")
|
||||
choice_index += 1
|
||||
yield x
|
||||
if len(data_buf) > 0:
|
||||
all_responses = list(
|
||||
map(lambda chunk: "".join(chunk), data_buf)
|
||||
)
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.AI_RESPONSES, all_responses
|
||||
)
|
||||
_calculate_chat_completion_usage(
|
||||
messages,
|
||||
res,
|
||||
span,
|
||||
all_responses,
|
||||
integration.count_tokens,
|
||||
)
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
async def new_iterator_async():
|
||||
# type: () -> AsyncIterator[ChatCompletionChunk]
|
||||
with capture_internal_exceptions():
|
||||
async for x in old_iterator:
|
||||
if hasattr(x, "choices"):
|
||||
choice_index = 0
|
||||
for choice in x.choices:
|
||||
if hasattr(choice, "delta") and hasattr(
|
||||
choice.delta, "content"
|
||||
):
|
||||
content = choice.delta.content
|
||||
if len(data_buf) <= choice_index:
|
||||
data_buf.append([])
|
||||
data_buf[choice_index].append(content or "")
|
||||
choice_index += 1
|
||||
yield x
|
||||
if len(data_buf) > 0:
|
||||
all_responses = list(
|
||||
map(lambda chunk: "".join(chunk), data_buf)
|
||||
)
|
||||
if should_send_default_pii() and integration.include_prompts:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.AI_RESPONSES, all_responses
|
||||
)
|
||||
_calculate_chat_completion_usage(
|
||||
messages,
|
||||
res,
|
||||
span,
|
||||
all_responses,
|
||||
integration.count_tokens,
|
||||
)
|
||||
span.__exit__(None, None, None)
|
||||
|
||||
if str(type(res._iterator)) == "<class 'async_generator'>":
|
||||
res._iterator = new_iterator_async()
|
||||
else:
|
||||
res._iterator = new_iterator()
|
||||
|
||||
else:
|
||||
set_data_normalized(span, "unknown_response", True)
|
||||
span.__exit__(None, None, None)
|
||||
return res
|
||||
return response
|
||||
|
||||
|
||||
def _wrap_chat_completion_create(f):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
def _execute_sync(f, *args, **kwargs):
|
||||
# type: (Any, *Any, **Any) -> Any
|
||||
# type: (Any, Any, Any) -> Any
|
||||
gen = _new_chat_completion_common(f, *args, **kwargs)
|
||||
|
||||
try:
|
||||
@@ -268,7 +473,7 @@ def _wrap_chat_completion_create(f):
|
||||
|
||||
@wraps(f)
|
||||
def _sentry_patched_create_sync(*args, **kwargs):
|
||||
# type: (*Any, **Any) -> Any
|
||||
# type: (Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
|
||||
if integration is None or "messages" not in kwargs:
|
||||
# no "messages" means invalid call (in all versions of openai), let it return error
|
||||
@@ -282,7 +487,7 @@ def _wrap_chat_completion_create(f):
|
||||
def _wrap_async_chat_completion_create(f):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
async def _execute_async(f, *args, **kwargs):
|
||||
# type: (Any, *Any, **Any) -> Any
|
||||
# type: (Any, Any, Any) -> Any
|
||||
gen = _new_chat_completion_common(f, *args, **kwargs)
|
||||
|
||||
try:
|
||||
@@ -303,7 +508,7 @@ def _wrap_async_chat_completion_create(f):
|
||||
|
||||
@wraps(f)
|
||||
async def _sentry_patched_create_async(*args, **kwargs):
|
||||
# type: (*Any, **Any) -> Any
|
||||
# type: (Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
|
||||
if integration is None or "messages" not in kwargs:
|
||||
# no "messages" means invalid call (in all versions of openai), let it return error
|
||||
@@ -315,48 +520,24 @@ def _wrap_async_chat_completion_create(f):
|
||||
|
||||
|
||||
def _new_embeddings_create_common(f, *args, **kwargs):
|
||||
# type: (Any, *Any, **Any) -> Any
|
||||
# type: (Any, Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
|
||||
if integration is None:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
model = kwargs.get("model")
|
||||
operation = "embeddings"
|
||||
|
||||
with sentry_sdk.start_span(
|
||||
op=consts.OP.OPENAI_EMBEDDINGS_CREATE,
|
||||
description="OpenAI Embedding Creation",
|
||||
op=consts.OP.GEN_AI_EMBEDDINGS,
|
||||
name=f"{operation} {model}",
|
||||
origin=OpenAIIntegration.origin,
|
||||
) as span:
|
||||
if "input" in kwargs and (
|
||||
should_send_default_pii() and integration.include_prompts
|
||||
):
|
||||
if isinstance(kwargs["input"], str):
|
||||
set_data_normalized(span, "ai.input_messages", [kwargs["input"]])
|
||||
elif (
|
||||
isinstance(kwargs["input"], list)
|
||||
and len(kwargs["input"]) > 0
|
||||
and isinstance(kwargs["input"][0], str)
|
||||
):
|
||||
set_data_normalized(span, "ai.input_messages", kwargs["input"])
|
||||
if "model" in kwargs:
|
||||
set_data_normalized(span, "ai.model_id", kwargs["model"])
|
||||
_set_input_data(span, kwargs, operation, integration)
|
||||
|
||||
response = yield f, args, kwargs
|
||||
|
||||
prompt_tokens = 0
|
||||
total_tokens = 0
|
||||
if hasattr(response, "usage"):
|
||||
if hasattr(response.usage, "prompt_tokens") and isinstance(
|
||||
response.usage.prompt_tokens, int
|
||||
):
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
if hasattr(response.usage, "total_tokens") and isinstance(
|
||||
response.usage.total_tokens, int
|
||||
):
|
||||
total_tokens = response.usage.total_tokens
|
||||
|
||||
if prompt_tokens == 0:
|
||||
prompt_tokens = integration.count_tokens(kwargs["input"] or "")
|
||||
|
||||
record_token_usage(span, prompt_tokens, None, total_tokens or prompt_tokens)
|
||||
_set_output_data(span, response, kwargs, integration, finish_span=False)
|
||||
|
||||
return response
|
||||
|
||||
@@ -364,9 +545,102 @@ def _new_embeddings_create_common(f, *args, **kwargs):
|
||||
def _wrap_embeddings_create(f):
|
||||
# type: (Any) -> Any
|
||||
def _execute_sync(f, *args, **kwargs):
|
||||
# type: (Any, *Any, **Any) -> Any
|
||||
# type: (Any, Any, Any) -> Any
|
||||
gen = _new_embeddings_create_common(f, *args, **kwargs)
|
||||
|
||||
try:
|
||||
f, args, kwargs = next(gen)
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
try:
|
||||
try:
|
||||
result = f(*args, **kwargs)
|
||||
except Exception as e:
|
||||
_capture_exception(e, manual_span_cleanup=False)
|
||||
raise e from None
|
||||
|
||||
return gen.send(result)
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
@wraps(f)
|
||||
def _sentry_patched_create_sync(*args, **kwargs):
|
||||
# type: (Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
|
||||
if integration is None:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return _execute_sync(f, *args, **kwargs)
|
||||
|
||||
return _sentry_patched_create_sync
|
||||
|
||||
|
||||
def _wrap_async_embeddings_create(f):
|
||||
# type: (Any) -> Any
|
||||
async def _execute_async(f, *args, **kwargs):
|
||||
# type: (Any, Any, Any) -> Any
|
||||
gen = _new_embeddings_create_common(f, *args, **kwargs)
|
||||
|
||||
try:
|
||||
f, args, kwargs = next(gen)
|
||||
except StopIteration as e:
|
||||
return await e.value
|
||||
|
||||
try:
|
||||
try:
|
||||
result = await f(*args, **kwargs)
|
||||
except Exception as e:
|
||||
_capture_exception(e, manual_span_cleanup=False)
|
||||
raise e from None
|
||||
|
||||
return gen.send(result)
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
@wraps(f)
|
||||
async def _sentry_patched_create_async(*args, **kwargs):
|
||||
# type: (Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
|
||||
if integration is None:
|
||||
return await f(*args, **kwargs)
|
||||
|
||||
return await _execute_async(f, *args, **kwargs)
|
||||
|
||||
return _sentry_patched_create_async
|
||||
|
||||
|
||||
def _new_responses_create_common(f, *args, **kwargs):
|
||||
# type: (Any, Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
|
||||
if integration is None:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
model = kwargs.get("model")
|
||||
operation = "responses"
|
||||
|
||||
span = sentry_sdk.start_span(
|
||||
op=consts.OP.GEN_AI_RESPONSES,
|
||||
name=f"{operation} {model}",
|
||||
origin=OpenAIIntegration.origin,
|
||||
)
|
||||
span.__enter__()
|
||||
|
||||
_set_input_data(span, kwargs, operation, integration)
|
||||
|
||||
response = yield f, args, kwargs
|
||||
|
||||
_set_output_data(span, response, kwargs, integration, finish_span=True)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _wrap_responses_create(f):
|
||||
# type: (Any) -> Any
|
||||
def _execute_sync(f, *args, **kwargs):
|
||||
# type: (Any, Any, Any) -> Any
|
||||
gen = _new_responses_create_common(f, *args, **kwargs)
|
||||
|
||||
try:
|
||||
f, args, kwargs = next(gen)
|
||||
except StopIteration as e:
|
||||
@@ -385,7 +659,7 @@ def _wrap_embeddings_create(f):
|
||||
|
||||
@wraps(f)
|
||||
def _sentry_patched_create_sync(*args, **kwargs):
|
||||
# type: (*Any, **Any) -> Any
|
||||
# type: (Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
|
||||
if integration is None:
|
||||
return f(*args, **kwargs)
|
||||
@@ -395,11 +669,11 @@ def _wrap_embeddings_create(f):
|
||||
return _sentry_patched_create_sync
|
||||
|
||||
|
||||
def _wrap_async_embeddings_create(f):
|
||||
def _wrap_async_responses_create(f):
|
||||
# type: (Any) -> Any
|
||||
async def _execute_async(f, *args, **kwargs):
|
||||
# type: (Any, *Any, **Any) -> Any
|
||||
gen = _new_embeddings_create_common(f, *args, **kwargs)
|
||||
# type: (Any, Any, Any) -> Any
|
||||
gen = _new_responses_create_common(f, *args, **kwargs)
|
||||
|
||||
try:
|
||||
f, args, kwargs = next(gen)
|
||||
@@ -418,12 +692,24 @@ def _wrap_async_embeddings_create(f):
|
||||
return e.value
|
||||
|
||||
@wraps(f)
|
||||
async def _sentry_patched_create_async(*args, **kwargs):
|
||||
# type: (*Any, **Any) -> Any
|
||||
async def _sentry_patched_responses_async(*args, **kwargs):
|
||||
# type: (Any, Any) -> Any
|
||||
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
|
||||
if integration is None:
|
||||
return await f(*args, **kwargs)
|
||||
|
||||
return await _execute_async(f, *args, **kwargs)
|
||||
|
||||
return _sentry_patched_create_async
|
||||
return _sentry_patched_responses_async
|
||||
|
||||
|
||||
def _is_given(obj):
|
||||
# type: (Any) -> bool
|
||||
"""
|
||||
Check for givenness safely across different openai versions.
|
||||
"""
|
||||
if NotGiven is not None and isinstance(obj, NotGiven):
|
||||
return False
|
||||
if Omit is not None and isinstance(obj, Omit):
|
||||
return False
|
||||
return True
|
||||
|
||||
+55
@@ -0,0 +1,55 @@
|
||||
from sentry_sdk.integrations import DidNotEnable, Integration
|
||||
|
||||
from .patches import (
|
||||
_create_get_model_wrapper,
|
||||
_create_get_all_tools_wrapper,
|
||||
_create_run_wrapper,
|
||||
_patch_agent_run,
|
||||
_patch_error_tracing,
|
||||
)
|
||||
|
||||
try:
|
||||
import agents
|
||||
|
||||
except ImportError:
|
||||
raise DidNotEnable("OpenAI Agents not installed")
|
||||
|
||||
|
||||
def _patch_runner():
|
||||
# type: () -> None
|
||||
# Create the root span for one full agent run (including eventual handoffs)
|
||||
# Note agents.run.DEFAULT_AGENT_RUNNER.run_sync is a wrapper around
|
||||
# agents.run.DEFAULT_AGENT_RUNNER.run. It does not need to be wrapped separately.
|
||||
# TODO-anton: Also patch streaming runner: agents.Runner.run_streamed
|
||||
agents.run.DEFAULT_AGENT_RUNNER.run = _create_run_wrapper(
|
||||
agents.run.DEFAULT_AGENT_RUNNER.run
|
||||
)
|
||||
|
||||
# Creating the actual spans for each agent run.
|
||||
_patch_agent_run()
|
||||
|
||||
|
||||
def _patch_model():
|
||||
# type: () -> None
|
||||
agents.run.AgentRunner._get_model = classmethod(
|
||||
_create_get_model_wrapper(agents.run.AgentRunner._get_model),
|
||||
)
|
||||
|
||||
|
||||
def _patch_tools():
|
||||
# type: () -> None
|
||||
agents.run.AgentRunner._get_all_tools = classmethod(
|
||||
_create_get_all_tools_wrapper(agents.run.AgentRunner._get_all_tools),
|
||||
)
|
||||
|
||||
|
||||
class OpenAIAgentsIntegration(Integration):
|
||||
identifier = "openai_agents"
|
||||
|
||||
@staticmethod
|
||||
def setup_once():
|
||||
# type: () -> None
|
||||
_patch_error_tracing()
|
||||
_patch_tools()
|
||||
_patch_model()
|
||||
_patch_runner()
|
||||
+1
@@ -0,0 +1 @@
|
||||
SPAN_ORIGIN = "auto.ai.openai_agents"
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
from .models import _create_get_model_wrapper # noqa: F401
|
||||
from .tools import _create_get_all_tools_wrapper # noqa: F401
|
||||
from .runner import _create_run_wrapper # noqa: F401
|
||||
from .agent_run import _patch_agent_run # noqa: F401
|
||||
from .error_tracing import _patch_error_tracing # noqa: F401
|
||||
+140
@@ -0,0 +1,140 @@
|
||||
from functools import wraps
|
||||
|
||||
from sentry_sdk.integrations import DidNotEnable
|
||||
from ..spans import invoke_agent_span, update_invoke_agent_span, handoff_span
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
import agents
|
||||
except ImportError:
|
||||
raise DidNotEnable("OpenAI Agents not installed")
|
||||
|
||||
|
||||
def _patch_agent_run():
|
||||
# type: () -> None
|
||||
"""
|
||||
Patches AgentRunner methods to create agent invocation spans.
|
||||
This directly patches the execution flow to track when agents start and stop.
|
||||
"""
|
||||
|
||||
# Store original methods
|
||||
original_run_single_turn = agents.run.AgentRunner._run_single_turn
|
||||
original_execute_handoffs = agents._run_impl.RunImpl.execute_handoffs
|
||||
original_execute_final_output = agents._run_impl.RunImpl.execute_final_output
|
||||
|
||||
def _start_invoke_agent_span(context_wrapper, agent, kwargs):
|
||||
# type: (agents.RunContextWrapper, agents.Agent, dict[str, Any]) -> None
|
||||
"""Start an agent invocation span"""
|
||||
# Store the agent on the context wrapper so we can access it later
|
||||
context_wrapper._sentry_current_agent = agent
|
||||
invoke_agent_span(context_wrapper, agent, kwargs)
|
||||
|
||||
def _end_invoke_agent_span(context_wrapper, agent, output=None):
|
||||
# type: (agents.RunContextWrapper, agents.Agent, Optional[Any]) -> None
|
||||
"""End the agent invocation span"""
|
||||
# Clear the stored agent
|
||||
if hasattr(context_wrapper, "_sentry_current_agent"):
|
||||
delattr(context_wrapper, "_sentry_current_agent")
|
||||
|
||||
update_invoke_agent_span(context_wrapper, agent, output)
|
||||
|
||||
def _has_active_agent_span(context_wrapper):
|
||||
# type: (agents.RunContextWrapper) -> bool
|
||||
"""Check if there's an active agent span for this context"""
|
||||
return getattr(context_wrapper, "_sentry_current_agent", None) is not None
|
||||
|
||||
def _get_current_agent(context_wrapper):
|
||||
# type: (agents.RunContextWrapper) -> Optional[agents.Agent]
|
||||
"""Get the current agent from context wrapper"""
|
||||
return getattr(context_wrapper, "_sentry_current_agent", None)
|
||||
|
||||
@wraps(
|
||||
original_run_single_turn.__func__
|
||||
if hasattr(original_run_single_turn, "__func__")
|
||||
else original_run_single_turn
|
||||
)
|
||||
async def patched_run_single_turn(cls, *args, **kwargs):
|
||||
# type: (agents.Runner, *Any, **Any) -> Any
|
||||
"""Patched _run_single_turn that creates agent invocation spans"""
|
||||
agent = kwargs.get("agent")
|
||||
context_wrapper = kwargs.get("context_wrapper")
|
||||
should_run_agent_start_hooks = kwargs.get("should_run_agent_start_hooks")
|
||||
|
||||
# Start agent span when agent starts (but only once per agent)
|
||||
if should_run_agent_start_hooks and agent and context_wrapper:
|
||||
# End any existing span for a different agent
|
||||
if _has_active_agent_span(context_wrapper):
|
||||
current_agent = _get_current_agent(context_wrapper)
|
||||
if current_agent and current_agent != agent:
|
||||
_end_invoke_agent_span(context_wrapper, current_agent)
|
||||
|
||||
_start_invoke_agent_span(context_wrapper, agent, kwargs)
|
||||
|
||||
# Call original method with all the correct parameters
|
||||
result = await original_run_single_turn(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
@wraps(
|
||||
original_execute_handoffs.__func__
|
||||
if hasattr(original_execute_handoffs, "__func__")
|
||||
else original_execute_handoffs
|
||||
)
|
||||
async def patched_execute_handoffs(cls, *args, **kwargs):
|
||||
# type: (agents.Runner, *Any, **Any) -> Any
|
||||
"""Patched execute_handoffs that creates handoff spans and ends agent span for handoffs"""
|
||||
|
||||
context_wrapper = kwargs.get("context_wrapper")
|
||||
run_handoffs = kwargs.get("run_handoffs")
|
||||
agent = kwargs.get("agent")
|
||||
|
||||
# Create Sentry handoff span for the first handoff (agents library only processes the first one)
|
||||
if run_handoffs:
|
||||
first_handoff = run_handoffs[0]
|
||||
handoff_agent_name = first_handoff.handoff.agent_name
|
||||
handoff_span(context_wrapper, agent, handoff_agent_name)
|
||||
|
||||
# Call original method with all parameters
|
||||
try:
|
||||
result = await original_execute_handoffs(*args, **kwargs)
|
||||
|
||||
finally:
|
||||
# End span for current agent after handoff processing is complete
|
||||
if agent and context_wrapper and _has_active_agent_span(context_wrapper):
|
||||
_end_invoke_agent_span(context_wrapper, agent)
|
||||
|
||||
return result
|
||||
|
||||
@wraps(
|
||||
original_execute_final_output.__func__
|
||||
if hasattr(original_execute_final_output, "__func__")
|
||||
else original_execute_final_output
|
||||
)
|
||||
async def patched_execute_final_output(cls, *args, **kwargs):
|
||||
# type: (agents.Runner, *Any, **Any) -> Any
|
||||
"""Patched execute_final_output that ends agent span for final outputs"""
|
||||
|
||||
agent = kwargs.get("agent")
|
||||
context_wrapper = kwargs.get("context_wrapper")
|
||||
final_output = kwargs.get("final_output")
|
||||
|
||||
# Call original method with all parameters
|
||||
try:
|
||||
result = await original_execute_final_output(*args, **kwargs)
|
||||
finally:
|
||||
# End span for current agent after final output processing is complete
|
||||
if agent and context_wrapper and _has_active_agent_span(context_wrapper):
|
||||
_end_invoke_agent_span(context_wrapper, agent, final_output)
|
||||
|
||||
return result
|
||||
|
||||
# Apply patches
|
||||
agents.run.AgentRunner._run_single_turn = classmethod(patched_run_single_turn)
|
||||
agents._run_impl.RunImpl.execute_handoffs = classmethod(patched_execute_handoffs)
|
||||
agents._run_impl.RunImpl.execute_final_output = classmethod(
|
||||
patched_execute_final_output
|
||||
)
|
||||
+77
@@ -0,0 +1,77 @@
|
||||
from functools import wraps
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.consts import SPANSTATUS
|
||||
from sentry_sdk.tracing_utils import set_span_errored
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
def _patch_error_tracing():
|
||||
# type: () -> None
|
||||
"""
|
||||
Patches agents error tracing function to inject our span error logic
|
||||
when a tool execution fails.
|
||||
|
||||
In newer versions, the function is at: agents.util._error_tracing.attach_error_to_current_span
|
||||
In older versions, it was at: agents._utils.attach_error_to_current_span
|
||||
|
||||
This works even when the module or function doesn't exist.
|
||||
"""
|
||||
error_tracing_module = None
|
||||
|
||||
# Try newer location first (agents.util._error_tracing)
|
||||
try:
|
||||
from agents.util import _error_tracing
|
||||
|
||||
error_tracing_module = _error_tracing
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
# Try older location (agents._utils)
|
||||
if error_tracing_module is None:
|
||||
try:
|
||||
import agents._utils
|
||||
|
||||
error_tracing_module = agents._utils
|
||||
except (ImportError, AttributeError):
|
||||
# Module doesn't exist in either location, nothing to patch
|
||||
return
|
||||
|
||||
# Check if the function exists
|
||||
if not hasattr(error_tracing_module, "attach_error_to_current_span"):
|
||||
return
|
||||
|
||||
original_attach_error = error_tracing_module.attach_error_to_current_span
|
||||
|
||||
@wraps(original_attach_error)
|
||||
def sentry_attach_error_to_current_span(error, *args, **kwargs):
|
||||
# type: (Any, *Any, **Any) -> Any
|
||||
"""
|
||||
Wraps agents' error attachment to also set Sentry span status to error.
|
||||
This allows us to properly track tool execution errors even though
|
||||
the agents library swallows exceptions.
|
||||
"""
|
||||
# Set the current Sentry span to errored
|
||||
current_span = sentry_sdk.get_current_span()
|
||||
if current_span is not None:
|
||||
set_span_errored(current_span)
|
||||
current_span.set_data("span.status", "error")
|
||||
|
||||
# Optionally capture the error details if we have them
|
||||
if hasattr(error, "__class__"):
|
||||
current_span.set_data("error.type", error.__class__.__name__)
|
||||
if hasattr(error, "__str__"):
|
||||
error_message = str(error)
|
||||
if error_message:
|
||||
current_span.set_data("error.message", error_message)
|
||||
|
||||
# Call the original function
|
||||
return original_attach_error(error, *args, **kwargs)
|
||||
|
||||
error_tracing_module.attach_error_to_current_span = (
|
||||
sentry_attach_error_to_current_span
|
||||
)
|
||||
+50
@@ -0,0 +1,50 @@
|
||||
from functools import wraps
|
||||
|
||||
from sentry_sdk.integrations import DidNotEnable
|
||||
|
||||
from ..spans import ai_client_span, update_ai_client_span
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
try:
|
||||
import agents
|
||||
except ImportError:
|
||||
raise DidNotEnable("OpenAI Agents not installed")
|
||||
|
||||
|
||||
def _create_get_model_wrapper(original_get_model):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
"""
|
||||
Wraps the agents.Runner._get_model method to wrap the get_response method of the model to create a AI client span.
|
||||
"""
|
||||
|
||||
@wraps(
|
||||
original_get_model.__func__
|
||||
if hasattr(original_get_model, "__func__")
|
||||
else original_get_model
|
||||
)
|
||||
def wrapped_get_model(cls, agent, run_config):
|
||||
# type: (agents.Runner, agents.Agent, agents.RunConfig) -> agents.Model
|
||||
|
||||
model = original_get_model(agent, run_config)
|
||||
original_get_response = model.get_response
|
||||
|
||||
@wraps(original_get_response)
|
||||
async def wrapped_get_response(*args, **kwargs):
|
||||
# type: (*Any, **Any) -> Any
|
||||
with ai_client_span(agent, kwargs) as span:
|
||||
result = await original_get_response(*args, **kwargs)
|
||||
|
||||
update_ai_client_span(span, agent, kwargs, result)
|
||||
|
||||
return result
|
||||
|
||||
model.get_response = wrapped_get_response
|
||||
|
||||
return model
|
||||
|
||||
return wrapped_get_model
|
||||
+45
@@ -0,0 +1,45 @@
|
||||
from functools import wraps
|
||||
|
||||
import sentry_sdk
|
||||
|
||||
from ..spans import agent_workflow_span
|
||||
from ..utils import _capture_exception
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
def _create_run_wrapper(original_func):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
"""
|
||||
Wraps the agents.Runner.run methods to create a root span for the agent workflow runs.
|
||||
|
||||
Note agents.Runner.run_sync() is a wrapper around agents.Runner.run(),
|
||||
so it does not need to be wrapped separately.
|
||||
"""
|
||||
|
||||
@wraps(original_func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# type: (*Any, **Any) -> Any
|
||||
# Isolate each workflow so that when agents are run in asyncio tasks they
|
||||
# don't touch each other's scopes
|
||||
with sentry_sdk.isolation_scope():
|
||||
agent = args[0]
|
||||
with agent_workflow_span(agent):
|
||||
result = None
|
||||
try:
|
||||
result = await original_func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as exc:
|
||||
_capture_exception(exc)
|
||||
|
||||
# It could be that there is a "invoke agent" span still open
|
||||
current_span = sentry_sdk.get_current_span()
|
||||
if current_span is not None and current_span.timestamp is None:
|
||||
current_span.__exit__(None, None, None)
|
||||
|
||||
raise exc from None
|
||||
|
||||
return wrapper
|
||||
+77
@@ -0,0 +1,77 @@
|
||||
from functools import wraps
|
||||
|
||||
from sentry_sdk.integrations import DidNotEnable
|
||||
|
||||
from ..spans import execute_tool_span, update_execute_tool_span
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable
|
||||
|
||||
try:
|
||||
import agents
|
||||
except ImportError:
|
||||
raise DidNotEnable("OpenAI Agents not installed")
|
||||
|
||||
|
||||
def _create_get_all_tools_wrapper(original_get_all_tools):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
"""
|
||||
Wraps the agents.Runner._get_all_tools method of the Runner class to wrap all function tools with Sentry instrumentation.
|
||||
"""
|
||||
|
||||
@wraps(
|
||||
original_get_all_tools.__func__
|
||||
if hasattr(original_get_all_tools, "__func__")
|
||||
else original_get_all_tools
|
||||
)
|
||||
async def wrapped_get_all_tools(cls, agent, context_wrapper):
|
||||
# type: (agents.Runner, agents.Agent, agents.RunContextWrapper) -> list[agents.Tool]
|
||||
|
||||
# Get the original tools
|
||||
tools = await original_get_all_tools(agent, context_wrapper)
|
||||
|
||||
wrapped_tools = []
|
||||
for tool in tools:
|
||||
# Wrap only the function tools (for now)
|
||||
if tool.__class__.__name__ != "FunctionTool":
|
||||
wrapped_tools.append(tool)
|
||||
continue
|
||||
|
||||
# Create a new FunctionTool with our wrapped invoke method
|
||||
original_on_invoke = tool.on_invoke_tool
|
||||
|
||||
def create_wrapped_invoke(current_tool, current_on_invoke):
|
||||
# type: (agents.Tool, Callable[..., Any]) -> Callable[..., Any]
|
||||
@wraps(current_on_invoke)
|
||||
async def sentry_wrapped_on_invoke_tool(*args, **kwargs):
|
||||
# type: (*Any, **Any) -> Any
|
||||
with execute_tool_span(current_tool, *args, **kwargs) as span:
|
||||
# We can not capture exceptions in tool execution here because
|
||||
# `_on_invoke_tool` is swallowing the exception here:
|
||||
# https://github.com/openai/openai-agents-python/blob/main/src/agents/tool.py#L409-L422
|
||||
# And because function_tool is a decorator with `default_tool_error_function` set as a default parameter
|
||||
# I was unable to monkey patch it because those are evaluated at module import time
|
||||
# and the SDK is too late to patch it. I was also unable to patch `_on_invoke_tool_impl`
|
||||
# because it is nested inside this import time code. As if they made it hard to patch on purpose...
|
||||
result = await current_on_invoke(*args, **kwargs)
|
||||
update_execute_tool_span(span, agent, current_tool, result)
|
||||
|
||||
return result
|
||||
|
||||
return sentry_wrapped_on_invoke_tool
|
||||
|
||||
wrapped_tool = agents.FunctionTool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
params_json_schema=tool.params_json_schema,
|
||||
on_invoke_tool=create_wrapped_invoke(tool, original_on_invoke),
|
||||
strict_json_schema=tool.strict_json_schema,
|
||||
is_enabled=tool.is_enabled,
|
||||
)
|
||||
wrapped_tools.append(wrapped_tool)
|
||||
|
||||
return wrapped_tools
|
||||
|
||||
return wrapped_get_all_tools
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
from .agent_workflow import agent_workflow_span # noqa: F401
|
||||
from .ai_client import ai_client_span, update_ai_client_span # noqa: F401
|
||||
from .execute_tool import execute_tool_span, update_execute_tool_span # noqa: F401
|
||||
from .handoff import handoff_span # noqa: F401
|
||||
from .invoke_agent import invoke_agent_span, update_invoke_agent_span # noqa: F401
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
import sentry_sdk
|
||||
from sentry_sdk.ai.utils import get_start_span_function
|
||||
|
||||
from ..consts import SPAN_ORIGIN
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import agents
|
||||
|
||||
|
||||
def agent_workflow_span(agent):
|
||||
# type: (agents.Agent) -> sentry_sdk.tracing.Span
|
||||
|
||||
# Create a transaction or a span if an transaction is already active
|
||||
span = get_start_span_function()(
|
||||
name=f"{agent.name} workflow",
|
||||
origin=SPAN_ORIGIN,
|
||||
)
|
||||
|
||||
return span
|
||||
+42
@@ -0,0 +1,42 @@
|
||||
import sentry_sdk
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
|
||||
from ..consts import SPAN_ORIGIN
|
||||
from ..utils import (
|
||||
_set_agent_data,
|
||||
_set_input_data,
|
||||
_set_output_data,
|
||||
_set_usage_data,
|
||||
_create_mcp_execute_tool_spans,
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agents import Agent
|
||||
from typing import Any
|
||||
|
||||
|
||||
def ai_client_span(agent, get_response_kwargs):
|
||||
# type: (Agent, dict[str, Any]) -> sentry_sdk.tracing.Span
|
||||
# TODO-anton: implement other types of operations. Now "chat" is hardcoded.
|
||||
model_name = agent.model.model if hasattr(agent.model, "model") else agent.model
|
||||
span = sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_CHAT,
|
||||
description=f"chat {model_name}",
|
||||
origin=SPAN_ORIGIN,
|
||||
)
|
||||
# TODO-anton: remove hardcoded stuff and replace something that also works for embedding and so on
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
|
||||
|
||||
_set_agent_data(span, agent)
|
||||
|
||||
return span
|
||||
|
||||
|
||||
def update_ai_client_span(span, agent, get_response_kwargs, result):
|
||||
# type: (sentry_sdk.tracing.Span, Agent, dict[str, Any], Any) -> None
|
||||
_set_usage_data(span, result.usage)
|
||||
_set_input_data(span, get_response_kwargs)
|
||||
_set_output_data(span, result)
|
||||
_create_mcp_execute_tool_spans(span, result)
|
||||
+48
@@ -0,0 +1,48 @@
|
||||
import sentry_sdk
|
||||
from sentry_sdk.consts import OP, SPANDATA, SPANSTATUS
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
|
||||
from ..consts import SPAN_ORIGIN
|
||||
from ..utils import _set_agent_data
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import agents
|
||||
from typing import Any
|
||||
|
||||
|
||||
def execute_tool_span(tool, *args, **kwargs):
|
||||
# type: (agents.Tool, *Any, **Any) -> sentry_sdk.tracing.Span
|
||||
span = sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_EXECUTE_TOOL,
|
||||
name=f"execute_tool {tool.name}",
|
||||
origin=SPAN_ORIGIN,
|
||||
)
|
||||
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "execute_tool")
|
||||
|
||||
if tool.__class__.__name__ == "FunctionTool":
|
||||
span.set_data(SPANDATA.GEN_AI_TOOL_TYPE, "function")
|
||||
|
||||
span.set_data(SPANDATA.GEN_AI_TOOL_NAME, tool.name)
|
||||
span.set_data(SPANDATA.GEN_AI_TOOL_DESCRIPTION, tool.description)
|
||||
|
||||
if should_send_default_pii():
|
||||
input = args[1]
|
||||
span.set_data(SPANDATA.GEN_AI_TOOL_INPUT, input)
|
||||
|
||||
return span
|
||||
|
||||
|
||||
def update_execute_tool_span(span, agent, tool, result):
|
||||
# type: (sentry_sdk.tracing.Span, agents.Agent, agents.Tool, Any) -> None
|
||||
_set_agent_data(span, agent)
|
||||
|
||||
if isinstance(result, str) and result.startswith(
|
||||
"An error occurred while running the tool"
|
||||
):
|
||||
span.set_status(SPANSTATUS.ERROR)
|
||||
|
||||
if should_send_default_pii():
|
||||
span.set_data(SPANDATA.GEN_AI_TOOL_OUTPUT, result)
|
||||
+19
@@ -0,0 +1,19 @@
|
||||
import sentry_sdk
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
|
||||
from ..consts import SPAN_ORIGIN
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import agents
|
||||
|
||||
|
||||
def handoff_span(context, from_agent, to_agent_name):
|
||||
# type: (agents.RunContextWrapper, agents.Agent, str) -> None
|
||||
with sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_HANDOFF,
|
||||
name=f"handoff from {from_agent.name} to {to_agent_name}",
|
||||
origin=SPAN_ORIGIN,
|
||||
) as span:
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "handoff")
|
||||
+86
@@ -0,0 +1,86 @@
|
||||
import sentry_sdk
|
||||
from sentry_sdk.ai.utils import (
|
||||
get_start_span_function,
|
||||
set_data_normalized,
|
||||
normalize_message_roles,
|
||||
)
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.utils import safe_serialize
|
||||
|
||||
from ..consts import SPAN_ORIGIN
|
||||
from ..utils import _set_agent_data
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import agents
|
||||
from typing import Any
|
||||
|
||||
|
||||
def invoke_agent_span(context, agent, kwargs):
|
||||
# type: (agents.RunContextWrapper, agents.Agent, dict[str, Any]) -> sentry_sdk.tracing.Span
|
||||
start_span_function = get_start_span_function()
|
||||
span = start_span_function(
|
||||
op=OP.GEN_AI_INVOKE_AGENT,
|
||||
name=f"invoke_agent {agent.name}",
|
||||
origin=SPAN_ORIGIN,
|
||||
)
|
||||
span.__enter__()
|
||||
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
|
||||
|
||||
if should_send_default_pii():
|
||||
messages = []
|
||||
if agent.instructions:
|
||||
message = (
|
||||
agent.instructions
|
||||
if isinstance(agent.instructions, str)
|
||||
else safe_serialize(agent.instructions)
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"content": [{"text": message, "type": "text"}],
|
||||
"role": "system",
|
||||
}
|
||||
)
|
||||
|
||||
original_input = kwargs.get("original_input")
|
||||
if original_input is not None:
|
||||
message = (
|
||||
original_input
|
||||
if isinstance(original_input, str)
|
||||
else safe_serialize(original_input)
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"content": [{"text": message, "type": "text"}],
|
||||
"role": "user",
|
||||
}
|
||||
)
|
||||
|
||||
if len(messages) > 0:
|
||||
normalized_messages = normalize_message_roles(messages)
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_REQUEST_MESSAGES,
|
||||
normalized_messages,
|
||||
unpack=False,
|
||||
)
|
||||
|
||||
_set_agent_data(span, agent)
|
||||
|
||||
return span
|
||||
|
||||
|
||||
def update_invoke_agent_span(context, agent, output):
|
||||
# type: (agents.RunContextWrapper, agents.Agent, Any) -> None
|
||||
span = sentry_sdk.get_current_span()
|
||||
|
||||
if span:
|
||||
if should_send_default_pii():
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_RESPONSE_TEXT, output, unpack=False
|
||||
)
|
||||
|
||||
span.__exit__(None, None, None)
|
||||
+199
@@ -0,0 +1,199 @@
|
||||
import sentry_sdk
|
||||
from sentry_sdk.ai.utils import (
|
||||
GEN_AI_ALLOWED_MESSAGE_ROLES,
|
||||
normalize_message_roles,
|
||||
set_data_normalized,
|
||||
normalize_message_role,
|
||||
)
|
||||
from sentry_sdk.consts import SPANDATA, SPANSTATUS, OP
|
||||
from sentry_sdk.integrations import DidNotEnable
|
||||
from sentry_sdk.scope import should_send_default_pii
|
||||
from sentry_sdk.tracing_utils import set_span_errored
|
||||
from sentry_sdk.utils import event_from_exception, safe_serialize
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
from agents import Usage
|
||||
|
||||
try:
|
||||
import agents
|
||||
|
||||
except ImportError:
|
||||
raise DidNotEnable("OpenAI Agents not installed")
|
||||
|
||||
|
||||
def _capture_exception(exc):
|
||||
# type: (Any) -> None
|
||||
set_span_errored()
|
||||
|
||||
event, hint = event_from_exception(
|
||||
exc,
|
||||
client_options=sentry_sdk.get_client().options,
|
||||
mechanism={"type": "openai_agents", "handled": False},
|
||||
)
|
||||
sentry_sdk.capture_event(event, hint=hint)
|
||||
|
||||
|
||||
def _set_agent_data(span, agent):
|
||||
# type: (sentry_sdk.tracing.Span, agents.Agent) -> None
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_SYSTEM, "openai"
|
||||
) # See footnote for https://opentelemetry.io/docs/specs/semconv/registry/attributes/gen-ai/#gen-ai-system for explanation why.
|
||||
|
||||
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent.name)
|
||||
|
||||
if agent.model_settings.max_tokens:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, agent.model_settings.max_tokens
|
||||
)
|
||||
|
||||
if agent.model:
|
||||
model_name = agent.model.model if hasattr(agent.model, "model") else agent.model
|
||||
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
|
||||
|
||||
if agent.model_settings.presence_penalty:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
|
||||
agent.model_settings.presence_penalty,
|
||||
)
|
||||
|
||||
if agent.model_settings.temperature:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_REQUEST_TEMPERATURE, agent.model_settings.temperature
|
||||
)
|
||||
|
||||
if agent.model_settings.top_p:
|
||||
span.set_data(SPANDATA.GEN_AI_REQUEST_TOP_P, agent.model_settings.top_p)
|
||||
|
||||
if agent.model_settings.frequency_penalty:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
|
||||
agent.model_settings.frequency_penalty,
|
||||
)
|
||||
|
||||
if len(agent.tools) > 0:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
|
||||
safe_serialize([vars(tool) for tool in agent.tools]),
|
||||
)
|
||||
|
||||
|
||||
def _set_usage_data(span, usage):
|
||||
# type: (sentry_sdk.tracing.Span, Usage) -> None
|
||||
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage.input_tokens)
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
|
||||
usage.input_tokens_details.cached_tokens,
|
||||
)
|
||||
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage.output_tokens)
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
|
||||
usage.output_tokens_details.reasoning_tokens,
|
||||
)
|
||||
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage.total_tokens)
|
||||
|
||||
|
||||
def _set_input_data(span, get_response_kwargs):
|
||||
# type: (sentry_sdk.tracing.Span, dict[str, Any]) -> None
|
||||
if not should_send_default_pii():
|
||||
return
|
||||
request_messages = []
|
||||
|
||||
system_instructions = get_response_kwargs.get("system_instructions")
|
||||
if system_instructions:
|
||||
request_messages.append(
|
||||
{
|
||||
"role": GEN_AI_ALLOWED_MESSAGE_ROLES.SYSTEM,
|
||||
"content": [{"type": "text", "text": system_instructions}],
|
||||
}
|
||||
)
|
||||
|
||||
for message in get_response_kwargs.get("input", []):
|
||||
if "role" in message:
|
||||
normalized_role = normalize_message_role(message.get("role"))
|
||||
request_messages.append(
|
||||
{
|
||||
"role": normalized_role,
|
||||
"content": [{"type": "text", "text": message.get("content")}],
|
||||
}
|
||||
)
|
||||
else:
|
||||
if message.get("type") == "function_call":
|
||||
request_messages.append(
|
||||
{
|
||||
"role": GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT,
|
||||
"content": [message],
|
||||
}
|
||||
)
|
||||
elif message.get("type") == "function_call_output":
|
||||
request_messages.append(
|
||||
{
|
||||
"role": GEN_AI_ALLOWED_MESSAGE_ROLES.TOOL,
|
||||
"content": [message],
|
||||
}
|
||||
)
|
||||
|
||||
set_data_normalized(
|
||||
span,
|
||||
SPANDATA.GEN_AI_REQUEST_MESSAGES,
|
||||
normalize_message_roles(request_messages),
|
||||
unpack=False,
|
||||
)
|
||||
|
||||
|
||||
def _set_output_data(span, result):
|
||||
# type: (sentry_sdk.tracing.Span, Any) -> None
|
||||
if not should_send_default_pii():
|
||||
return
|
||||
|
||||
output_messages = {
|
||||
"response": [],
|
||||
"tool": [],
|
||||
} # type: (dict[str, list[Any]])
|
||||
|
||||
for output in result.output:
|
||||
if output.type == "function_call":
|
||||
output_messages["tool"].append(output.dict())
|
||||
elif output.type == "message":
|
||||
for output_message in output.content:
|
||||
try:
|
||||
output_messages["response"].append(output_message.text)
|
||||
except AttributeError:
|
||||
# Unknown output message type, just return the json
|
||||
output_messages["response"].append(output_message.dict())
|
||||
|
||||
if len(output_messages["tool"]) > 0:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, safe_serialize(output_messages["tool"])
|
||||
)
|
||||
|
||||
if len(output_messages["response"]) > 0:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_RESPONSE_TEXT, output_messages["response"]
|
||||
)
|
||||
|
||||
|
||||
def _create_mcp_execute_tool_spans(span, result):
|
||||
# type: (sentry_sdk.tracing.Span, agents.Result) -> None
|
||||
for output in result.output:
|
||||
if output.__class__.__name__ == "McpCall":
|
||||
with sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_EXECUTE_TOOL,
|
||||
description=f"execute_tool {output.name}",
|
||||
start_timestamp=span.start_timestamp,
|
||||
) as execute_tool_span:
|
||||
set_data_normalized(execute_tool_span, SPANDATA.GEN_AI_TOOL_TYPE, "mcp")
|
||||
set_data_normalized(
|
||||
execute_tool_span, SPANDATA.GEN_AI_TOOL_NAME, output.name
|
||||
)
|
||||
if should_send_default_pii():
|
||||
execute_tool_span.set_data(
|
||||
SPANDATA.GEN_AI_TOOL_INPUT, output.arguments
|
||||
)
|
||||
execute_tool_span.set_data(
|
||||
SPANDATA.GEN_AI_TOOL_OUTPUT, output.output
|
||||
)
|
||||
if output.error:
|
||||
execute_tool_span.set_status(SPANSTATUS.ERROR)
|
||||
+5
-9
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import sentry_sdk
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sentry_sdk.feature_flags import add_feature_flag
|
||||
from sentry_sdk.integrations import DidNotEnable, Integration
|
||||
|
||||
try:
|
||||
@@ -8,7 +8,6 @@ try:
|
||||
from openfeature.hook import Hook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openfeature.flag_evaluation import FlagEvaluationDetails
|
||||
from openfeature.hook import HookContext, HookHints
|
||||
except ImportError:
|
||||
raise DidNotEnable("OpenFeature is not installed")
|
||||
@@ -25,15 +24,12 @@ class OpenFeatureIntegration(Integration):
|
||||
|
||||
|
||||
class OpenFeatureHook(Hook):
|
||||
|
||||
def after(self, hook_context, details, hints):
|
||||
# type: (HookContext, FlagEvaluationDetails[bool], HookHints) -> None
|
||||
# type: (Any, Any, Any) -> None
|
||||
if isinstance(details.value, bool):
|
||||
flags = sentry_sdk.get_current_scope().flags
|
||||
flags.set(details.flag_key, details.value)
|
||||
add_feature_flag(details.flag_key, details.value)
|
||||
|
||||
def error(self, hook_context, exception, hints):
|
||||
# type: (HookContext, Exception, HookHints) -> None
|
||||
if isinstance(hook_context.default_value, bool):
|
||||
flags = sentry_sdk.get_current_scope().flags
|
||||
flags.set(hook_context.flag_key, hook_context.default_value)
|
||||
add_feature_flag(hook_context.flag_key, hook_context.default_value)
|
||||
|
||||
@@ -116,7 +116,9 @@ def pure_eval_frame(frame):
|
||||
return (n.lineno, n.col_offset)
|
||||
|
||||
nodes_before_stmt = [
|
||||
node for node in nodes if start(node) < stmt.last_token.end # type: ignore
|
||||
node
|
||||
for node in nodes
|
||||
if start(node) < stmt.last_token.end # type: ignore
|
||||
]
|
||||
if nodes_before_stmt:
|
||||
# The position of the last node before or in the statement
|
||||
|
||||
+47
@@ -0,0 +1,47 @@
|
||||
from sentry_sdk.integrations import DidNotEnable, Integration
|
||||
|
||||
|
||||
try:
|
||||
import pydantic_ai # type: ignore
|
||||
except ImportError:
|
||||
raise DidNotEnable("pydantic-ai not installed")
|
||||
|
||||
|
||||
from .patches import (
|
||||
_patch_agent_run,
|
||||
_patch_graph_nodes,
|
||||
_patch_model_request,
|
||||
_patch_tool_execution,
|
||||
)
|
||||
|
||||
|
||||
class PydanticAIIntegration(Integration):
|
||||
identifier = "pydantic_ai"
|
||||
origin = f"auto.ai.{identifier}"
|
||||
|
||||
def __init__(self, include_prompts=True):
|
||||
# type: (bool) -> None
|
||||
"""
|
||||
Initialize the Pydantic AI integration.
|
||||
|
||||
Args:
|
||||
include_prompts: Whether to include prompts and messages in span data.
|
||||
Requires send_default_pii=True. Defaults to True.
|
||||
"""
|
||||
self.include_prompts = include_prompts
|
||||
|
||||
@staticmethod
|
||||
def setup_once():
|
||||
# type: () -> None
|
||||
"""
|
||||
Set up the pydantic-ai integration.
|
||||
|
||||
This patches the key methods in pydantic-ai to create Sentry spans for:
|
||||
- Agent invocations (Agent.run methods)
|
||||
- Model requests (AI client calls)
|
||||
- Tool executions
|
||||
"""
|
||||
_patch_agent_run()
|
||||
_patch_graph_nodes()
|
||||
_patch_model_request()
|
||||
_patch_tool_execution()
|
||||
+1
@@ -0,0 +1 @@
|
||||
SPAN_ORIGIN = "auto.ai.pydantic_ai"
|
||||
+4
@@ -0,0 +1,4 @@
|
||||
from .agent_run import _patch_agent_run # noqa: F401
|
||||
from .graph_nodes import _patch_graph_nodes # noqa: F401
|
||||
from .model_request import _patch_model_request # noqa: F401
|
||||
from .tools import _patch_tool_execution # noqa: F401
|
||||
+217
@@ -0,0 +1,217 @@
|
||||
from functools import wraps
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.tracing_utils import set_span_errored
|
||||
from sentry_sdk.utils import event_from_exception
|
||||
|
||||
from ..spans import invoke_agent_span, update_invoke_agent_span
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from pydantic_ai.agent import Agent # type: ignore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
def _capture_exception(exc):
|
||||
# type: (Any) -> None
|
||||
set_span_errored()
|
||||
|
||||
event, hint = event_from_exception(
|
||||
exc,
|
||||
client_options=sentry_sdk.get_client().options,
|
||||
mechanism={"type": "pydantic_ai", "handled": False},
|
||||
)
|
||||
sentry_sdk.capture_event(event, hint=hint)
|
||||
|
||||
|
||||
class _StreamingContextManagerWrapper:
|
||||
"""Wrapper for streaming methods that return async context managers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent,
|
||||
original_ctx_manager,
|
||||
user_prompt,
|
||||
model,
|
||||
model_settings,
|
||||
is_streaming=True,
|
||||
):
|
||||
# type: (Any, Any, Any, Any, Any, bool) -> None
|
||||
self.agent = agent
|
||||
self.original_ctx_manager = original_ctx_manager
|
||||
self.user_prompt = user_prompt
|
||||
self.model = model
|
||||
self.model_settings = model_settings
|
||||
self.is_streaming = is_streaming
|
||||
self._isolation_scope = None # type: Any
|
||||
self._span = None # type: Optional[sentry_sdk.tracing.Span]
|
||||
self._result = None # type: Any
|
||||
|
||||
async def __aenter__(self):
|
||||
# type: () -> Any
|
||||
# Set up isolation scope and invoke_agent span
|
||||
self._isolation_scope = sentry_sdk.isolation_scope()
|
||||
self._isolation_scope.__enter__()
|
||||
|
||||
# Store agent reference and streaming flag
|
||||
sentry_sdk.get_current_scope().set_context(
|
||||
"pydantic_ai_agent", {"_agent": self.agent, "_streaming": self.is_streaming}
|
||||
)
|
||||
|
||||
# Create invoke_agent span (will be closed in __aexit__)
|
||||
self._span = invoke_agent_span(
|
||||
self.user_prompt, self.agent, self.model, self.model_settings
|
||||
)
|
||||
self._span.__enter__()
|
||||
|
||||
# Enter the original context manager
|
||||
result = await self.original_ctx_manager.__aenter__()
|
||||
self._result = result
|
||||
return result
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# type: (Any, Any, Any) -> None
|
||||
try:
|
||||
# Exit the original context manager first
|
||||
await self.original_ctx_manager.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
# Update span with output if successful
|
||||
if exc_type is None and self._result and hasattr(self._result, "output"):
|
||||
output = (
|
||||
self._result.output if hasattr(self._result, "output") else None
|
||||
)
|
||||
if self._span is not None:
|
||||
update_invoke_agent_span(self._span, output)
|
||||
finally:
|
||||
sentry_sdk.get_current_scope().remove_context("pydantic_ai_agent")
|
||||
# Clean up invoke span
|
||||
if self._span:
|
||||
self._span.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
# Clean up isolation scope
|
||||
if self._isolation_scope:
|
||||
self._isolation_scope.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
def _create_run_wrapper(original_func, is_streaming=False):
|
||||
# type: (Callable[..., Any], bool) -> Callable[..., Any]
|
||||
"""
|
||||
Wraps the Agent.run method to create an invoke_agent span.
|
||||
|
||||
Args:
|
||||
original_func: The original run method
|
||||
is_streaming: Whether this is a streaming method (for future use)
|
||||
"""
|
||||
|
||||
@wraps(original_func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
# type: (Any, *Any, **Any) -> Any
|
||||
# Isolate each workflow so that when agents are run in asyncio tasks they
|
||||
# don't touch each other's scopes
|
||||
with sentry_sdk.isolation_scope():
|
||||
# Store agent reference and streaming flag in Sentry scope for access in nested spans
|
||||
# We store the full agent to allow access to tools and system prompts
|
||||
sentry_sdk.get_current_scope().set_context(
|
||||
"pydantic_ai_agent", {"_agent": self, "_streaming": is_streaming}
|
||||
)
|
||||
|
||||
# Extract parameters for the span
|
||||
user_prompt = kwargs.get("user_prompt") or (args[0] if args else None)
|
||||
model = kwargs.get("model")
|
||||
model_settings = kwargs.get("model_settings")
|
||||
|
||||
# Create invoke_agent span
|
||||
with invoke_agent_span(user_prompt, self, model, model_settings) as span:
|
||||
try:
|
||||
result = await original_func(self, *args, **kwargs)
|
||||
|
||||
# Update span with output
|
||||
output = result.output if hasattr(result, "output") else None
|
||||
update_invoke_agent_span(span, output)
|
||||
|
||||
return result
|
||||
except Exception as exc:
|
||||
_capture_exception(exc)
|
||||
raise exc from None
|
||||
finally:
|
||||
sentry_sdk.get_current_scope().remove_context("pydantic_ai_agent")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _create_streaming_wrapper(original_func):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
"""
|
||||
Wraps run_stream method that returns an async context manager.
|
||||
"""
|
||||
|
||||
@wraps(original_func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
# type: (Any, *Any, **Any) -> Any
|
||||
# Extract parameters for the span
|
||||
user_prompt = kwargs.get("user_prompt") or (args[0] if args else None)
|
||||
model = kwargs.get("model")
|
||||
model_settings = kwargs.get("model_settings")
|
||||
|
||||
# Call original function to get the context manager
|
||||
original_ctx_manager = original_func(self, *args, **kwargs)
|
||||
|
||||
# Wrap it with our instrumentation
|
||||
return _StreamingContextManagerWrapper(
|
||||
agent=self,
|
||||
original_ctx_manager=original_ctx_manager,
|
||||
user_prompt=user_prompt,
|
||||
model=model,
|
||||
model_settings=model_settings,
|
||||
is_streaming=True,
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _create_streaming_events_wrapper(original_func):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
"""
|
||||
Wraps run_stream_events method - no span needed as it delegates to run().
|
||||
|
||||
Note: run_stream_events internally calls self.run() with an event_stream_handler,
|
||||
so the invoke_agent span will be created by the run() wrapper.
|
||||
"""
|
||||
|
||||
@wraps(original_func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
# type: (Any, *Any, **Any) -> Any
|
||||
# Just call the original generator - it will call run() which has the instrumentation
|
||||
try:
|
||||
async for event in original_func(self, *args, **kwargs):
|
||||
yield event
|
||||
except Exception as exc:
|
||||
_capture_exception(exc)
|
||||
raise exc from None
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _patch_agent_run():
|
||||
# type: () -> None
|
||||
"""
|
||||
Patches the Agent run methods to create spans for agent execution.
|
||||
|
||||
This patches both non-streaming (run, run_sync) and streaming
|
||||
(run_stream, run_stream_events) methods.
|
||||
"""
|
||||
|
||||
# Store original methods
|
||||
original_run = Agent.run
|
||||
original_run_stream = Agent.run_stream
|
||||
original_run_stream_events = Agent.run_stream_events
|
||||
|
||||
# Wrap and apply patches for non-streaming methods
|
||||
Agent.run = _create_run_wrapper(original_run, is_streaming=False)
|
||||
|
||||
# Wrap and apply patches for streaming methods
|
||||
Agent.run_stream = _create_streaming_wrapper(original_run_stream)
|
||||
Agent.run_stream_events = _create_streaming_events_wrapper(
|
||||
original_run_stream_events
|
||||
)
|
||||
+105
@@ -0,0 +1,105 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import wraps
|
||||
|
||||
import sentry_sdk
|
||||
|
||||
from ..spans import (
|
||||
ai_client_span,
|
||||
update_ai_client_span,
|
||||
)
|
||||
from pydantic_ai._agent_graph import ModelRequestNode # type: ignore
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
def _extract_span_data(node, ctx):
|
||||
# type: (Any, Any) -> tuple[list[Any], Any, Any]
|
||||
"""Extract common data needed for creating chat spans.
|
||||
|
||||
Returns:
|
||||
Tuple of (messages, model, model_settings)
|
||||
"""
|
||||
# Extract model and settings from context
|
||||
model = None
|
||||
model_settings = None
|
||||
if hasattr(ctx, "deps"):
|
||||
model = getattr(ctx.deps, "model", None)
|
||||
model_settings = getattr(ctx.deps, "model_settings", None)
|
||||
|
||||
# Build full message list: history + current request
|
||||
messages = []
|
||||
if hasattr(ctx, "state") and hasattr(ctx.state, "message_history"):
|
||||
messages.extend(ctx.state.message_history)
|
||||
|
||||
current_request = getattr(node, "request", None)
|
||||
if current_request:
|
||||
messages.append(current_request)
|
||||
|
||||
return messages, model, model_settings
|
||||
|
||||
|
||||
def _patch_graph_nodes():
|
||||
# type: () -> None
|
||||
"""
|
||||
Patches the graph node execution to create appropriate spans.
|
||||
|
||||
ModelRequestNode -> Creates ai_client span for model requests
|
||||
CallToolsNode -> Handles tool calls (spans created in tool patching)
|
||||
"""
|
||||
|
||||
# Patch ModelRequestNode to create ai_client spans
|
||||
original_model_request_run = ModelRequestNode.run
|
||||
|
||||
@wraps(original_model_request_run)
|
||||
async def wrapped_model_request_run(self, ctx):
|
||||
# type: (Any, Any) -> Any
|
||||
messages, model, model_settings = _extract_span_data(self, ctx)
|
||||
|
||||
with ai_client_span(messages, None, model, model_settings) as span:
|
||||
result = await original_model_request_run(self, ctx)
|
||||
|
||||
# Extract response from result if available
|
||||
model_response = None
|
||||
if hasattr(result, "model_response"):
|
||||
model_response = result.model_response
|
||||
|
||||
update_ai_client_span(span, model_response)
|
||||
return result
|
||||
|
||||
ModelRequestNode.run = wrapped_model_request_run
|
||||
|
||||
# Patch ModelRequestNode.stream for streaming requests
|
||||
original_model_request_stream = ModelRequestNode.stream
|
||||
|
||||
def create_wrapped_stream(original_stream_method):
|
||||
# type: (Callable[..., Any]) -> Callable[..., Any]
|
||||
"""Create a wrapper for ModelRequestNode.stream that creates chat spans."""
|
||||
|
||||
@asynccontextmanager
|
||||
@wraps(original_stream_method)
|
||||
async def wrapped_model_request_stream(self, ctx):
|
||||
# type: (Any, Any) -> Any
|
||||
messages, model, model_settings = _extract_span_data(self, ctx)
|
||||
|
||||
# Create chat span for streaming request
|
||||
with ai_client_span(messages, None, model, model_settings) as span:
|
||||
# Call the original stream method
|
||||
async with original_stream_method(self, ctx) as stream:
|
||||
yield stream
|
||||
|
||||
# After streaming completes, update span with response data
|
||||
# The ModelRequestNode stores the final response in _result
|
||||
model_response = None
|
||||
if hasattr(self, "_result") and self._result is not None:
|
||||
# _result is a NextNode containing the model_response
|
||||
if hasattr(self._result, "model_response"):
|
||||
model_response = self._result.model_response
|
||||
|
||||
update_ai_client_span(span, model_response)
|
||||
|
||||
return wrapped_model_request_stream
|
||||
|
||||
ModelRequestNode.stream = create_wrapped_stream(original_model_request_stream)
|
||||
+35
@@ -0,0 +1,35 @@
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic_ai import models # type: ignore
|
||||
|
||||
from ..spans import ai_client_span, update_ai_client_span
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _patch_model_request():
|
||||
# type: () -> None
|
||||
"""
|
||||
Patches model request execution to create AI client spans.
|
||||
|
||||
In pydantic-ai, model requests are handled through the Model interface.
|
||||
We need to patch the request method on models to create spans.
|
||||
"""
|
||||
|
||||
# Patch the base Model class's request method
|
||||
if hasattr(models, "Model"):
|
||||
original_request = models.Model.request
|
||||
|
||||
@wraps(original_request)
|
||||
async def wrapped_request(self, messages, *args, **kwargs):
|
||||
# type: (Any, Any, *Any, **Any) -> Any
|
||||
# Pass all messages (full conversation history)
|
||||
with ai_client_span(messages, None, self, None) as span:
|
||||
result = await original_request(self, messages, *args, **kwargs)
|
||||
update_ai_client_span(span, result)
|
||||
return result
|
||||
|
||||
models.Model.request = wrapped_request
|
||||
+75
@@ -0,0 +1,75 @@
|
||||
from functools import wraps
|
||||
|
||||
from pydantic_ai._tool_manager import ToolManager # type: ignore
|
||||
|
||||
import sentry_sdk
|
||||
|
||||
from ..spans import execute_tool_span, update_execute_tool_span
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from pydantic_ai.mcp import MCPServer # type: ignore
|
||||
|
||||
HAS_MCP = True
|
||||
except ImportError:
|
||||
HAS_MCP = False
|
||||
|
||||
|
||||
def _patch_tool_execution():
|
||||
# type: () -> None
|
||||
"""
|
||||
Patch ToolManager._call_tool to create execute_tool spans.
|
||||
|
||||
This is the single point where ALL tool calls flow through in pydantic_ai,
|
||||
regardless of toolset type (function, MCP, combined, wrapper, etc.).
|
||||
|
||||
By patching here, we avoid:
|
||||
- Patching multiple toolset classes
|
||||
- Dealing with signature mismatches from instrumented MCP servers
|
||||
- Complex nested toolset handling
|
||||
"""
|
||||
|
||||
original_call_tool = ToolManager._call_tool
|
||||
|
||||
@wraps(original_call_tool)
|
||||
async def wrapped_call_tool(self, call, allow_partial, wrap_validation_errors):
|
||||
# type: (Any, Any, bool, bool) -> Any
|
||||
|
||||
# Extract tool info before calling original
|
||||
name = call.tool_name
|
||||
tool = self.tools.get(name) if self.tools else None
|
||||
|
||||
# Determine tool type by checking tool.toolset
|
||||
tool_type = "function" # default
|
||||
if tool and HAS_MCP and isinstance(tool.toolset, MCPServer):
|
||||
tool_type = "mcp"
|
||||
|
||||
# Get agent from Sentry scope
|
||||
current_span = sentry_sdk.get_current_span()
|
||||
if current_span and tool:
|
||||
agent_data = (
|
||||
sentry_sdk.get_current_scope()._contexts.get("pydantic_ai_agent") or {}
|
||||
)
|
||||
agent = agent_data.get("_agent")
|
||||
|
||||
# Get args for span (before validation)
|
||||
# call.args can be a string (JSON) or dict
|
||||
args_dict = call.args if isinstance(call.args, dict) else {}
|
||||
|
||||
with execute_tool_span(name, args_dict, agent, tool_type=tool_type) as span:
|
||||
result = await original_call_tool(
|
||||
self, call, allow_partial, wrap_validation_errors
|
||||
)
|
||||
update_execute_tool_span(span, result)
|
||||
return result
|
||||
|
||||
# No span context - just call original
|
||||
return await original_call_tool(
|
||||
self, call, allow_partial, wrap_validation_errors
|
||||
)
|
||||
|
||||
ToolManager._call_tool = wrapped_call_tool
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
from .ai_client import ai_client_span, update_ai_client_span # noqa: F401
|
||||
from .execute_tool import execute_tool_span, update_execute_tool_span # noqa: F401
|
||||
from .invoke_agent import invoke_agent_span, update_invoke_agent_span # noqa: F401
|
||||
+253
@@ -0,0 +1,253 @@
|
||||
import sentry_sdk
|
||||
from sentry_sdk.ai.utils import set_data_normalized
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
from sentry_sdk.utils import safe_serialize
|
||||
|
||||
from ..consts import SPAN_ORIGIN
|
||||
from ..utils import (
|
||||
_set_agent_data,
|
||||
_set_available_tools,
|
||||
_set_model_data,
|
||||
_should_send_prompts,
|
||||
_get_model_name,
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, List, Dict
|
||||
from pydantic_ai.usage import RequestUsage # type: ignore
|
||||
|
||||
try:
|
||||
from pydantic_ai.messages import ( # type: ignore
|
||||
BaseToolCallPart,
|
||||
BaseToolReturnPart,
|
||||
SystemPromptPart,
|
||||
UserPromptPart,
|
||||
TextPart,
|
||||
ThinkingPart,
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback if these classes are not available
|
||||
BaseToolCallPart = None
|
||||
BaseToolReturnPart = None
|
||||
SystemPromptPart = None
|
||||
UserPromptPart = None
|
||||
TextPart = None
|
||||
ThinkingPart = None
|
||||
|
||||
|
||||
def _set_usage_data(span, usage):
|
||||
# type: (sentry_sdk.tracing.Span, RequestUsage) -> None
|
||||
"""Set token usage data on a span."""
|
||||
if usage is None:
|
||||
return
|
||||
|
||||
if hasattr(usage, "input_tokens") and usage.input_tokens is not None:
|
||||
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage.input_tokens)
|
||||
|
||||
if hasattr(usage, "output_tokens") and usage.output_tokens is not None:
|
||||
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage.output_tokens)
|
||||
|
||||
if hasattr(usage, "total_tokens") and usage.total_tokens is not None:
|
||||
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage.total_tokens)
|
||||
|
||||
|
||||
def _set_input_messages(span, messages):
|
||||
# type: (sentry_sdk.tracing.Span, Any) -> None
|
||||
"""Set input messages data on a span."""
|
||||
if not _should_send_prompts():
|
||||
return
|
||||
|
||||
if not messages:
|
||||
return
|
||||
|
||||
try:
|
||||
formatted_messages = []
|
||||
system_prompt = None
|
||||
|
||||
# Extract system prompt from any ModelRequest with instructions
|
||||
for msg in messages:
|
||||
if hasattr(msg, "instructions") and msg.instructions:
|
||||
system_prompt = msg.instructions
|
||||
break
|
||||
|
||||
# Add system prompt as first message if present
|
||||
if system_prompt:
|
||||
formatted_messages.append(
|
||||
{"role": "system", "content": [{"type": "text", "text": system_prompt}]}
|
||||
)
|
||||
|
||||
for msg in messages:
|
||||
if hasattr(msg, "parts"):
|
||||
for part in msg.parts:
|
||||
role = "user"
|
||||
# Use isinstance checks with proper base classes
|
||||
if SystemPromptPart and isinstance(part, SystemPromptPart):
|
||||
role = "system"
|
||||
elif (
|
||||
(TextPart and isinstance(part, TextPart))
|
||||
or (ThinkingPart and isinstance(part, ThinkingPart))
|
||||
or (BaseToolCallPart and isinstance(part, BaseToolCallPart))
|
||||
):
|
||||
role = "assistant"
|
||||
elif BaseToolReturnPart and isinstance(part, BaseToolReturnPart):
|
||||
role = "tool"
|
||||
|
||||
content = [] # type: List[Dict[str, Any] | str]
|
||||
tool_calls = None
|
||||
tool_call_id = None
|
||||
|
||||
# Handle ToolCallPart (assistant requesting tool use)
|
||||
if BaseToolCallPart and isinstance(part, BaseToolCallPart):
|
||||
tool_call_data = {}
|
||||
if hasattr(part, "tool_name"):
|
||||
tool_call_data["name"] = part.tool_name
|
||||
if hasattr(part, "args"):
|
||||
tool_call_data["arguments"] = safe_serialize(part.args)
|
||||
if tool_call_data:
|
||||
tool_calls = [tool_call_data]
|
||||
# Handle ToolReturnPart (tool result)
|
||||
elif BaseToolReturnPart and isinstance(part, BaseToolReturnPart):
|
||||
if hasattr(part, "tool_name"):
|
||||
tool_call_id = part.tool_name
|
||||
if hasattr(part, "content"):
|
||||
content.append({"type": "text", "text": str(part.content)})
|
||||
# Handle regular content
|
||||
elif hasattr(part, "content"):
|
||||
if isinstance(part.content, str):
|
||||
content.append({"type": "text", "text": part.content})
|
||||
elif isinstance(part.content, list):
|
||||
for item in part.content:
|
||||
if isinstance(item, str):
|
||||
content.append({"type": "text", "text": item})
|
||||
else:
|
||||
content.append(safe_serialize(item))
|
||||
else:
|
||||
content.append({"type": "text", "text": str(part.content)})
|
||||
|
||||
# Add message if we have content or tool calls
|
||||
if content or tool_calls:
|
||||
message = {"role": role} # type: Dict[str, Any]
|
||||
if content:
|
||||
message["content"] = content
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
if tool_call_id:
|
||||
message["tool_call_id"] = tool_call_id
|
||||
formatted_messages.append(message)
|
||||
|
||||
if formatted_messages:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, formatted_messages, unpack=False
|
||||
)
|
||||
except Exception:
|
||||
# If we fail to format messages, just skip it
|
||||
pass
|
||||
|
||||
|
||||
def _set_output_data(span, response):
|
||||
# type: (sentry_sdk.tracing.Span, Any) -> None
|
||||
"""Set output data on a span."""
|
||||
if not _should_send_prompts():
|
||||
return
|
||||
|
||||
if not response:
|
||||
return
|
||||
|
||||
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, response.model_name)
|
||||
try:
|
||||
# Extract text from ModelResponse
|
||||
if hasattr(response, "parts"):
|
||||
texts = []
|
||||
tool_calls = []
|
||||
|
||||
for part in response.parts:
|
||||
if TextPart and isinstance(part, TextPart) and hasattr(part, "content"):
|
||||
texts.append(part.content)
|
||||
elif BaseToolCallPart and isinstance(part, BaseToolCallPart):
|
||||
tool_call_data = {
|
||||
"type": "function",
|
||||
}
|
||||
if hasattr(part, "tool_name"):
|
||||
tool_call_data["name"] = part.tool_name
|
||||
if hasattr(part, "args"):
|
||||
tool_call_data["arguments"] = safe_serialize(part.args)
|
||||
tool_calls.append(tool_call_data)
|
||||
|
||||
if texts:
|
||||
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, texts)
|
||||
|
||||
if tool_calls:
|
||||
span.set_data(
|
||||
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, safe_serialize(tool_calls)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# If we fail to format output, just skip it
|
||||
pass
|
||||
|
||||
|
||||
def ai_client_span(messages, agent, model, model_settings):
|
||||
# type: (Any, Any, Any, Any) -> sentry_sdk.tracing.Span
|
||||
"""Create a span for an AI client call (model request).
|
||||
|
||||
Args:
|
||||
messages: Full conversation history (list of messages)
|
||||
agent: Agent object
|
||||
model: Model object
|
||||
model_settings: Model settings
|
||||
"""
|
||||
# Determine model name for span name
|
||||
model_obj = model
|
||||
if agent and hasattr(agent, "model"):
|
||||
model_obj = agent.model
|
||||
|
||||
model_name = _get_model_name(model_obj) or "unknown"
|
||||
|
||||
span = sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_CHAT,
|
||||
name=f"chat {model_name}",
|
||||
origin=SPAN_ORIGIN,
|
||||
)
|
||||
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
|
||||
|
||||
_set_agent_data(span, agent)
|
||||
_set_model_data(span, model, model_settings)
|
||||
|
||||
# Set streaming flag
|
||||
agent_data = sentry_sdk.get_current_scope()._contexts.get("pydantic_ai_agent") or {}
|
||||
is_streaming = agent_data.get("_streaming", False)
|
||||
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, is_streaming)
|
||||
|
||||
# Add available tools if agent is available
|
||||
agent_obj = agent
|
||||
if not agent_obj:
|
||||
# Try to get from Sentry scope
|
||||
agent_data = (
|
||||
sentry_sdk.get_current_scope()._contexts.get("pydantic_ai_agent") or {}
|
||||
)
|
||||
agent_obj = agent_data.get("_agent")
|
||||
|
||||
_set_available_tools(span, agent_obj)
|
||||
|
||||
# Set input messages (full conversation history)
|
||||
if messages:
|
||||
_set_input_messages(span, messages)
|
||||
|
||||
return span
|
||||
|
||||
|
||||
def update_ai_client_span(span, model_response):
|
||||
# type: (sentry_sdk.tracing.Span, Any) -> None
|
||||
"""Update the AI client span with response data."""
|
||||
if not span:
|
||||
return
|
||||
|
||||
# Set usage data if available
|
||||
if model_response and hasattr(model_response, "usage"):
|
||||
_set_usage_data(span, model_response.usage)
|
||||
|
||||
# Set output data
|
||||
_set_output_data(span, model_response)
|
||||
+49
@@ -0,0 +1,49 @@
|
||||
import sentry_sdk
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
from sentry_sdk.utils import safe_serialize
|
||||
|
||||
from ..consts import SPAN_ORIGIN
|
||||
from ..utils import _set_agent_data, _should_send_prompts
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
def execute_tool_span(tool_name, tool_args, agent, tool_type="function"):
|
||||
# type: (str, Any, Any, str) -> sentry_sdk.tracing.Span
|
||||
"""Create a span for tool execution.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool being executed
|
||||
tool_args: The arguments passed to the tool
|
||||
agent: The agent executing the tool
|
||||
tool_type: The type of tool ("function" for regular tools, "mcp" for MCP services)
|
||||
"""
|
||||
span = sentry_sdk.start_span(
|
||||
op=OP.GEN_AI_EXECUTE_TOOL,
|
||||
name=f"execute_tool {tool_name}",
|
||||
origin=SPAN_ORIGIN,
|
||||
)
|
||||
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "execute_tool")
|
||||
span.set_data(SPANDATA.GEN_AI_TOOL_TYPE, tool_type)
|
||||
span.set_data(SPANDATA.GEN_AI_TOOL_NAME, tool_name)
|
||||
|
||||
_set_agent_data(span, agent)
|
||||
|
||||
if _should_send_prompts() and tool_args is not None:
|
||||
span.set_data(SPANDATA.GEN_AI_TOOL_INPUT, safe_serialize(tool_args))
|
||||
|
||||
return span
|
||||
|
||||
|
||||
def update_execute_tool_span(span, result):
|
||||
# type: (sentry_sdk.tracing.Span, Any) -> None
|
||||
"""Update the execute tool span with the result."""
|
||||
if not span:
|
||||
return
|
||||
|
||||
if _should_send_prompts() and result is not None:
|
||||
span.set_data(SPANDATA.GEN_AI_TOOL_OUTPUT, safe_serialize(result))
|
||||
+112
@@ -0,0 +1,112 @@
|
||||
import sentry_sdk
|
||||
from sentry_sdk.ai.utils import get_start_span_function, set_data_normalized
|
||||
from sentry_sdk.consts import OP, SPANDATA
|
||||
|
||||
from ..consts import SPAN_ORIGIN
|
||||
from ..utils import (
|
||||
_set_agent_data,
|
||||
_set_available_tools,
|
||||
_set_model_data,
|
||||
_should_send_prompts,
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
def invoke_agent_span(user_prompt, agent, model, model_settings):
|
||||
# type: (Any, Any, Any, Any) -> sentry_sdk.tracing.Span
|
||||
"""Create a span for invoking the agent."""
|
||||
# Determine agent name for span
|
||||
name = "agent"
|
||||
if agent and getattr(agent, "name", None):
|
||||
name = agent.name
|
||||
|
||||
span = get_start_span_function()(
|
||||
op=OP.GEN_AI_INVOKE_AGENT,
|
||||
name=f"invoke_agent {name}",
|
||||
origin=SPAN_ORIGIN,
|
||||
)
|
||||
|
||||
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
|
||||
|
||||
_set_agent_data(span, agent)
|
||||
_set_model_data(span, model, model_settings)
|
||||
_set_available_tools(span, agent)
|
||||
|
||||
# Add user prompt and system prompts if available and prompts are enabled
|
||||
if _should_send_prompts():
|
||||
messages = []
|
||||
|
||||
# Add system prompts (both instructions and system_prompt)
|
||||
system_texts = []
|
||||
|
||||
if agent:
|
||||
# Check for system_prompt
|
||||
system_prompts = getattr(agent, "_system_prompts", None) or []
|
||||
for prompt in system_prompts:
|
||||
if isinstance(prompt, str):
|
||||
system_texts.append(prompt)
|
||||
|
||||
# Check for instructions (stored in _instructions)
|
||||
instructions = getattr(agent, "_instructions", None)
|
||||
if instructions:
|
||||
if isinstance(instructions, str):
|
||||
system_texts.append(instructions)
|
||||
elif isinstance(instructions, (list, tuple)):
|
||||
for instr in instructions:
|
||||
if isinstance(instr, str):
|
||||
system_texts.append(instr)
|
||||
elif callable(instr):
|
||||
# Skip dynamic/callable instructions
|
||||
pass
|
||||
|
||||
# Add all system texts as system messages
|
||||
for system_text in system_texts:
|
||||
messages.append(
|
||||
{
|
||||
"content": [{"text": system_text, "type": "text"}],
|
||||
"role": "system",
|
||||
}
|
||||
)
|
||||
|
||||
# Add user prompt
|
||||
if user_prompt:
|
||||
if isinstance(user_prompt, str):
|
||||
messages.append(
|
||||
{
|
||||
"content": [{"text": user_prompt, "type": "text"}],
|
||||
"role": "user",
|
||||
}
|
||||
)
|
||||
elif isinstance(user_prompt, list):
|
||||
# Handle list of user content
|
||||
content = []
|
||||
for item in user_prompt:
|
||||
if isinstance(item, str):
|
||||
content.append({"text": item, "type": "text"})
|
||||
if content:
|
||||
messages.append(
|
||||
{
|
||||
"content": content,
|
||||
"role": "user",
|
||||
}
|
||||
)
|
||||
|
||||
if messages:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages, unpack=False
|
||||
)
|
||||
|
||||
return span
|
||||
|
||||
|
||||
def update_invoke_agent_span(span, output):
|
||||
# type: (sentry_sdk.tracing.Span, Any) -> None
|
||||
"""Update and close the invoke agent span."""
|
||||
if span and _should_send_prompts() and output:
|
||||
set_data_normalized(
|
||||
span, SPANDATA.GEN_AI_RESPONSE_TEXT, str(output), unpack=False
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user