2025-12-01

This commit is contained in:
2026-03-17 14:58:51 -06:00
parent 183e865f8b
commit 4b82b57113
6846 changed files with 954887 additions and 162606 deletions
@@ -16,10 +16,8 @@
#
# ##### END GPL LICENSE BLOCK #####
from concurrent.futures import Future
from datetime import datetime
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional
from errno import EACCES, ENOSPC
import functools
import os
@@ -30,8 +28,6 @@ from .assets import (AssetType,
AssetData,
ModelType,
SIZES)
from .user import (PoliigonUser,
PoliigonSubscription)
from .plan_manager import SubscriptionState, PoliigonPlanUpgradeManager
@@ -44,7 +40,7 @@ from .logger import (DEBUG, # noqa F401, allowing downstream const usage
get_addon_logger,
NOT_SET,
WARNING)
from .notifications import NotificationSystem
from .notifications import NotificationSystem, Notification
from . import settings
from . import updater
from .multilingual import Multilingual
@@ -66,6 +62,22 @@ class PoliigonAddon():
library_paths: List = []
# Variables stored in the addon class to handle the top-level assets
# requests. Values are set on api remote control jobs (also api_rc params).
# NOTE: These values should only be changed in addon-core side;
# DO NOT CHANGE IT ON ANY DCC ADDON CODE;
all_assets_fetched: bool = False
my_assets_fetched: bool = False
recent_downloads_fetched: bool = False
# Variables to control inject token from web process
install_emitted: bool = False
is_user_injected: bool = False
_injected_token: Optional[str] = None
# WM Onboarding notice
wm_onboarding_notice: Optional[Notification] = None
def __init__(self,
addon_name: str,
addon_version: tuple,
@@ -78,7 +90,10 @@ class PoliigonAddon():
language: str = "en-US",
# See ThreadManager.__init__ for signature below,
# e.g. print_exc(fut: Future, key_pool: PoolKeys)
callback_print_exc: Optional[Callable] = None):
callback_print_exc: Optional[Callable] = None,
# Used to get access token file - if None, token file will
# need to be injected in the addon side;
addon_root_path: Optional[str] = None):
self.log_manager = get_addon_logger(env=addon_env)
if addon_env.env_name == "prod":
@@ -105,6 +120,7 @@ class PoliigonAddon():
self.software_version = software_version
self.addon_convention = addon_convention
self.addon_root_path = addon_root_path
self.user = None
self.login_error = None
self.api_rc = None # To be set on the DCC side
@@ -116,6 +132,7 @@ class PoliigonAddon():
self.set_logger_verbose(verbose=False)
self._settings = addon_settings
self.settings_config = self._settings.config
self._api = api.PoliigonConnector(
env=self._env,
software=software_source,
@@ -129,6 +146,8 @@ class PoliigonAddon():
".".join([str(x) for x in addon_version]),
".".join([str(x) for x in software_version])
)
self.set_api_token()
self._tm = tm.ThreadManager(callback_print_exc=callback_print_exc)
self.notify = NotificationSystem(self)
self._api.notification_system = self.notify
@@ -140,8 +159,6 @@ class PoliigonAddon():
local_json=self._env.local_updater_json
)
self.settings_config = self._settings.config
self.user_addon_dir = os.path.join(
os.path.expanduser("~"),
"Poliigon"
@@ -175,12 +192,19 @@ class PoliigonAddon():
report_exception: Callable,
report_thread: Callable,
status_listener: Callable,
get_renderer_name: Callable,
urls_dcc: Dict[str, str],
notify_icon_info: Any,
notify_icon_no_connection: Any,
notify_icon_survey: Any,
notify_icon_warn: Any,
notify_update_body: str
notify_icon_wm_onboarding: Any,
notify_update_body: str,
onboarding_wm_title: Optional[str] = None,
onboarding_wm_label: Optional[str] = None,
onboarding_wm_tooltip: Optional[str] = None,
allow_banner_notice: bool = False
# TODO(Andreas): Once API RC gets instanced here, add:
# page_size_online_assets: int,
# page_size_my_assets: int,
@@ -194,6 +218,7 @@ class PoliigonAddon():
self._api.get_optin = get_optin
self._api.set_on_invalidated(callback_on_invalidated_token)
self._api._status_listener = status_listener
self._api.get_current_renderer_name = get_renderer_name
self._api.add_poliigon_urls(urls_dcc)
self._api._report_message = report_message
self._api._report_exception = report_exception
@@ -204,8 +229,17 @@ class PoliigonAddon():
icon_info=notify_icon_info,
icon_no_connection=notify_icon_no_connection,
icon_survey=notify_icon_survey,
icon_warn=notify_icon_warn)
icon_warn=notify_icon_warn,
notify_icon_wm_onboarding=notify_icon_wm_onboarding)
self.notify.addon_params.update_body = notify_update_body
self.notify.addon_params.allow_banner_notice = allow_banner_notice
if onboarding_wm_title is not None:
self.notify.addon_params.onboarding_wm_title = onboarding_wm_title
if onboarding_wm_label is not None:
self.notify.addon_params.onboarding_wm_label = onboarding_wm_label
if onboarding_wm_tooltip is not None:
self.notify.addon_params.onboarding_wm_tooltip = onboarding_wm_tooltip
# TODO(Andreas): Once API RC gets instanced in constructor,
# add the following here:
@@ -232,6 +266,196 @@ class PoliigonAddon():
return wrapped_func_call
return wrapped_func
@run_threaded(tm.PoolKeys.INTERACTIVE)
def signal_preview_asset(self, asset_id: int) -> None:
"""Signals an asset preview in the background if user opted in."""
self._api.signal_preview_asset(asset_id)
@run_threaded(tm.PoolKeys.INTERACTIVE)
def signal_import_asset(self, asset_id: int) -> None:
"""Signals an asset imported in the background if user opted in."""
self._api.signal_import_asset(asset_id)
@run_threaded(tm.PoolKeys.INTERACTIVE)
def signal_view_screen(self, screen_name: str) -> None:
"""Signals a screen tab was viewed."""
self._api.signal_view_screen(screen_name)
@run_threaded(tm.PoolKeys.INTERACTIVE)
def signal_search(self, search: str) -> None:
"""Signals a search text was triggered."""
if search != "":
self._api.signal_search(search)
@run_threaded(tm.PoolKeys.INTERACTIVE)
def signal_category_filter(self, categories: str) -> None:
"""Signals a category was clicked."""
self._api.signal_view_category(categories)
@run_threaded(tm.PoolKeys.INTERACTIVE)
def signal_view_notification(self, notification_id: str) -> None:
"""Signals a notification was viewed."""
self._api.signal_view_notification(notification_id)
@run_threaded(tm.PoolKeys.INTERACTIVE)
def signal_click_notification(
self, notification_id: str, action: str) -> None:
"""Signals a notification click event."""
self._api.signal_click_notification(notification_id, action)
@run_threaded(tm.PoolKeys.INTERACTIVE)
def update_user_use(self) -> None:
"""Update user profile on api."""
user_use = self.user.user_profile
if user_use is None:
return
# Check if user already has primary_3d_software set
# Only assign DCC if not already set
assign_dcc = self.user.primary_3d_software is None
user_use_value = self.user.user_profile.value
email_preference = self.user.email_preference
render_engine = self.user.primary_rendering_engine
self._api.update_user_profile(user_use=user_use_value,
email_preference=email_preference,
primary_render_engine=render_engine,
assign_dcc=assign_dcc)
def check_install_event(self):
""" Triggers an install event. Should only be called when the automated
login (Injected token) happened.
If api token is not set, means that the user will have to log in and then
the backend will emit the event using time_since_enabled (server side).
Install event should be emmit only once;"""
if self._api.token is None or self.install_emitted:
return
same_injected_token = self._api.token == self._injected_token
if self.is_user_injected and same_injected_token and not self._api.invalidated:
self.install_emitted = True
self._api.signal_install()
def set_api_token(self) -> None:
if self.addon_root_path is None:
return
access_token_file = os.path.join(self.addon_root_path, "access_token.txt")
self.inject_token(access_token_file)
def inject_token(self,
downloaded_token_file: str,
callback: Optional[callable] = None,
delete_token_file: bool = True
) -> None:
valid_file = os.path.isfile(downloaded_token_file)
if not valid_file:
self.logger.debug("Token file not found.")
return
txt_file = downloaded_token_file.endswith(".txt")
if not txt_file:
self.logger.error("Invalid token file provided. Ignoring auto login;")
return
try:
with open(downloaded_token_file, "r") as f:
_token_content = f.read()
except Exception as e:
self.logger.error(f"Error parsing login token: {e}")
return
if delete_token_file:
try:
os.remove(downloaded_token_file)
except Exception as e:
self.logger.error(f"Unable to delete token file. {e};")
local_token = self.settings_config.get("user", "token", fallback=None)
if local_token:
self._api.token = local_token
return
# After splitting, two fields are expected: the token and the opt-in flag.
# If any of them is missing, do not inject the token;
split_token = _token_content.split(",")
if len(_token_content) >= 2 and _token_content[-2] != "," or len(split_token) != 2:
self.logger.error(f"Unexpected token format: {_token_content}")
return
self._injected_token, _opt_in = split_token
try:
opt_in_int = int(_opt_in)
except ValueError:
self.logger.error(f"OptIn value is not an Integer: {_opt_in};")
return
opt_in_str = "true" if bool(opt_in_int) else "false"
self.logger.debug(f"Injecting Token\n"
f"Token: {self._injected_token} | OptIn: {opt_in_str}")
self._api.token = self._injected_token
self.settings_config.set("user", "token", self._injected_token)
self.settings_config.set("logging", "reporting_opt_in", opt_in_str)
# User Confirmed flag is currently only being used by P4Max
self.settings_config.set("logging", "user_confirmed", "true")
self._settings.save_settings()
self.is_user_injected = True
if callback is not None:
callback(self._injected_token)
def is_addon_first_onboarding_done(self) -> bool:
return self.settings_config.getboolean("onboarding",
"addon_first_completed",
fallback=False)
def refresh_top_level_queries_flags(self):
self.all_assets_fetched = False
self.my_assets_fetched = False
self.recent_downloads_fetched = False
def are_user_assets_fetched(self) -> bool:
if self.user_legacy_own_assets():
return self.my_assets_fetched and self.recent_downloads_fetched
return self.recent_downloads_fetched
def set_addon_first_onboarding_done(self) -> None:
self.settings_config.set("onboarding", "addon_first_completed", "true")
self._settings.save_settings()
def is_onboarding_wm_preview_done(self) -> bool:
# This parameter should be set as True in the dcc side, everytime a
# watermarked preview is imported;
return self.settings_config.getboolean("onboarding",
"did_watermark_preview",
fallback=False)
def dismiss_onboarding_wm_notice(self):
if self.wm_onboarding_notice:
self.notify.dismiss_notice(self.wm_onboarding_notice, force=True)
self.wm_onboarding_notice = None
def set_onboarding_wm_preview_done(self, check_banner: bool = False):
self.settings_config.set("onboarding", "did_watermark_preview", "true")
self._settings.save_settings()
self.dismiss_onboarding_wm_notice()
if check_banner:
self.upgrade_manager.check_show_banner()
def check_onboarding_notice(self):
did_wm = self.is_onboarding_wm_preview_done()
if self.wm_onboarding_notice is not None:
self.dismiss_onboarding_wm_notice()
if not did_wm and not self.is_unlimited_user():
self.wm_onboarding_notice = self.notify.create_watermarked_onboarding()
def setup_libraries(self):
default_lib_path = os.path.join(self.user_addon_dir, "Library")
multi_dir = self.settings_config["directories"]
@@ -334,35 +558,6 @@ class PoliigonAddon():
"""Clears any invalidation flag for a user."""
self._api.invalidated = False
@run_threaded(tm.PoolKeys.INTERACTIVE)
def log_in_with_credentials(self,
email: str,
password: str,
*,
wait_for_user: bool = False) -> Future:
self.clear_user_invalidated()
req = self._api.log_in(
email,
password
)
if req.ok:
user_data = req.body.get("user", {})
fut = self.create_user(user_data.get("name"), user_data.get("id"))
if wait_for_user:
fut.result(timeout=api.TIMEOUT)
self.login_error = None
else:
self.login_error = req.error
return req
def log_in_with_website(self):
pass
def check_for_survey_notice(
self,
free_user_url: str,
@@ -403,19 +598,6 @@ class PoliigonAddon():
on_dismiss_callable=set_user_survey_flag
)
@run_threaded(tm.PoolKeys.INTERACTIVE)
def log_out(self):
req = self._api.log_out()
if req.ok:
print("Logout success")
else:
print(req.error)
self._api.token = None
# Clear out user on logout.
self.user = None
def add_library_path(self,
path: str,
primary: bool = True,
@@ -494,79 +676,6 @@ class PoliigonAddon():
else:
return None
def _get_user_info(self) -> Tuple:
req = self._api.get_user_info()
user_name = None
user_id = None
if req.ok:
data = req.body
user_name = data["user"]["name"]
user_id = data["user"]["id"]
self.login_error = None
else:
# TODO(SOFT-1029): Create an error log for fail in get user info
self.login_error = req.error
return user_name, user_id
def _get_credits(self):
if self.user is None:
msg = "_get_credits() called without user."
self._api.report_message(
"addon_get_credits", msg, "error")
return
req = self._api.get_user_balance()
if req.ok:
data = req.body
self.user.credits = data.get("subscription_balance")
self.user.credits_od = data.get("ondemand_balance")
else:
self.user.credits = None
self.user.credits_od = None
msg = f"ERROR: {req.error}"
self._api.report_message(
"addon_get_credits", msg, "error")
def _get_subscription_details(self):
"""Fetches the current user's subscription status."""
req = self._api.get_subscription_details()
if req.ok:
plan = req.body
self.user.plan.update_from_dict(plan)
@run_threaded(tm.PoolKeys.INTERACTIVE)
def update_plan_data(self, done_callback: Optional[Callable] = None) -> None:
# TODO(Joao): sub thread the two private functions
self._get_credits()
self._get_subscription_details()
if done_callback is not None:
done_callback()
def create_user(
self,
user_name: Optional[str] = None,
user_id: Optional[int] = None,
done_callback: Optional[Callable] = None) -> Optional[Future]:
if user_name is None or user_id is None:
user_name, user_id = self._get_user_info()
if user_name is None or user_id is None:
return None
self.user = PoliigonUser(
user_name=user_name,
user_id=user_id,
plan=PoliigonSubscription(
subscription_state=SubscriptionState.NOT_POPULATED)
)
future = self.update_plan_data(done_callback)
return future
def is_free_user(self) -> bool:
"""Identifies a free user which neither
has a plan nor on demand credits."""
@@ -591,6 +700,15 @@ class PoliigonAddon():
return False
return self.user.plan.is_unlimited
def is_legacy_limited_user(self) -> bool:
if self.user is None:
return False
elif self.user.plan in [None, SubscriptionState.NOT_POPULATED]:
return False
elif self.user.plan.is_limited_legacy is None:
return False
return self.user.plan.is_limited_legacy
def is_paused_subscription(self) -> Optional[bool]:
"""Return True, if the Subscription is in paused state.
@@ -601,6 +719,11 @@ class PoliigonAddon():
return None
return self.user.plan.subscription_state == SubscriptionState.PAUSED
def user_legacy_own_assets(self):
"""Return True if the user has any Owned Asset (Legacy)"""
return self.user is not None and self.user.count_assets_owned
def get_user_credits(self, incl_od: bool = True) -> int:
"""Returns the number of _spendable_ credits."""
@@ -718,23 +841,6 @@ class PoliigonAddon():
"user", "first_purchase", str(time_stamp))
self._settings.save_settings()
def print_debug(self, *args, dbg=False, bg=True):
"""Print out a debug statement with no separator line.
Cache based on args up to a limit, to avoid excessive repeat prints.
All args must be flat values, such as already casted to strings, else
an error will be thrown.
"""
if dbg:
# Ensure all inputs are hashable, otherwise lru_cache fails.
stringified = [str(arg) for arg in args]
self._cached_print(*stringified, bg=bg)
@lru_cache(maxsize=32)
def _cached_print(self, *args, bg: bool):
"""A safe-to-cache function for printing."""
print(*args)
def open_asset_url(self, asset_id: int) -> None:
asset_data = self._asset_index.get_asset(asset_id)
url = self._api.add_utm_suffix(asset_data.url)
@@ -759,50 +865,6 @@ class PoliigonAddon():
path_wm_previews = os.path.join(path_thumbs, asset_name)
return path_wm_previews
def download_material_wm(
self, files_to_download: List[Tuple[str, str]]) -> None:
"""Synchronous function to download material preview."""
urls = []
files_dl = []
for _url_wm, _filename_wm_dl in files_to_download:
urls.append(_url_wm)
files_dl.append(_filename_wm_dl)
resp = self._api.pooled_preview_download(urls, files_dl)
if not resp.ok:
msg = f"Failed to download WM preview\n{resp}"
self._api.report_message(
"download_mat_preview_dl_failed", msg, "error")
# Continue, as some may have worked.
for _filename_wm_dl in files_dl:
filename_wm = _filename_wm_dl[:-3] # cut of _dl
try:
file_exists = os.path.exists(filename_wm)
dl_exists = os.path.exists(_filename_wm_dl)
if file_exists and dl_exists:
os.remove(filename_wm)
elif not file_exists and not dl_exists:
raise FileNotFoundError
if dl_exists:
os.rename(_filename_wm_dl, filename_wm)
except FileNotFoundError:
msg = f"Neither {filename_wm}, nor {_filename_wm_dl} exist"
self._api.report_message(
"download_mat_existing_file", msg, "error")
except FileExistsError:
msg = f"File {filename_wm} already exists, failed to rename"
self._api.report_message(
"download_mat_rename", msg, "error")
except Exception as e:
self.logger.exception("Unexpected exception while renaming WM preview")
msg = f"Unexpected exception while renaming {_filename_wm_dl}\n{e}"
self._api.report_message(
"download_wm_exception", msg, "error")
return resp
def get_config_param(self,
name_param: str,
name_group: str = "DEFAULT",
File diff suppressed because it is too large Load Diff
@@ -47,8 +47,10 @@ from .api_remote_control_params import (
ApiJobParamsPutUpgradePlan,
ApiJobParamsResumePlan,
ApiJobParamsGetAssets,
ApiJobParamsGetAllAssets,
ApiJobParamsLogin,
ApiJobParamsPurchaseAsset,
CATEGORY_ALL,
CmdLoginMode
)
from .assets import AssetData
@@ -82,11 +84,12 @@ class JobType(IntEnum):
PUT_UPGRADE_PLAN = 6
RESUME_PLAN = 7
GET_ASSETS = 10
DOWNLOAD_THUMB = 11
PURCHASE_ASSET = 12
DOWNLOAD_ASSET = 13
DOWNLOAD_WM_PREVIEW = 14,
UNIT_TEST = 15,
GET_ALL_ASSETS = 11
DOWNLOAD_THUMB = 12
PURCHASE_ASSET = 13
DOWNLOAD_ASSET = 14
DOWNLOAD_WM_PREVIEW = 15,
UNIT_TEST = 16,
EXIT = 99999
@@ -242,6 +245,9 @@ class ApiRemoteControl():
def add_job_get_user_data(self,
user_name: str,
user_id: str,
do_fetch_plans: bool = True,
do_fetch_categories: bool = True,
do_fetch_asset_data: bool = True,
callback_cancel: Optional[Callable] = None,
callback_progress: Optional[Callable] = None,
callback_done: Optional[Callable] = None,
@@ -249,7 +255,12 @@ class ApiRemoteControl():
) -> None:
"""Convenience function to add a get user data job."""
params = ApiJobParamsGetUserData(user_name, user_id)
params = ApiJobParamsGetUserData(
user_name=user_name,
user_id=user_id,
do_fetch_plans=do_fetch_plans,
do_fetch_categories=do_fetch_categories,
do_fetch_asset_data=do_fetch_asset_data)
self.add_job(
job_type=JobType.GET_USER_DATA,
params=params,
@@ -363,7 +374,7 @@ class ApiRemoteControl():
def add_job_get_assets(self,
library_paths: List[str],
tab: str, # KEY_TAB_ONLINE, KEY_TAB_MY_ASSETS
category_list: List[str] = ["All Assets"],
category_list: List[str] = [CATEGORY_ALL],
search: str = "",
idx_page: int = 1,
page_size: int = 10,
@@ -373,10 +384,14 @@ class ApiRemoteControl():
callback_progress: Optional[Callable] = None,
callback_done: Optional[Callable] = None,
force: bool = True,
ignore_old_names: bool = True
do_my_assets: bool = False
) -> None:
"""Convenience function to add a get assets job."""
if search != "":
# With API V2 it's either search or categories, not both
category_list = [CATEGORY_ALL]
params = ApiJobParamsGetAssets(library_paths,
tab,
category_list,
@@ -385,7 +400,8 @@ class ApiRemoteControl():
page_size,
force_request,
do_get_all,
ignore_old_names)
do_my_assets)
self.add_job(
job_type=JobType.GET_ASSETS,
params=params,
@@ -395,6 +411,31 @@ class ApiRemoteControl():
force=force,
timeout=TIMEOUT)
def add_job_get_all_assets(
self,
library_paths: List[str],
do_my_assets: bool = False,
force_request: bool = False,
callback_cancel: Optional[Callable] = None,
callback_progress: Optional[Callable] = None,
callback_done: Optional[Callable] = None,
force: bool = True
) -> None:
"""Convenience function to add a get assets job."""
params = ApiJobParamsGetAllAssets(
library_paths=library_paths,
force_request=force_request,
do_my_assets=do_my_assets)
self.add_job(
job_type=JobType.GET_ALL_ASSETS,
params=params,
callback_cancel=callback_cancel,
callback_progress=callback_progress,
callback_done=callback_done,
force=force,
timeout=TIMEOUT)
def add_job_download_thumb(self,
asset_id: int,
url: str,
@@ -0,0 +1,103 @@
# #### BEGIN GPL LICENSE BLOCK #####
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# ##### END GPL LICENSE BLOCK #####
import enum
from typing import List, Optional, Any
from .multilingual import _m, _t
from .user import PoliigonUser
from .api_remote_control_params import (KEY_TAB_ONLINE,
KEY_TAB_MY_ASSETS,
KEY_TAB_RECENT_DOWNLOADS,
KEY_TAB_IMPORTED,
KEY_TAB_LOCAL)
class FilterOptions(enum.Enum):
ALL_ASSETS = _m("All Assets")
OWNED_ASSETS = _m("Owned")
RECENT_DOWNLOADED_ASSETS = _m("Downloads")
LOCAL_ASSETS = _m("Local")
IMPORTED_ASSETS = _m("Imported")
@classmethod
def get_list(cls):
return [member for member in FilterOptions]
@classmethod
def get_list_for_user(cls, user: PoliigonUser):
# If a user doesn't own any asset, the owned filter is not shown
if user.count_assets_owned is not None and user.count_assets_owned == 0:
return [member for member in FilterOptions
if member != cls.OWNED_ASSETS]
return cls.get_list()
@classmethod
def get_string_list(cls, short_string: bool = False) -> List[str]:
if short_string:
return [member.map_to_short_value() for member in FilterOptions]
return [member.value for member in FilterOptions]
@classmethod
def get_translated_string_list(cls, short_string: bool = False) -> List[str]:
if short_string:
return [_t(member.map_to_short_value()) for member in FilterOptions]
return [_t(member.value) for member in FilterOptions]
@classmethod
def get_from_query(cls, key_query: str) -> Optional[Any]:
if key_query == KEY_TAB_ONLINE:
return cls.ALL_ASSETS
elif key_query == KEY_TAB_MY_ASSETS:
return cls.OWNED_ASSETS
elif key_query == KEY_TAB_RECENT_DOWNLOADS:
return cls.RECENT_DOWNLOADED_ASSETS
elif key_query == KEY_TAB_LOCAL:
return cls.LOCAL_ASSETS
elif key_query == KEY_TAB_IMPORTED:
return cls.IMPORTED_ASSETS
else:
return None
def map_to_query(self) -> Optional[str]:
if self == self.ALL_ASSETS:
return KEY_TAB_ONLINE
elif self == self.OWNED_ASSETS:
return KEY_TAB_MY_ASSETS
elif self == self.RECENT_DOWNLOADED_ASSETS:
return KEY_TAB_RECENT_DOWNLOADS
elif self == self.LOCAL_ASSETS:
return KEY_TAB_LOCAL
elif self == self.IMPORTED_ASSETS:
return KEY_TAB_IMPORTED
else:
return None
def map_to_short_value(self) -> Optional[str]:
if self == self.ALL_ASSETS:
return _m("All")
elif self == self.OWNED_ASSETS:
return self.value
elif self == self.RECENT_DOWNLOADED_ASSETS:
return self.value
elif self == self.LOCAL_ASSETS:
return self.value
elif self == self.IMPORTED_ASSETS:
return self.value
else:
return None
@@ -34,13 +34,15 @@ from . import logger
from .maps import MAPS_TYPE_NAMES, MapType
ISOFORMAT_EPOCH_BEGIN = "1970-01-01T00:00:00.000Z"
REGEX_DATETIME_ISOFORMAT = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}\.[0-9]{3}Z$"
# Compiled regex to avoid re-instancing each time.
# Checks for preview being in the last section of a filename split into _'s
# The [^_]* means match any character except another _, and $ asserts it's at
# the end of the match string.
_PREVIEW_PATTERN = re.compile(r"_[^_]*preview[^_]*$", re.IGNORECASE)
# Used to report unknown asset type exactly once per session
g_asset_unsupported_type_reported: bool = False
@@ -100,6 +102,7 @@ class AssetIndex():
self.all_assets = {}
self.cached_queries = {}
self.reported_asset_ids = []
self.reported_map_codes = []
self.use_lod_extras = use_lod_extras
if use_lod_extras:
@@ -115,9 +118,33 @@ class AssetIndex():
return
self.reporting_callable(message, code_msg, level, max_reports)
@staticmethod
def _filter_image_urls(urls: List[str]) -> List[str]:
return [url for url in urls if ".png" in url.lower() or ".jpg" in url.lower()]
def _convert_api_time_to_timestamp(self, time_api: str) -> float:
try:
dt = datetime.fromisoformat(time_api)
except ValueError:
dt = datetime.fromisoformat(ISOFORMAT_EPOCH_BEGIN)
warn = f"Time string from API off expected format: {time_api}"
self.logger.warning(warn)
t_tuple = dt.utctimetuple()
seconds_since_epoch = time.mktime(t_tuple)
return seconds_since_epoch
def _get_api_time(self, data: Dict, key: str) -> float:
t = data.get(key, ISOFORMAT_EPOCH_BEGIN)
if t is None:
# TODO(Andreas): Not sure how we would want to report this, here
return 0.0
if not re.match(REGEX_DATETIME_ISOFORMAT, t):
# TODO(Andreas): Not sure how we would want to report this, here
return 0.0
# We know our times are UTC, yet, for unknown reasons Python's
# datetime.fromisoformat() does not allow for the ISO conform
# timezone qualifier ('Z' for UTC at the end).
# So, we need to split it off.
t = t[:-1]
secs_since_epoch = self._convert_api_time_to_timestamp(t)
return secs_since_epoch
def _get_cloudflare_thumbnails(
self,
@@ -129,27 +156,18 @@ class AssetIndex():
return []
for thumb in cloudflare_thumbs:
filename = thumb.get("file_name", None)
thumb_time = thumb.get("time", None)
filename = thumb.get("fileName", None)
if filename is None:
warn = "Filename not available in Cloudflare Thumbnail."
self.logger.warning(warn)
continue
filename = filename.split("?")[0]
if thumb_time is not None:
try:
thumb_time = datetime.strptime(
thumb_time, "%Y-%m-%d %H:%M:%S")
thumb_time = thumb_time.astimezone(timezone.utc).timestamp()
except (ValueError, OSError):
warn = "Thumbnail time string off expected format."
self.logger.warning(warn)
thumb_time = self._get_api_time(data=thumb, key="time")
thumb_class = assets.AssetThumbnail(
filename=filename,
base_url=thumb.get("base_url", None),
base_url=thumb.get("baseUrl", None),
index=thumb.get("position", None),
time=thumb_time,
type=thumb.get("type", None)
@@ -160,7 +178,7 @@ class AssetIndex():
@staticmethod
def _get_texture_real_world_dimension(
asset_dict: Dict) -> Optional[Tuple[float, float]]:
""" Method to get Real World Dimensions of an asset in Convention 0.
"""Method to get Real World Dimensions of an asset in Convention 0.
For convention 1, use _decode_tex_convention_1 to get this info. """
dimension_height = None
@@ -168,27 +186,23 @@ class AssetIndex():
dimension_unit_str = ""
dimension_dict = {}
asset_name = asset_dict.get("asset_name", "")
asset_name = asset_dict.get("AssetName", "")
# Ignore dimensions for Atlas Textures
if (asset_name.lower()).startswith("atlas"):
return None
dimension_str = None
render_custom_schema = asset_dict.get("render_custom_schema", None)
if render_custom_schema is not None:
dimension_str = render_custom_schema.get("dimensions", None)
technical_desc = render_custom_schema.get("technical_description", {})
if type(technical_desc) is dict:
dimension_dict = technical_desc.get("Dimensions", {})
if dimension_dict is None:
dimension_dict = {}
specs = asset_dict.get("Specifications", None)
if specs is not None:
dimension_str = specs.get("dimensions", None)
dimension_dict = specs.get("physical_size_cm", {})
if dimension_dict is None:
dimension_dict = {}
for key in dimension_dict.keys():
dimension_key = key.split(" ")
if len(dimension_key) < 2:
continue
dimension_unit_str = dimension_key[1]
dimension_unit_str = "cm"
if "height" in key.lower():
dimension_height = float(dimension_dict[key])
elif "width" in key.lower():
@@ -203,7 +217,7 @@ class AssetIndex():
dimension_width = float(dimension_val[2])
dimension_unit_str = str(dimension_val[3])
except (ValueError, IndexError):
# if the first value is not integer, consider the format invalid
# if the first value is not float, consider the format invalid
return None
if dimension_height is None and dimension_width is None:
@@ -233,19 +247,29 @@ class AssetIndex():
# Always get Metalness as Workflow for Convention 1
workflow = "METALNESS"
asset_maps = asset_dict.get("maps", [])
asset_maps = asset_dict.get("Maps", [])
all_sizes = asset_dict.get("resolutions", [])
all_sizes = asset_dict.get("Resolutions", [])
sizes = [_size for _size in all_sizes if self._is_valid_size(_size)]
asset_id = asset_dict.get("AssetID", -1)
maps_dictionary = {workflow: []}
for _map_desc in asset_maps:
try:
map_code = _map_desc.get("type", "UNKNOWN")
file_formats = _map_desc.get("file_formats", [])
if MapType.from_type_code(map_code) == MapType.UNKNOWN:
msg = (f"AssetId: {asset_id} Unknown Map "
f"type from server: {map_code}")
self.logger.warning(msg)
if map_code not in self.reported_map_codes:
self.reported_map_codes.append(map_code)
self.capture_message("unknown_map_type_found", msg)
continue
except Exception:
map_code = "UNKNOWN"
file_formats = []
asset_id = asset_dict.get("id", -1)
msg = f"Asset {asset_id}: Invalid maps data\n{asset_maps}"
self.capture_message("assetindex_invalid_maps_data", msg)
@@ -257,7 +281,7 @@ class AssetIndex():
variants=[])
maps_dictionary[workflow].append(description)
specs = asset_dict.get("specifications", {})
specs = asset_dict.get("Specifications", {})
creation_str = specs.get("creation_method", "")
creation_method = assets.CreationMethodId.from_string(creation_str)
@@ -271,8 +295,8 @@ class AssetIndex():
return maps_dictionary, sizes, dimensions, creation_method
@staticmethod
def _decode_render_schema_tex(asset_dict: Dict
def _decode_render_schema_tex(self,
asset_dict: Dict
) -> Tuple[Dict[str, assets.TextureMapDesc],
List[str],
List[str]]:
@@ -286,26 +310,34 @@ class AssetIndex():
NOTE: This is the default Texture decode function for convention 0
"""
if "render_schema" not in asset_dict.keys():
if "Maps" not in asset_dict.keys():
return ({}, [], [])
all_sizes = []
asset_id = asset_dict.get("AssetID", -1)
all_sizes = asset_dict.get("Resolutions", [])
all_variants = []
tex_desc_dict = {} # {workflow: List[TextureMapDesc]
for schema in asset_dict["render_schema"]:
if "types" not in schema.keys():
continue
workflow = schema.get("name", "REGULAR")
dict_maps = asset_dict.get("Maps", [])
for _maps_dict in dict_maps:
workflow = _maps_dict.get("workflow", "REGULAR")
tex_descs = []
for tex_type in schema.get("types", []):
tex_code = tex_type.get("type_code", "")
for _tex_code in _maps_dict.get("maps", []):
variant = None
if "_" in tex_code:
tex_code, variant = tex_code.split("_")
if tex_code not in MAPS_TYPE_NAMES:
tex_code = MapType.UNKNOWN.name
map_type = MapType[tex_code]
if "_" in _tex_code:
_tex_code, variant = _tex_code.split("_")
map_type = MapType.from_type_code(_tex_code)
if MapType.from_type_code(_tex_code) == MapType.UNKNOWN:
msg = (f"AssetId: {asset_id} Unknown Map "
f"type from server: {_tex_code}")
self.logger.warning(msg)
if _tex_code not in self.reported_map_codes:
self.reported_map_codes.append(_tex_code)
self.capture_message("unknown_map_type_found", msg)
continue
if variant is not None:
variants = [variant]
@@ -313,16 +345,15 @@ class AssetIndex():
else:
variants = None
type_code = tex_type.get("type_code", "")
type_name = tex_type.get("type_name", "")
type_options = tex_type.get("type_options", [])
type_preview = tex_type.get("type_preview", "")
tex_desc = assets.TextureMapDesc(map_type_code=type_code,
file_formats=[], # convention 0 assets don't have formats
display_name=type_name,
sizes=type_options,
filename_preview=type_preview,
variants=variants)
type_preview = "" # TODO(Andreas): no longer available?
tex_desc = assets.TextureMapDesc(
map_type_code=_tex_code,
# TODO(Andreas): It seems with API V2 we would have formats available
file_formats=[], # convention 0 assets don't have formats
display_name=_tex_code,
sizes=all_sizes,
filename_preview=type_preview,
variants=variants)
tex_desc_variant = None
if variant is None:
for tex_desc_prev in tex_descs:
@@ -337,8 +368,6 @@ class AssetIndex():
# variants are otherwise identical
tex_desc_variant.variants.extend(tex_desc.variants)
all_sizes.extend(tex_desc.sizes)
tex_desc_dict[workflow] = tex_descs
# consolidate all sizes and variants for use in menus
@@ -367,31 +396,27 @@ class AssetIndex():
tuple[1] - Default size from included_resolution (if available)
"""
if "render_schema" not in asset_dict.keys():
msg = f"'render_schema' missing in asset dict\n{asset_dict}"
self.capture_message("assetindex_no_renderschema", msg)
return [], ""
render_schema = asset_dict.get("render_schema", {})
if "options" not in render_schema.keys():
msg = f"'options' missing in 'render_schema'\n{render_schema}"
self.capture_message("assetindex_no_renderschema_options", msg)
return [], ""
all_sizes = render_schema.get("options", [])
all_sizes = asset_dict.get("Resolutions", [])
all_sizes = [size for size in all_sizes if self._is_valid_size(size)]
render_custom_schema = asset_dict.get("render_custom_schema", {})
incl_size = None
if "included_resolution" in render_custom_schema.keys():
incl_size = render_custom_schema.get("included_resolution", "")
if "DefaultResolution" in asset_dict.keys():
incl_size = asset_dict.get("DefaultResolution", "")
if self._is_valid_size(incl_size):
all_sizes.extend([incl_size])
all_sizes = sorted(list(set(all_sizes)),
key=lambda s: int(s[:-1]))
return all_sizes, incl_size
def _construct_watermarked_urls(self, asset_dict: Dict) -> List[str]:
toolbox_previews = asset_dict.get("ToolboxPreviews", [])
urls_wm = []
for _preview in toolbox_previews:
url_base = _preview.get("baseUrl")
filename = _preview.get("fileName")
urls_wm.append(f"{url_base}/{filename}")
return urls_wm
def _construct_brush(self,
asset_dict: Dict,
convention: int
@@ -406,8 +431,8 @@ class AssetIndex():
"""Constructs a Model"""
model = assets.Model()
if "lods" in asset_dict.keys():
model.lods = asset_dict["lods"]
if "LODs" in asset_dict.keys():
model.lods = asset_dict.get("LODs", [])
model.sizes, model.size_default = self._decode_render_schema_model(
asset_dict)
return model
@@ -438,8 +463,7 @@ class AssetIndex():
tex_bg = assets.Texture(map_descs=tex_map_descs_bg,
sizes=sizes,
variants=variants)
tex_bg.watermarked_urls = self._filter_image_urls(
asset_dict["toolbox_previews"])
tex_bg.watermarked_urls = []
tex_bg.maps = {}
tex_light = assets.Texture(map_descs=tex_map_descs_light,
@@ -459,7 +483,10 @@ class AssetIndex():
creation_method = None
if convention == 1:
maps, sizes, dimension, creation_method = self._decode_tex_convention_1(asset_dict)
(maps,
sizes,
dimension,
creation_method) = self._decode_tex_convention_1(asset_dict)
variants = []
else:
maps, sizes, variants = self._decode_render_schema_tex(asset_dict)
@@ -470,16 +497,17 @@ class AssetIndex():
variants=variants,
real_world_dimension=dimension,
creation_method=creation_method)
tex.watermarked_urls = self._filter_image_urls(
asset_dict["toolbox_previews"])
tex.watermarked_urls = self._construct_watermarked_urls(asset_dict)
tex.maps = {}
return tex
def _construct_asset_base(self, asset_dict: Dict) -> assets.AssetData:
global g_asset_unsupported_type_reported
asset_name = asset_dict["name"]
asset_type_api = asset_dict["type"]
asset_id = asset_dict.get("AssetID", 0)
asset_name = asset_dict.get("AssetName", "NoName")
display_name = asset_dict.get("Name", "No Name")
asset_type_api = asset_dict.get("Type", "Unsupported")
try:
asset_type = assets.API_TYPE_TO_ASSET_TYPE[asset_type_api]
@@ -487,18 +515,17 @@ class AssetIndex():
asset_type = assets.AssetType.UNSUPPORTED
if asset_type == assets.AssetType.SUBSTANCE:
msg = f"{asset_name}: {asset_type_api} not supported, yet"
msg = f"{display_name}: {asset_type_api} not supported, yet"
raise NotImplementedError(msg)
elif asset_type == assets.AssetType.UNSUPPORTED:
msg = f"{asset_name}: {asset_type_api} not supported, yet"
msg = f"{display_name}: {asset_type_api} not supported, yet"
if g_asset_unsupported_type_reported is False:
self.capture_message("assetindex_unsupported_type", msg)
g_asset_unsupported_type_reported = True
raise NotImplementedError(msg)
return assets.AssetData(asset_id=asset_dict["id"],
asset_type=asset_type,
asset_name=asset_dict["asset_name"])
return assets.AssetData(
asset_id=asset_id, asset_type=asset_type, asset_name=asset_name)
def _populate_default_asset_info(self,
asset_data: assets.AssetData,
@@ -507,39 +534,49 @@ class AssetIndex():
) -> None:
"""Constructs AssetData part common to all types"""
asset_data.display_name = asset_dict["name"]
asset_data.display_name = asset_dict.get("Name", "No Name")
asset_data.categories = []
for category in asset_dict["categories"]:
category = category.title()
if category in assets.CATEGORY_TRANSLATION:
category = assets.CATEGORY_TRANSLATION[category]
asset_id = asset_data.asset_id
for category in asset_dict.get("Categories", []):
asset_data.categories.append(category)
asset_data.url = asset_dict["url"]
asset_data.slug = asset_dict["slug"]
asset_data.credits = asset_dict["credit"]
slug = asset_dict.get("Slug", "no-name")
asset_data.slug = slug
asset_data.url = asset_dict.get("WebsiteUrl", "")
asset_data.credits = asset_dict.get("Credit", 1)
try:
asset_data.api_convention = int(asset_dict.get("convention", 0))
asset_data.api_convention = int(asset_dict.get("Convention", 0))
except ValueError:
asset_data.api_convention = 0
asset_data.local_convention = None # to be determined during update_from_directory()
dict_previews = asset_dict.get("Previews", [])
asset_data.cloudflare_thumb_urls = self._get_cloudflare_thumbnails(
asset_dict.get("cloudflare_previews", None))
dict_previews)
asset_data.thumb_urls = self._filter_image_urls(asset_dict["previews"])
published_at = asset_dict["published_at"]
t_published_at = time.strptime(published_at, "%Y-%m-%d %H:%M:%S")
seconds_since_epoch = time.mktime(t_published_at)
asset_data.published_at = seconds_since_epoch # TODO(Andreas): need to take timezone into account
seconds_since_epoch = self._get_api_time(
data=asset_dict, key="PublishedAt")
asset_data.published_at = seconds_since_epoch
asset_data.is_local = None
asset_data.downloaded_at = None
asset_data.is_purchased = purchased
asset_data.purchased_at = None
asset_data.render_custom_schema = asset_dict.get(
"render_custom_schema", {})
asset_data.old_asset_names = asset_data.render_custom_schema.get(
"previous_filenames", None)
rcs = asset_dict.get("RenderCustomSchema", {})
specs = asset_dict.get("Specifications", {})
if type(rcs) is list:
asset_data.status.has_error = True
asset_data.status.error = "No Render Custom Schema"
rcs = {}
if asset_id not in self.reported_asset_ids:
msg = f"AssetId: {asset_id} Error: {asset_data.status.error}"
self.capture_message("fail_asset_data_populate_rcs", msg)
self.reported_asset_ids.append(asset_id)
asset_data.render_custom_schema = rcs
asset_data.specifications = specs
def construct_asset(self,
asset_dict: Dict,
@@ -570,6 +607,7 @@ class AssetIndex():
except Exception as e:
asset_data.state.has_error = True
asset_data.state.error = e
asset_data.create_dummy_type_data()
if asset_data.asset_id not in self.reported_asset_ids:
msg = f"AssetId: {asset_data.asset_id} Error: {e}"
self.capture_message("fail_asset_data_populate", msg)
@@ -623,6 +661,13 @@ class AssetIndex():
utc_s_since_epoch = datetime.now(timezone.utc).timestamp()
self.all_assets[asset_id].purchased_at = utc_s_since_epoch
def mark_recent_downloaded(self, asset_id: int) -> None:
"""Marks an AssetData as Recent Downloaded"""
if asset_id not in self.all_assets:
return
self.all_assets[asset_id].is_recent_downloaded = True
def _map_type_from_filename_parts(self, filename_parts: List[str]):
"""Gets a MapType (and its workflow) from a list of parts of a filename.
@@ -787,7 +832,6 @@ class AssetIndex():
if lod is not None:
lods.append(lod)
else:
lods.append("NONE")
lod = "NONE"
if size is not None:
sizes.append(size)
@@ -1108,10 +1152,8 @@ class AssetIndex():
asset_data.is_local = False
return files_found
def gather_asset_name_dict(self,
asset_id_list: List[int],
ignore_old_names: bool = True
) -> Dict[str, int]:
def gather_asset_name_dict(
self, asset_id_list: List[int]) -> Dict[str, int]:
"""Returns a dictionary with asset name keys (including optional old
names) and asset ID values for all asset IDs in list."""
@@ -1124,16 +1166,6 @@ class AssetIndex():
asset_name = asset_data.asset_name
asset_name_dict[asset_name] = asset_id
if ignore_old_names:
continue
old_asset_names = asset_data.old_asset_names
if old_asset_names is None or len(old_asset_names) == 0:
continue
for _old_name in old_asset_names:
asset_name_dict[_old_name] = asset_id
return asset_name_dict
def update_all_local_assets(self,
@@ -1179,7 +1211,7 @@ class AssetIndex():
asset_type=None, purchased=purchased)
asset_name_dict = self.gather_asset_name_dict(
asset_id_list, ignore_old_names=ignore_old_names)
asset_id_list)
# update_from_directory() overwrites file reference entries.
# Thus the primary library directory has to be scanned last.
@@ -1333,7 +1365,7 @@ class AssetIndex():
except NotImplementedError:
# Silence Substance exceptions
self.logger.info(("Unsupported asset type encountered.\n"
f" {asset_dict['name']}: {asset_dict['type']}"))
f" {asset_dict['Name']}: {asset_dict['Type']}"))
if query_tuple in self.cached_queries:
self.cached_queries[query_tuple].extend(tmp_cached_query)
return tmp_cached_query
@@ -1412,12 +1444,16 @@ class AssetIndex():
asset_ids: List[int],
key_query: str,
chunk: Optional[int] = None,
chunk_size: Optional[int] = None
chunk_size: Optional[int] = None,
do_append: bool = False
) -> None:
"""Stores a list of asset IDs in query cache."""
query_tuple = self._query_key_to_tuple(key_query, chunk, chunk_size)
self.cached_queries[query_tuple] = asset_ids
if do_append:
self.cached_queries[query_tuple].extend(asset_ids)
else:
self.cached_queries[query_tuple] = asset_ids
def query_exists(self,
key_query: str,
@@ -1615,7 +1651,8 @@ class AssetIndex():
lod: Optional[str] = None,
model_type: Optional[assets.ModelType] = None,
native_only: bool = False,
renderer: Optional[str] = None
renderer: Optional[str] = None,
check_map_preferences: bool = False
) -> bool:
"""Checks if an asset (or a flavor thereof) has been downloaded.
@@ -1646,7 +1683,7 @@ class AssetIndex():
incl_watermarked = size == "WM"
local_sizes = self.check_asset_local_sizes(
asset_id, workflow, incl_watermarked)
asset_id, workflow, incl_watermarked, check_map_preferences)
if size is None:
tex_is_local = any(local_sizes.values())
else:
@@ -1682,7 +1719,8 @@ class AssetIndex():
def check_asset_local_sizes(self,
asset_id: int,
workflow: Optional[str] = "REGULAR",
incl_watermarked: bool = False
incl_watermarked: bool = False,
check_map_preferences: bool = False
) -> Dict[str, bool]:
"""Returns texture 'locality' by size.
@@ -1700,13 +1738,24 @@ class AssetIndex():
asset_data = self.all_assets[asset_id]
type_data = asset_data.get_type_data()
local_sizes = {}
all_sizes = type_data.get_size_list(incl_watermarked) # local only False, no conv needed
local_convention = asset_data.get_convention(local=True)
if local_convention is not None:
convention_min = min(local_convention, self.addon_convention)
else:
convention_min = self.addon_convention
map_prefs = None
if check_map_preferences:
map_prefs = self.addon.user.map_preferences
local_sizes = {}
all_sizes = type_data.get_size_list(incl_watermarked=incl_watermarked,
# Local only set as false to get
# False values on return dict;
local_only=False,
local_convention=local_convention,
addon_convention=self.addon_convention,
map_preferences=map_prefs)
for size in all_sizes:
if workflow is None:
workflow_list = self.get_asset_workflow_list(asset_id)
@@ -1773,37 +1822,6 @@ class AssetIndex():
return local_lods
# TODO(Joao): Deprecated - Delete function and use cloudflare previews
def get_thumbnail_url_list(self, asset_id: int) -> List[str]:
"""Gets _all_ URLs for an asset's thumbnails"""
if asset_id not in self.all_assets:
return None
ad = self.all_assets[asset_id]
return ad.thumb_urls
# TODO(Joao): Deprecated - Delete function and use cloudflare previews
def get_thumbnail_url_by_index(self,
asset_id: int,
index: int = 0) -> Optional[str]:
"""Returns preview url via index, if index exists,
otherwise the first preview url will be returned.
Return value may be None, e.g. in case of dummy entries.
"""
if index < 0:
raise ValueError
if asset_id not in self.all_assets:
return None
ad = self.all_assets[asset_id]
if ad.thumb_urls is None or len(ad.thumb_urls) == 0:
return None
elif index < len(ad.thumb_urls):
return ad.thumb_urls[index]
else:
return ad.thumb_urls[0]
def get_asset_cf_thumbnails(self, asset_id):
if asset_id not in self.all_assets:
return None
@@ -1880,44 +1898,6 @@ class AssetIndex():
return path_thumb, url
return None, None
# TODO(Joao): Deprecated - Delete function and use cloudflare previews
def get_thumbnail_url_by_name(self,
asset_id: int,
name: str = "sphere") -> Optional[str]:
"""Returns preview url via name extension, if it exists.
Return value may be None, e.g. in case name not found.
"""
if asset_id not in self.all_assets:
return None
asset_data = self.all_assets[asset_id]
if asset_data.thumb_urls is None or len(asset_data.thumb_urls) == 0:
return None
name = name.lower()
result_url = None
for url in asset_data.thumb_urls:
if name in url.lower():
result_url = url
break
return result_url
# TODO(Andreas): maybe not URLs...
def get_large_preview_url_list(self, asset_id: int) -> List[str]:
"""Gets _all_ URLs for an asset's large previews"""
# TODO(Andreas)
return []
def get_large_preview_url(self,
asset_id: int,
index: int = 0
) -> Optional[str]:
"""Gets URL for an asset's larrge preview"""
# TODO(Andreas)
return ""
def get_watermark_preview_url_list(self,
asset_id: int
) -> Optional[List[str]]:
@@ -2046,7 +2026,8 @@ class AssetIndex():
def get_asset_id_list(self,
asset_type: Optional[assets.AssetType] = None,
purchased: bool = None,
local: bool = None
local: bool = None,
check_map_prefs: bool = False
) -> List[int]:
"""Return a list of asset IDs in AssetIndex.
Optionally restricted by per type and/or is_purchased flag.
@@ -2055,30 +2036,31 @@ class AssetIndex():
asset_type: Restrict list to a specific type. Use None for any type.
purchased: Restrict list to (non-)purchased assets. Use None for both.
local: Restrict list to (non-)local assets. Use None for both.
check_map_prefs: Restrict list to ready to import assets (considering map prefs).
"""
# TODO(Andreas): Just realized, we could likely speed this up
# by considering only asset_ids in "my_assets"
# (from query cache) in case purchased == True
asset_id_list = [
all_asset_id_list = [
asset_data.asset_id for asset_data in self.all_assets.values()
if asset_type is None or asset_data.asset_type == asset_type
]
if purchased is None and local is None:
return asset_id_list
return all_asset_id_list
asset_id_list = [
asset_id for asset_id in asset_id_list
purchased_asset_id_list = [
asset_id for asset_id in all_asset_id_list
if self.all_assets[asset_id].is_purchased == purchased
]
if local is None:
return asset_id_list
return purchased_asset_id_list
asset_id_list = [
asset_id for asset_id in asset_id_list
if self.all_assets[asset_id].is_local == local
filter_local_list = all_asset_id_list if purchased is None else purchased_asset_id_list
local_asset_id_list = [
asset_id for asset_id in filter_local_list
if self.check_asset_is_local(asset_id,
check_map_preferences=check_map_prefs
) == local
]
return asset_id_list
return local_asset_id_list
def num_assets(self, asset_type: Optional[assets.AssetType] = None) -> int:
"""Returns the number of assets, optionally per type"""
@@ -2104,19 +2086,19 @@ class AssetIndex():
def _init_categories(self, categories):
for category in categories:
category["asset_count"] = 0
self._init_categories(category["children"])
self._init_categories(category.get("children", []))
def _count_asset(self, categories, asset_categories):
num_asset_categories = len(asset_categories)
for category in categories:
category_name = category["name"]
if category_name not in asset_categories:
id_category = category.get("id", -1)
if id_category not in asset_categories:
continue
asset_categories.remove(category_name)
category["asset_count"] += 1
self._count_asset(category["children"], asset_categories)
asset_categories.remove(id_category)
category["asset_count"] = category.get("asset_count", 0) + 1
self._count_asset(category.get("children", []), asset_categories)
break
if len(asset_categories) > 0 and len(asset_categories) < num_asset_categories:
if 0 < len(asset_categories) < num_asset_categories:
self._count_asset(categories, asset_categories)
def get_asset_count_per_category(self,
@@ -2133,13 +2115,14 @@ class AssetIndex():
# Top level is different,
# as it actually contains AssetTypes, not categories
for category in categories:
asset_type_name = category["name"]
id_cat_asset_type = category.get("id", -1)
asset_type_name = category.get("name", "Unknown Category")
if asset_type_name == CATEGORY_FREE:
continue
asset_type = assets.AssetType.type_from_api(asset_type_name)
# filter depending on purchased and downloaded
# Filter depending on purchased and downloaded
if purchased:
asset_ids_per_type[asset_type] = [
asset_id for asset_id in asset_ids_per_type[asset_type]
@@ -2158,10 +2141,11 @@ class AssetIndex():
# important copy(), as we remove categories from the list
asset_categories = asset_data.categories.copy()
if asset_type_name in asset_categories:
asset_categories.remove(asset_type_name)
if id_cat_asset_type in asset_categories:
asset_categories.remove(id_cat_asset_type)
self._count_asset(category["children"], asset_categories)
self._count_asset(
category.get("children", []), asset_categories)
if len(asset_categories):
msg_warn = (f"Did not count all categories ({asset_id})!\n"
f"Left over: {asset_categories}\n"
@@ -2180,6 +2164,12 @@ class AssetIndex():
type_data.get_files(files_dict)
return files_dict
def flush_state(self) -> None:
"""Flushes all Downloads and Purchases params"""
for asset_data in self.all_assets.values():
asset_data.flush_state()
def flush_is_local(self) -> None:
"""Flushes all is_local flags"""
@@ -2224,9 +2214,9 @@ class AssetIndex():
raise ValueError("Asset ID already in use")
if len(asset_name) == 0:
raise ValueError("Please specify an asset name")
if asset_type not in ["HDRIs", "Models", "Textures"]:
if asset_type not in ["HDRIS", "Models", "Textures"]:
msg = (f"Unknown asset type: {asset_type}\n"
"Known types: HDRIs, Models, Textures")
"Known types: HDRIS, Models, Textures")
raise ValueError(msg)
return True
@@ -2351,13 +2341,13 @@ class AssetIndex():
asset_data.url = None
asset_data.slug = None
asset_data.credits = 0
asset_data.thumb_urls = None
asset_data.published_at = date_now
asset_data.is_local = True
asset_data.downloaded_at = None
asset_data.is_purchased = True
asset_data.purchased_at = None
asset_data.render_custom_schema = None
asset_data.specifications = None
asset_data.api_convention = convention
asset_data.local_convention = convention
@@ -42,11 +42,10 @@ PREVIEWS = ["_atlas",
"_grid",
]
PREVIEW_EXTS_LOWER = [".jpg", ".jpeg", ".png"]
MAP_EXT_LOWER = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".exr", ".hdr", ".psd"]
MAP_EXT_LOWER = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".exr", ".hdr", ".webp", ".psd"]
SIZES = ["256", "512"] + [f"{idx+1}K" for idx in range(18)] + ["HIRES", "WM"]
VARIANTS = [f"VAR{idx}" for idx in range(1, 10)]
WORKFLOWS = ["REGULAR", "METALNESS", "SPECULAR"]
CATEGORY_TRANSLATION = {"Hdrs": "HDRIs"}
class ModelType(IntEnum):
@@ -138,7 +137,7 @@ CREATION_METHODS = {
CATEGORY_NAME_TO_ASSET_TYPE = {
"All Assets": None,
"Brushes": AssetType.BRUSH,
"HDRIs": AssetType.HDRI,
"HDRIS": AssetType.HDRI,
"Models": AssetType.MODEL,
"Substances": AssetType.SUBSTANCE,
"Textures": AssetType.TEXTURE,
@@ -162,9 +161,24 @@ class AssetThumbnail():
filename: str
base_url: str
index: int
time: datetime
time: datetime # TODO(Andreas): This seems rather float!
type: str
@classmethod
def _from_dict(cls, d: Dict):
"""Alternate constructor,
used after loading AssetIndex from JSON to reconstruct class.
"""
# TODO(Andreas): If time really was a datetime (see above),
# something like this would be needed:
# if "time" not in d:
# raise KeyError("time")
# d["time"] = create datetime from value
new = cls(**d)
return new
# NOT a data class
# Reason: this way it does not get saved with the asset index.
@@ -327,6 +341,33 @@ class AssetState():
self.purchase = AssetStatePurchase()
self.dl = AssetStateDownload()
def has_any_error(self) -> bool:
"""Returns True, if any of the state classes signals an error."""
has_dl_error = self.dl.has_error()
has_purchase_error = self.purchase.has_error()
return self.has_error or has_dl_error or has_purchase_error
def get_any_error(
self,
force_string: bool = False
) -> Optional[Union[str, Exception]]:
"""Returns any error inside of state with following priority:
Asset error -> Purchase error -> Download error
"""
error = None
if self.has_error:
error = self.error
elif self.purchase.has_error():
error = self.purchase.error
elif self.dl.has_error():
error = self.dl.error
if force_string and error is not None:
error = str(error)
return error
class AssetDataRuntime():
"""Stores additional asset data, which will only be valid during runtime.
@@ -595,6 +636,18 @@ class TextureMap:
def get_path(self):
return os.path.join(self.directory, self.filename)
def copy(self):
"""Creates a copy of the TextureMap instance"""
return TextureMap(
directory=self.directory,
filename=self.filename,
file_format=self.file_format,
lod=self.lod,
map_type=self.map_type,
size=self.size,
variant=self.variant,
)
@dataclass
class TextureMapDesc:
@@ -730,9 +783,9 @@ class Texture(BaseTex):
return asset_map_list
@staticmethod
def _convention_0_filter_16bit(tex_map_dict: Dict[MapType, List[TextureMap]],
prefer_16_bit: bool = False
) -> Dict[MapType, List[TextureMap]]:
def _convention_0_filter_maps(tex_map_dict: Dict[MapType, List[TextureMap]],
prefer_16_bit: bool = False
) -> Dict[MapType, List[TextureMap]]:
# Decide between 8-Bit and 16-Bit, if both are available
if MapType.BUMP in tex_map_dict and MapType.BUMP16 in tex_map_dict:
if prefer_16_bit:
@@ -749,6 +802,20 @@ class Texture(BaseTex):
del tex_map_dict[MapType.NRM]
else:
del tex_map_dict[MapType.NRM16]
# TODO(Joao): Implement this in the dccs side once all the map
# management is being done there;
# If ALBEDO map is found, we replace the COL list with the ALBEDO maps;
# To be compatible with the addons, we have to keep using the MapType.COL
# as key;
albedo_list = tex_map_dict.get(MapType.ALBEDO, [])
if len(albedo_list) > 0:
tex_map_dict[MapType.COL] = []
for albedo_map in albedo_list:
albedo_map = albedo_map.copy()
albedo_map.map_type = MapType.COL
tex_map_dict[MapType.COL].append(albedo_map)
return tex_map_dict
def update(self, type_data_new, purge_maps: bool = False) -> None:
@@ -959,7 +1026,10 @@ class Texture(BaseTex):
raise KeyError(f"Workflow not found: {workflow}")
map_descs = self.map_descs[workflow]
return [map_desc.map_type_code for map_desc in map_descs]
# No UNKNOWN map should get here - they are filtered on the map_descs constructor ()
return [map_desc.map_type_code for map_desc in map_descs
if MapType.from_type_code(map_desc.map_type_code) != MapType.UNKNOWN]
def get_maps_per_preferences(self,
map_preferences: Any,
@@ -1049,8 +1119,8 @@ class Texture(BaseTex):
tex_map_dict[_map_type] = tex_map_dict.get(_map_type, []) + [tex_map]
if map_type_list is None:
tex_map_dict = self._convention_0_filter_16bit(tex_map_dict,
prefer_16_bit)
tex_map_dict = self._convention_0_filter_maps(tex_map_dict,
prefer_16_bit)
tex_maps = []
# Get rid of multiple files for the same texture (e.g. .png and .psd)
@@ -1880,11 +1950,18 @@ class AssetData:
downloaded_at: Optional[int] = None
# is_purchased: None until proven true or false.
is_purchased: Optional[bool] = None
# is_recent_downloaded: None until proven true or false
# (new flag after implementing recent downloads filter in addon-first)
is_recent_downloaded: Optional[bool] = None
# UTC, seconds since epoch
purchased_at: Optional[int] = None
# render_custom_schema: Filled with what ever meta data
# ApiResponse contains for this key.
# TODO(Joao): Remove Render Custom Schema from this class.
# Render custom schema was removed from asset dict api, specifications
# key was added here as substitute - to discuss if we should keep both;
render_custom_schema: Optional[Dict] = None
specifications: Optional[Dict] = None
api_convention: Optional[int] = None
local_convention: Optional[int] = None
@@ -1922,6 +1999,11 @@ class AssetData:
elif new.texture is not None:
new.texture = Texture._from_dict(new.texture)
cf_thumb_urls = []
for _dict in new.cloudflare_thumb_urls:
cf_thumb_urls.append(AssetThumbnail._from_dict(_dict))
new.cloudflare_thumb_urls = cf_thumb_urls
new.state = AssetState()
new.runtime = AssetDataRuntime()
@@ -1978,12 +2060,13 @@ class AssetData:
data_dict[_t("Creation Method")] = method_str
elif self.asset_type == AssetType.MODEL:
tech_data = self.render_custom_schema.get("technical_description", {})
lod_tri_count = tech_data.get("LODs", {})
specs = self.specifications
tech_data = specs if specs is not None else {}
lod_tri_count = tech_data.get("lods", {})
meshes = tech_data.get("Meshes", None)
meshes = tech_data.get("meshes", None)
source_tri_count = lod_tri_count.get("SOURCE", None)
dimensions = tech_data.get("Dimensions", {})
dimensions = tech_data.get("physical_size_cm", {})
incl_res = self.render_custom_schema.get("included_resolution")
# TODO(Joao): The following lines are for mark dimensions strings
@@ -1993,7 +2076,8 @@ class AssetData:
depth_str = _m("Depth (cm)") # noqa
for key, value in dimensions.items():
data_dict[_t(key)] = value
key_title = _t(key.title())
data_dict[f"{key_title} (cm)"] = value
if meshes is not None:
data_dict[_t("Mesh Count")] = meshes
@@ -2060,9 +2144,10 @@ class AssetData:
self.categories = _cond_set(asset_data_new.categories, self.categories)
self.url = _cond_set(asset_data_new.url, self.url)
self.slug = _cond_set(asset_data_new.slug, self.slug)
self.thumb_urls = _cond_set(asset_data_new.thumb_urls, self.thumb_urls)
self.published_at = _cond_set(asset_data_new.thumb_urls,
self.thumb_urls)
self.published_at = _cond_set(asset_data_new.published_at,
self.published_at)
self.thumb_urls = _cond_set(asset_data_new.thumb_urls,
self.thumb_urls)
self.is_local = _cond_set(asset_data_new.is_local, self.is_local)
self.downloaded_at = _cond_set(asset_data_new.downloaded_at,
self.downloaded_at)
@@ -2072,9 +2157,15 @@ class AssetData:
self.purchased_at)
self.render_custom_schema = _cond_set(
asset_data_new.render_custom_schema, self.render_custom_schema)
self.specifications = _cond_set(asset_data_new.specifications,
self.specifications)
self.get_type_data().update(asset_data_new.get_type_data(), purge_maps)
def flush_state(self):
"""Reset all Download and Purchase states."""
self.state = AssetState()
def flush_local(self):
"""Resets all information about local files."""
@@ -2187,6 +2278,43 @@ class AssetData:
name_mat += f"_{renderer}"
return name_mat
def create_dummy_type_data(
self, asset_type: Optional[AssetType] = None) -> None:
"""Creates a dummy type data. For example to be used in case we run
into an exception during evaluation of server side asset data and thus
can't build complete data structures.
"""
if self.get_type_data() is not None:
return
if asset_type is None:
asset_type = self.asset_type
tex_dummy = Texture(maps={}, map_descs={}, sizes=["1K"])
if asset_type == AssetType.BRUSH:
self.brush = Brush(alpha=tex_dummy)
elif asset_type == AssetType.HDRI:
tex_dummy_2 = Texture(maps={}, map_descs={}, sizes=["1K"])
self.hdri = Hdri(light=tex_dummy, bg=tex_dummy_2)
elif asset_type == AssetType.MODEL:
self.model = Model(texture=tex_dummy)
elif asset_type == AssetType.TEXTURE:
self.texture = tex_dummy
def is_new(self) -> bool:
"""Returns True if the asset was published within the last 30 days.
Returns:
bool: True if the asset is new (< 30 days old), False otherwise
"""
if not self.published_at:
return False
days_threshold = 30
# Convert timestamp to datetime object
published_datetime = datetime.fromtimestamp(self.published_at, tz=timezone.utc)
days_old = (datetime.now(timezone.utc) - published_datetime).days
return days_old <= days_threshold
# Currently constants are defined here at the end,
# as some require above classes to be defined.
@@ -2205,3 +2333,11 @@ ASSET_TYPE_TO_CATEGORY_NAME = {AssetType.BRUSH: "Brushes",
AssetType.TEXTURE: "Textures",
AssetType.UNSUPPORTED: "Unsupported",
}
ASSET_TYPE_TO_URL_PATH = {AssetType.BRUSH: "brush",
AssetType.HDRI: "hdr",
AssetType.MODEL: "model",
AssetType.SUBSTANCE: "substance",
AssetType.TEXTURE: "texture",
AssetType.UNSUPPORTED: "unsupported",
}
@@ -50,7 +50,7 @@ MAX_RETRIES_PER_ASSET = 2
# This list defines a priority to fallback available formats
# NOTE: This is only for Convention 1 downloads
SUPPORTED_TEX_FORMATS = ["jpg", "png", "tiff", "exr"]
SUPPORTED_TEX_FORMATS = ["jpg", "png", "tiff", "exr", "hdr"]
MODEL_FILE_EXT = ["fbx", "blend", "max", "c4d", "skp", "ma"]
@@ -166,20 +166,30 @@ class AssetDownload:
asset_data: AssetData,
size: str,
dir_target: str,
lod: str = "NONE",
download_lods: bool = False,
native_mesh: bool = False,
renderer: Optional[str] = None,
update_callback: Optional[Callable] = None
update_callback: Optional[Callable] = None,
get_download_data_callback: Optional[Callable] = None
) -> None:
self.addon = addon
self.asset_data = asset_data
self.size = size
self.lod = lod
# TODO(Joao): Add lods list parameter to the class and assign to
# self.lods when select download lods feature is available.
# Currently if the download is set do download lods, all the lods will
# be downloaded. There is no way to download only some given lods;
self.download_lods = download_lods
self.lods = None
self.native_mesh = native_mesh
self.renderer = renderer
self.update_callback = update_callback
self.get_download_data_callback = get_download_data_callback
self.get_download_data_callback_called = False
self.dir_target = os.path.join(dir_target, asset_data.asset_name)
self.download_list = []
@@ -234,15 +244,18 @@ class AssetDownload:
]
}
if self.convention == 0:
self.data_payload["assets"][0]["sizes"] = [self.size]
elif self.convention == 1:
self.data_payload["assets"][0]["resolution"] = self.size
self.data_payload["assets"][0]["sizes"] = [self.size]
if self.asset_data.asset_type in [AssetType.HDRI, AssetType.TEXTURE]:
if self.asset_data.asset_type == AssetType.TEXTURE:
self.set_texture_payload()
elif self.asset_data.asset_type == AssetType.MODEL:
self.set_model_payload()
elif self.asset_data.asset_type == AssetType.HDRI:
# HDRIs don't require any additional payload. The logic for defining
# which maps should be downloaded is implemented on the server backend.
# This change was made to support .hdr files and specify which DCCs
# they should be downloaded/imported for;
pass
def set_texture_payload(self) -> None:
if self.convention == 0:
@@ -282,13 +295,15 @@ class AssetDownload:
self.data_payload["assets"][0]["maps"] = map_list
def set_model_payload(self) -> None:
self.data_payload["assets"][0]["lods"] = int(self.download_lods)
if self.download_lods:
self.lods = self.asset_data.get_type_data().get_lod_list()
if self.lods not in [[], ["NONE"]]:
self.data_payload["assets"][0]["lods"] = self.lods
if self.native_mesh and self.renderer is not None:
self.data_payload["assets"][0]["softwares"] = [self.addon._api.software_dl_dcc]
self.data_payload["assets"][0]["renders"] = [self.renderer]
else:
self.data_payload["assets"][0]["softwares"] = ["ALL_OTHERS"]
self.data_payload["assets"][0]["softwares"] = ["Generic"]
def create_download_folder(self) -> bool:
try:
@@ -354,8 +369,23 @@ class AssetDownload:
return dl_folder
@staticmethod
def get_files_list(res) -> Tuple[List, List]:
files_dict_list = [_files_dict
for _files_dict in res.body.get("payload", [])
if "files" in _files_dict.keys()]
files_list = None
dynamic_files_list = None
if len(files_dict_list) > 0:
files_list = files_dict_list[0].get("files")
# Dynamic files expected only for textures
if files_dict_list[0].get("dynamic_files"):
dynamic_files_list = files_dict_list[0].get("dynamic_files")
return files_list, dynamic_files_list
def build_download_list(self, res: ApiResponse) -> None:
files_list = res.body.get("files", [])
files_list, dynamic_files = self.get_files_list(res)
self.uuid = res.body.get("uuid", None)
if self.uuid in [None, ""]:
self.addon.logger_dl.error("No UUID for download")
@@ -408,7 +438,7 @@ class AssetDownload:
self.addon._api.report_message(
"model_with_only_source_lod", msg, level="info")
self.set_dynamic_files(res.body.get("dynamic_files", None))
self.set_dynamic_files(dynamic_files)
def set_dynamic_files(self,
dynamic_files_api: Optional[List[Dict]]
@@ -597,6 +627,14 @@ class AssetDownload:
progress = self.size_asset_bytes_downloaded / max(self.size_asset_bytes_expected, 1)
self.asset_data.state.dl.set_progress(max(progress, 0.001))
self.asset_data.state.dl.set_downloaded_bytes(self.size_asset_bytes_expected)
try:
if not self.get_download_data_callback_called:
self.get_download_data_callback()
self.get_download_data_callback_called = True
except TypeError:
pass # No Download data callback
try: # Init progress bar
self.update_callback()
except TypeError:
@@ -77,6 +77,11 @@ class MapType(IntEnum):
NA_ORM = 50
NA_VERTEXBLEND = 51
# Legacy maps - used to identify and download - not yet applied to the materials;
ALBEDO = 52
OVERLAY16 = 53
DIRECTION = 54
# Convention 1, values 100 to 149 (150 and up for convention 1 only maps)
# NOTE: Value - 100 should match convention 0
AmbientOcclusion = 104
@@ -40,19 +40,27 @@ NOTICE_PRIO_NO_INET = NOTICE_PRIO_LOW # Show, but other errors have precedent
NOTICE_PRIO_PROXY = NOTICE_PRIO_MEDIUM
NOTICE_PRIO_SETTINGS_WRITE = NOTICE_PRIO_HIGH
NOTICE_PRIO_SURVEY = NOTICE_PRIO_LOWEST
NOTICE_PRIO_UPDATE = NOTICE_PRIO_HIGH + 5 # urgent, but room for "more urgent"
NOTICE_PRIO_RESTART = NOTICE_PRIO_LOW
# unsupported renderer is low (only onboarding is lower), once is not dismissible
NOTICE_PRIO_UNSUPPORTED_RENDERER = NOTICE_PRIO_LOWEST + 5
# Onboarding WM is now the lowest, once is not dismissible and less prio than warnings
NOTICE_PRIO_WM_ONBOARDING = NOTICE_PRIO_LOWEST + 6
# Predefined notice IDs
NOTICE_ID_MAT_TEMPLATE = "MATERIAL_TEMPLATE_ERROR"
NOTICE_ID_NO_INET = "NO_INTERNET_CONNECTION"
NOTICE_ID_PROXY = "PROXY_CONNECTION_ERROR"
NOTICE_ID_SETTINGS_WRITE = "SETTINGS_WRITE_ERROR"
NOTICE_ID_UNSUPPORTED_RENDERER = "UNSUPPORTED_RENDERER"
NOTICE_ID_SURVEY_FREE = "NPS_INAPP_FREE"
NOTICE_ID_SURVEY_ACTIVE = "NPS_INAPP_ACTIVE"
NOTICE_ID_UPDATE = "UPDATE_READY_MANUAL_INSTALL"
NOTICE_ID_VERSION_ALERT = "ADDON_VERSION_ALERT"
NOTICE_ID_RESTART_ALERT = "NOTICE_ID_RESTART_ALERT"
NOTICE_ID_WM_ONBOARDING = "NOTICE_WM_ONBOARDING"
# Predefined notice titles
# Used a default param in create functions, but should usually be overridden
@@ -61,10 +69,12 @@ NOTICE_TITLE_MAT_TEMPLATE = _m("Material template error")
NOTICE_TITLE_NO_INET = _m("No internet access")
NOTICE_TITLE_PROXY = _m("Encountered proxy error")
NOTICE_TITLE_SETTINGS_WRITE = _m("Failed to write settings")
NOTICE_UNSUPPORTED_RENDERER = _m("Renderer not supported")
NOTICE_TITLE_SURVEY = _m("How's the addon?")
NOTICE_TITLE_UPDATE = _m("Update ready")
NOTICE_TITLE_DEPRECATED = _m("Deprecated version")
NOTICE_TITLE_RESTART = _m("Restart needed")
NOTICE_TITLE_WM_ONBOARDING = _m("Watermarked Previews")
# Predefined notice Labels (text to be displayed on the notification Banner)
# Used a default param in create functions, but should usually be overridden
@@ -87,6 +97,7 @@ NOTICE_ICON_WARN = "ICON_WARN"
NOTICE_ICON_INFO = "ICON_INFO"
NOTICE_ICON_SURVEY = "ICON_SURVEY"
NOTICE_ICON_NO_CONNECTION = "ICON_NO_CONNECTION"
NOTICE_ICON_WM_ONBOARDING = "ICON_WM_ONBOARDING"
class ActionType(IntEnum):
@@ -97,6 +108,9 @@ class ActionType(IntEnum):
RUN_OPERATOR = 4
UPDATE_READY = 2
# Adding None for notifications not related with any further actions (e.g. onboarding)
NONE = 99
class SignalType(IntEnum):
"""Types of each interaction with the notifications."""
@@ -147,6 +161,10 @@ class Notification():
# function to be called when the notification is dismissed (viewed or not)
on_dismiss_callable: Optional[Callable] = None
# Parameter for notifications that can be shown as banners on the asset
# browser (only available for P4Max and P4C)
display_as_banner: bool = False
viewed: bool = False # False until actually drawn
clicked: bool = False # False until user interact with the notice
@@ -167,6 +185,12 @@ class AddonNotificationsParameters:
update_action_text: str = NOTICE_TITLE_UPDATE
update_body: str = ""
onboarding_wm_title: str = NOTICE_TITLE_WM_ONBOARDING
onboarding_wm_label: str = ""
onboarding_wm_tooltip: str = ""
allow_banner_notice: bool = False
@dataclass
class NotificationOpenUrl(Notification):
@@ -217,7 +241,16 @@ class NotificationUpdateReady(Notification):
self.action = ActionType.UPDATE_READY
def get_key(self) -> str:
return "".join([self.action.name, self.download_url, self.download_label])
return "".join([self.action.name, self.id_notice])
@dataclass
class NotificationNoAction(Notification):
def __post_init__(self):
self.action = ActionType.NONE
def get_key(self) -> str:
return "".join([self.action.name, self.id_notice])
class NotificationSystem():
@@ -239,7 +272,8 @@ class NotificationSystem():
NOTICE_ICON_WARN: None,
NOTICE_ICON_INFO: None,
NOTICE_ICON_SURVEY: None,
NOTICE_ICON_NO_CONNECTION: None
NOTICE_ICON_NO_CONNECTION: None,
NOTICE_ICON_WM_ONBOARDING: None
}
addon_params = AddonNotificationsParameters()
@@ -255,12 +289,14 @@ class NotificationSystem():
icon_warn: Optional[Any] = None,
icon_info: Optional[Any] = None,
icon_survey: Optional[Any] = None,
icon_no_connection: Optional[Any] = None
icon_no_connection: Optional[Any] = None,
notify_icon_wm_onboarding: Optional[Any] = None
) -> None:
self.icon_dcc_map[NOTICE_ICON_WARN] = icon_warn
self.icon_dcc_map[NOTICE_ICON_INFO] = icon_info
self.icon_dcc_map[NOTICE_ICON_SURVEY] = icon_survey
self.icon_dcc_map[NOTICE_ICON_NO_CONNECTION] = icon_no_connection
self.icon_dcc_map[NOTICE_ICON_WM_ONBOARDING] = notify_icon_wm_onboarding
def _run_threaded(key_pool: PoolKeys,
max_threads: Optional[int] = None,
@@ -351,7 +387,9 @@ class NotificationSystem():
if not notice.allow_dismiss and not force:
return
if not notice.clicked:
if not notice.clicked and force is False:
# Don't signal dismiss if it's a forced removal, since there's no
# human interaction.
self._signal_dismiss(notice)
if notice.on_dismiss_callable is not None:
@@ -630,6 +668,32 @@ class NotificationSystem():
self.enqueue_notice(notice)
return notice
def create_unsupported_renderer(self,
title: str = NOTICE_UNSUPPORTED_RENDERER,
*,
label: str = NOTICE_UNSUPPORTED_RENDERER,
tooltip: str,
body: str,
auto_enqueue: bool = True
) -> Notification:
"""Returns an Unsupported Render Engine notice."""
notice = NotificationPopup(
id_notice=NOTICE_ID_UNSUPPORTED_RENDERER,
title=title,
priority=NOTICE_PRIO_UNSUPPORTED_RENDERER,
allow_dismiss=False,
tooltip=tooltip,
icon=self.icon_dcc_map[NOTICE_ICON_WARN],
body=body,
label=label,
open_popup=True,
alert=True
)
if auto_enqueue:
self.enqueue_notice(notice)
return notice
def create_update(self,
title: str = NOTICE_TITLE_UPDATE,
*,
@@ -676,3 +740,30 @@ class NotificationSystem():
if auto_enqueue:
self.enqueue_notice(notice)
return notice
def create_watermarked_onboarding(self,
auto_enqueue: bool = True,
) -> NotificationNoAction:
"""Returns a pre-built 'WM Onboarding' notice."""
title = NOTICE_TITLE_WM_ONBOARDING
if self.addon_params.onboarding_wm_title is not None:
title = self.addon_params.onboarding_wm_title
notice = NotificationNoAction(
id_notice=NOTICE_ID_WM_ONBOARDING,
title=title,
priority=NOTICE_PRIO_WM_ONBOARDING,
display_as_banner=True,
label=self.addon_params.onboarding_wm_label,
allow_dismiss=False,
tooltip=self.addon_params.onboarding_wm_tooltip,
icon=self.icon_dcc_map[NOTICE_ICON_WM_ONBOARDING]
)
allow_banner = self.addon_params.allow_banner_notice
handle_as_banner_notice = allow_banner and notice.display_as_banner
if auto_enqueue and not handle_as_banner_notice:
self.enqueue_notice(notice)
return notice
@@ -154,6 +154,7 @@ class PoliigonSubscription:
base_price: Optional[float] = None # e.g. 123.45
currency_symbol: Optional[str] = None # e.g. "$" (special character)
is_unlimited: Optional[bool] = None
is_limited_legacy: Optional[bool] = None
has_team: Optional[bool] = None
@staticmethod
@@ -275,6 +276,10 @@ class PoliigonSubscription:
if unlimited is not None:
self.is_unlimited = bool(unlimited)
legacy = plan_dictionary.get("limited_legacy", None)
if legacy is not None:
self.is_limited_legacy = bool(legacy)
has_team = bool(plan_dictionary.get("team_id", None))
if has_team is not None:
self.has_team = bool(has_team)
@@ -403,6 +408,10 @@ class PoliigonPlanUpgradeManager:
self.addon.notify._signal_clicked(signal_notice)
def check_show_banner(self) -> bool:
# Checks if the wm onboarding banner is drawn or not
if not self.addon.is_onboarding_wm_preview_done():
return False
if self.user is None:
return False
do_show_banner = self.show_banner
@@ -485,11 +494,13 @@ class PoliigonPlanUpgradeManager:
if self.user.plan.is_unlimited:
# The only benefit to upgrade is to get more downloads, if you're
# already unlimited, there's nothing to upgrade to
self.upgrade_plan = None
return
if self.user.plan.has_team:
# Let's not offer updates to team members, since these contracts
# are handled separately
self.upgrade_plan = None
return
upgrade_pro_plan = None
@@ -128,7 +128,7 @@ class UpgradeContent:
self.populate()
def student_discount(self, is_teacher: bool = False) -> Any:
primary = _t("Access the entire library by joining Pro")
primary = _t("Access the entire library with a Poliigon Plan")
secondary = _t("{0} can claim a 50% discount".format(
_t("Students") if not is_teacher else _t("Teachers")))
if self.as_single_paragraph:
@@ -144,7 +144,7 @@ class UpgradeContent:
self.icon_path = self.icons.check
def become_pro(self) -> Any:
primary = _t("Access the entire library by joining Pro")
primary = _t("Access the entire library with a Poliigon Plan")
secondary = _t("Download and import from the entire Poliigon library")
if self.as_single_paragraph:
# To keep it slimmer, do just the primary text.
@@ -17,11 +17,12 @@
# ##### END GPL LICENSE BLOCK #####
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Any
from enum import Enum
from .assets import (MapType)
from .plan_manager import PoliigonSubscription
from .multilingual import _m, _t
from .logger import (DEBUG, # noqa F401, allowing downstream const usage
ERROR,
INFO,
@@ -30,6 +31,42 @@ from .logger import (DEBUG, # noqa F401, allowing downstream const usage
WARNING)
class PoliigonUserProfiles(Enum):
USER_HOBBYIST = _m("As a hobbyist")
USER_STUDENT = _m("As a student")
USER_PROFESSIONAL = _m("As a professional (individual)")
USER_STUDIO = _m("As a professional (studio/team)")
USER_TEACHER = _m("As a teacher")
def to_ui_display(self) -> Optional[str]:
if self == self.USER_STUDENT:
return _t("Student")
elif self == self.USER_TEACHER:
return _t("Teacher")
elif self == self.USER_HOBBYIST:
return _t("Hobbyist")
elif self == self.USER_PROFESSIONAL:
return _t("Freelancer")
elif self == self.USER_STUDIO:
return _t("Team")
else:
return None
@classmethod
def from_string(cls, use_string: str) -> Optional[Any]:
try:
return cls(use_string)
except ValueError:
return None
@classmethod
def get_list(cls, addon_display: bool = False) -> List[str]:
if not addon_display:
return [member.value for member in PoliigonUserProfiles]
else:
return [member.to_ui_display() for member in PoliigonUserProfiles]
@dataclass
class MapFormats:
map_type: MapType
@@ -123,7 +160,15 @@ class PoliigonUser:
is_teacher: Optional[bool] = False
credits: Optional[int] = None
credits_od: Optional[int] = None
count_assets_owned: Optional[int] = None
count_assets_downloads: Optional[int] = None
plan: Optional[PoliigonSubscription] = None
map_preferences: Optional[UserDownloadPreferences] = None
user_profile: PoliigonUserProfiles = None
email_preference: Optional[bool] = None
primary_3d_software: Optional[str] = None
# Primary Render engine might be changed/decided in the Dcc side if
# not yet available (None) in user data;
primary_rendering_engine: Optional[str] = None
# Todo(Joao): remove this flag when all addons are using map prefs
use_preferences_on_download: bool = False
@@ -1,10 +1,11 @@
from sentry_sdk import profiler
from sentry_sdk import metrics
from sentry_sdk.scope import Scope
from sentry_sdk.transport import Transport, HttpTransport
from sentry_sdk.client import Client
from sentry_sdk.api import * # noqa
from sentry_sdk.consts import VERSION # noqa
from sentry_sdk.consts import VERSION
__all__ = [ # noqa
"Hub",
@@ -12,9 +13,11 @@ __all__ = [ # noqa
"Client",
"Transport",
"HttpTransport",
"VERSION",
"integrations",
# From sentry_sdk.api
"init",
"add_attachment",
"add_breadcrumb",
"capture_event",
"capture_exception",
@@ -45,6 +48,13 @@ __all__ = [ # noqa
"start_transaction",
"trace",
"monitor",
"logger",
"metrics",
"profiler",
"start_session",
"end_session",
"set_transaction_name",
"update_current_span",
]
# Initialize the debug support after everything is loaded
@@ -0,0 +1,161 @@
import os
import random
import threading
from datetime import datetime, timezone
from typing import Optional, List, Callable, TYPE_CHECKING, Any
from sentry_sdk.utils import format_timestamp, safe_repr
from sentry_sdk.envelope import Envelope, Item, PayloadRef
if TYPE_CHECKING:
from sentry_sdk._types import Log
class LogBatcher:
MAX_LOGS_BEFORE_FLUSH = 100
FLUSH_WAIT_TIME = 5.0
def __init__(
self,
capture_func, # type: Callable[[Envelope], None]
):
# type: (...) -> None
self._log_buffer = [] # type: List[Log]
self._capture_func = capture_func
self._running = True
self._lock = threading.Lock()
self._flush_event = threading.Event() # type: threading.Event
self._flusher = None # type: Optional[threading.Thread]
self._flusher_pid = None # type: Optional[int]
def _ensure_thread(self):
# type: (...) -> bool
"""For forking processes we might need to restart this thread.
This ensures that our process actually has that thread running.
"""
if not self._running:
return False
pid = os.getpid()
if self._flusher_pid == pid:
return True
with self._lock:
# Recheck to make sure another thread didn't get here and start the
# the flusher in the meantime
if self._flusher_pid == pid:
return True
self._flusher_pid = pid
self._flusher = threading.Thread(target=self._flush_loop)
self._flusher.daemon = True
try:
self._flusher.start()
except RuntimeError:
# Unfortunately at this point the interpreter is in a state that no
# longer allows us to spawn a thread and we have to bail.
self._running = False
return False
return True
def _flush_loop(self):
# type: (...) -> None
while self._running:
self._flush_event.wait(self.FLUSH_WAIT_TIME + random.random())
self._flush_event.clear()
self._flush()
def add(
self,
log, # type: Log
):
# type: (...) -> None
if not self._ensure_thread() or self._flusher is None:
return None
with self._lock:
self._log_buffer.append(log)
if len(self._log_buffer) >= self.MAX_LOGS_BEFORE_FLUSH:
self._flush_event.set()
def kill(self):
# type: (...) -> None
if self._flusher is None:
return
self._running = False
self._flush_event.set()
self._flusher = None
def flush(self):
# type: (...) -> None
self._flush()
@staticmethod
def _log_to_transport_format(log):
# type: (Log) -> Any
def format_attribute(val):
# type: (int | float | str | bool) -> Any
if isinstance(val, bool):
return {"value": val, "type": "boolean"}
if isinstance(val, int):
return {"value": val, "type": "integer"}
if isinstance(val, float):
return {"value": val, "type": "double"}
if isinstance(val, str):
return {"value": val, "type": "string"}
return {"value": safe_repr(val), "type": "string"}
if "sentry.severity_number" not in log["attributes"]:
log["attributes"]["sentry.severity_number"] = log["severity_number"]
if "sentry.severity_text" not in log["attributes"]:
log["attributes"]["sentry.severity_text"] = log["severity_text"]
res = {
"timestamp": int(log["time_unix_nano"]) / 1.0e9,
"trace_id": log.get("trace_id", "00000000-0000-0000-0000-000000000000"),
"level": str(log["severity_text"]),
"body": str(log["body"]),
"attributes": {
k: format_attribute(v) for (k, v) in log["attributes"].items()
},
}
return res
def _flush(self):
# type: (...) -> Optional[Envelope]
envelope = Envelope(
headers={"sent_at": format_timestamp(datetime.now(timezone.utc))}
)
with self._lock:
if len(self._log_buffer) == 0:
return None
envelope.add_item(
Item(
type="log",
content_type="application/vnd.sentry.items.log+json",
headers={
"item_count": len(self._log_buffer),
},
payload=PayloadRef(
json={
"items": [
self._log_to_transport_format(log)
for log in self._log_buffer
]
}
),
)
)
self._log_buffer.clear()
self._capture_func(envelope)
return envelope
@@ -0,0 +1,156 @@
import os
import random
import threading
from datetime import datetime, timezone
from typing import Optional, List, Callable, TYPE_CHECKING, Any, Union
from sentry_sdk.utils import format_timestamp, safe_repr
from sentry_sdk.envelope import Envelope, Item, PayloadRef
if TYPE_CHECKING:
from sentry_sdk._types import Metric
class MetricsBatcher:
MAX_METRICS_BEFORE_FLUSH = 1000
FLUSH_WAIT_TIME = 5.0
def __init__(
self,
capture_func, # type: Callable[[Envelope], None]
):
# type: (...) -> None
self._metric_buffer = [] # type: List[Metric]
self._capture_func = capture_func
self._running = True
self._lock = threading.Lock()
self._flush_event = threading.Event() # type: threading.Event
self._flusher = None # type: Optional[threading.Thread]
self._flusher_pid = None # type: Optional[int]
def _ensure_thread(self):
# type: (...) -> bool
if not self._running:
return False
pid = os.getpid()
if self._flusher_pid == pid:
return True
with self._lock:
if self._flusher_pid == pid:
return True
self._flusher_pid = pid
self._flusher = threading.Thread(target=self._flush_loop)
self._flusher.daemon = True
try:
self._flusher.start()
except RuntimeError:
self._running = False
return False
return True
def _flush_loop(self):
# type: (...) -> None
while self._running:
self._flush_event.wait(self.FLUSH_WAIT_TIME + random.random())
self._flush_event.clear()
self._flush()
def add(
self,
metric, # type: Metric
):
# type: (...) -> None
if not self._ensure_thread() or self._flusher is None:
return None
with self._lock:
self._metric_buffer.append(metric)
if len(self._metric_buffer) >= self.MAX_METRICS_BEFORE_FLUSH:
self._flush_event.set()
def kill(self):
# type: (...) -> None
if self._flusher is None:
return
self._running = False
self._flush_event.set()
self._flusher = None
def flush(self):
# type: (...) -> None
self._flush()
@staticmethod
def _metric_to_transport_format(metric):
# type: (Metric) -> Any
def format_attribute(val):
# type: (Union[int, float, str, bool]) -> Any
if isinstance(val, bool):
return {"value": val, "type": "boolean"}
if isinstance(val, int):
return {"value": val, "type": "integer"}
if isinstance(val, float):
return {"value": val, "type": "double"}
if isinstance(val, str):
return {"value": val, "type": "string"}
return {"value": safe_repr(val), "type": "string"}
res = {
"timestamp": metric["timestamp"],
"trace_id": metric["trace_id"],
"name": metric["name"],
"type": metric["type"],
"value": metric["value"],
"attributes": {
k: format_attribute(v) for (k, v) in metric["attributes"].items()
},
}
if metric.get("span_id") is not None:
res["span_id"] = metric["span_id"]
if metric.get("unit") is not None:
res["unit"] = metric["unit"]
return res
def _flush(self):
# type: (...) -> Optional[Envelope]
envelope = Envelope(
headers={"sent_at": format_timestamp(datetime.now(timezone.utc))}
)
with self._lock:
if len(self._metric_buffer) == 0:
return None
envelope.add_item(
Item(
type="trace_metric",
content_type="application/vnd.sentry.items.trace-metric+json",
headers={
"item_count": len(self._metric_buffer),
},
payload=PayloadRef(
json={
"items": [
self._metric_to_transport_format(metric)
for metric in self._metric_buffer
]
}
),
)
)
self._metric_buffer.clear()
self._capture_func(envelope)
return envelope
@@ -30,6 +30,17 @@ class AnnotatedValue:
return self.value == other.value and self.metadata == other.metadata
def __str__(self):
# type: (AnnotatedValue) -> str
return str({"value": str(self.value), "metadata": str(self.metadata)})
def __len__(self):
# type: (AnnotatedValue) -> int
if self.value is not None:
return len(self.value)
else:
return 0
@classmethod
def removed_because_raw_data(cls):
# type: () -> AnnotatedValue
@@ -47,11 +58,14 @@ class AnnotatedValue:
)
@classmethod
def removed_because_over_size_limit(cls):
# type: () -> AnnotatedValue
"""The actual value was removed because the size of the field exceeded the configured maximum size (specified with the max_request_body_size sdk option)"""
def removed_because_over_size_limit(cls, value=""):
# type: (Any) -> AnnotatedValue
"""
The actual value was removed because the size of the field exceeded the configured maximum size,
for example specified with the max_request_body_size sdk option.
"""
return AnnotatedValue(
value="",
value=value,
metadata={
"rem": [ # Remark
[
@@ -149,14 +163,14 @@ if TYPE_CHECKING:
Event = TypedDict(
"Event",
{
"breadcrumbs": dict[
Literal["values"], list[dict[str, Any]]
"breadcrumbs": Annotated[
dict[Literal["values"], list[dict[str, Any]]]
], # TODO: We can expand on this type
"check_in_id": str,
"contexts": dict[str, dict[str, object]],
"dist": str,
"duration": Optional[float],
"environment": str,
"environment": Optional[str],
"errors": list[dict[str, Any]], # TODO: We can expand on this type
"event_id": str,
"exception": dict[
@@ -174,7 +188,7 @@ if TYPE_CHECKING:
"monitor_slug": Optional[str],
"platform": Literal["python"],
"profile": object, # Should be sentry_sdk.profiler.Profile, but we can't import that here due to circular imports
"release": str,
"release": Optional[str],
"request": dict[str, object],
"sdk": Mapping[str, object],
"server_name": str,
@@ -196,7 +210,6 @@ if TYPE_CHECKING:
"type": Literal["check_in", "transaction"],
"user": dict[str, object],
"_dropped_spans": int,
"_metrics_summary": dict[str, object],
},
total=False,
)
@@ -206,17 +219,61 @@ if TYPE_CHECKING:
tuple[None, None, None],
]
# TODO: Make a proper type definition for this (PRs welcome!)
Hint = Dict[str, Any]
Log = TypedDict(
"Log",
{
"severity_text": str,
"severity_number": int,
"body": str,
"attributes": dict[str, str | bool | float | int],
"time_unix_nano": int,
"trace_id": Optional[str],
},
)
MetricType = Literal["counter", "gauge", "distribution"]
MetricAttributeValue = TypedDict(
"MetricAttributeValue",
{
"value": Union[str, bool, float, int],
"type": Literal["string", "boolean", "double", "integer"],
},
)
Metric = TypedDict(
"Metric",
{
"timestamp": float,
"trace_id": Optional[str],
"span_id": Optional[str],
"name": str,
"type": MetricType,
"value": float,
"unit": Optional[str],
"attributes": dict[str, str | bool | float | int],
},
)
MetricProcessor = Callable[[Metric, Hint], Optional[Metric]]
# TODO: Make a proper type definition for this (PRs welcome!)
Breadcrumb = Dict[str, Any]
# TODO: Make a proper type definition for this (PRs welcome!)
BreadcrumbHint = Dict[str, Any]
# TODO: Make a proper type definition for this (PRs welcome!)
SamplingContext = Dict[str, Any]
EventProcessor = Callable[[Event, Hint], Optional[Event]]
ErrorProcessor = Callable[[Event, ExcInfo], Optional[Event]]
BreadcrumbProcessor = Callable[[Breadcrumb, BreadcrumbHint], Optional[Breadcrumb]]
TransactionProcessor = Callable[[Event, Hint], Optional[Event]]
LogProcessor = Callable[[Log, Hint], Optional[Log]]
TracesSampler = Callable[[SamplingContext], Union[float, int, bool]]
@@ -234,35 +291,16 @@ if TYPE_CHECKING:
"internal",
"profile",
"profile_chunk",
"metric_bucket",
"monitor",
"span",
"log_item",
"trace_metric",
]
SessionStatus = Literal["ok", "exited", "crashed", "abnormal"]
ContinuousProfilerMode = Literal["thread", "gevent", "unknown"]
ProfilerMode = Union[ContinuousProfilerMode, Literal["sleep"]]
# Type of the metric.
MetricType = Literal["d", "s", "g", "c"]
# Value of the metric.
MetricValue = Union[int, float, str]
# Internal representation of tags as a tuple of tuples (this is done in order to allow for the same key to exist
# multiple times).
MetricTagsInternal = Tuple[Tuple[str, str], ...]
# External representation of tags as a dictionary.
MetricTagValue = Union[str, int, float, None]
MetricTags = Mapping[str, MetricTagValue]
# Value inside the generator for the metric value.
FlushedMetricValue = Union[int, float]
BucketKey = Tuple[MetricType, str, MeasurementUnit, MetricTagsInternal]
MetricMetaKey = Tuple[MetricType, str, MeasurementUnit]
MonitorConfigScheduleType = Literal["crontab", "interval"]
MonitorConfigScheduleUnit = Literal[
"year",
@@ -0,0 +1,7 @@
from .utils import (
set_data_normalized,
GEN_AI_MESSAGE_ROLE_MAPPING,
GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING,
normalize_message_role,
normalize_message_roles,
) # noqa: F401
@@ -1,6 +1,7 @@
import inspect
from functools import wraps
from sentry_sdk.consts import SPANDATA
import sentry_sdk.utils
from sentry_sdk import start_span
from sentry_sdk.tracing import Span
@@ -9,7 +10,9 @@ from sentry_sdk.utils import ContextVar
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Callable, Any
from typing import Optional, Callable, Awaitable, Any, Union, TypeVar
F = TypeVar("F", bound=Union[Callable[..., Any], Callable[..., Awaitable[Any]]])
_ai_pipeline_name = ContextVar("ai_pipeline_name", default=None)
@@ -25,13 +28,13 @@ def get_ai_pipeline_name():
def ai_track(description, **span_kwargs):
# type: (str, Any) -> Callable[..., Any]
# type: (str, Any) -> Callable[[F], F]
def decorator(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
# type: (F) -> F
def sync_wrapped(*args, **kwargs):
# type: (Any, Any) -> Any
curr_pipeline = _ai_pipeline_name.get()
op = span_kwargs.get("op", "ai.run" if curr_pipeline else "ai.pipeline")
op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
with start_span(name=description, op=op, **span_kwargs) as span:
for k, v in kwargs.pop("sentry_tags", {}).items():
@@ -39,7 +42,7 @@ def ai_track(description, **span_kwargs):
for k, v in kwargs.pop("sentry_data", {}).items():
span.set_data(k, v)
if curr_pipeline:
span.set_data("ai.pipeline.name", curr_pipeline)
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
return f(*args, **kwargs)
else:
_ai_pipeline_name.set(description)
@@ -60,7 +63,7 @@ def ai_track(description, **span_kwargs):
async def async_wrapped(*args, **kwargs):
# type: (Any, Any) -> Any
curr_pipeline = _ai_pipeline_name.get()
op = span_kwargs.get("op", "ai.run" if curr_pipeline else "ai.pipeline")
op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
with start_span(name=description, op=op, **span_kwargs) as span:
for k, v in kwargs.pop("sentry_tags", {}).items():
@@ -68,7 +71,7 @@ def ai_track(description, **span_kwargs):
for k, v in kwargs.pop("sentry_data", {}).items():
span.set_data(k, v)
if curr_pipeline:
span.set_data("ai.pipeline.name", curr_pipeline)
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
return await f(*args, **kwargs)
else:
_ai_pipeline_name.set(description)
@@ -87,29 +90,48 @@ def ai_track(description, **span_kwargs):
return res
if inspect.iscoroutinefunction(f):
return wraps(f)(async_wrapped)
return wraps(f)(async_wrapped) # type: ignore
else:
return wraps(f)(sync_wrapped)
return wraps(f)(sync_wrapped) # type: ignore
return decorator
def record_token_usage(
span, prompt_tokens=None, completion_tokens=None, total_tokens=None
span,
input_tokens=None,
input_tokens_cached=None,
output_tokens=None,
output_tokens_reasoning=None,
total_tokens=None,
):
# type: (Span, Optional[int], Optional[int], Optional[int]) -> None
# type: (Span, Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]) -> None
# TODO: move pipeline name elsewhere
ai_pipeline_name = get_ai_pipeline_name()
if ai_pipeline_name:
span.set_data("ai.pipeline.name", ai_pipeline_name)
if prompt_tokens is not None:
span.set_measurement("ai_prompt_tokens_used", value=prompt_tokens)
if completion_tokens is not None:
span.set_measurement("ai_completion_tokens_used", value=completion_tokens)
if (
total_tokens is None
and prompt_tokens is not None
and completion_tokens is not None
):
total_tokens = prompt_tokens + completion_tokens
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, ai_pipeline_name)
if input_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
if input_tokens_cached is not None:
span.set_data(
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
input_tokens_cached,
)
if output_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
if output_tokens_reasoning is not None:
span.set_data(
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
output_tokens_reasoning,
)
if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens
if total_tokens is not None:
span.set_measurement("ai_total_tokens_used", total_tokens)
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)
@@ -1,32 +1,144 @@
import json
from collections import deque
from typing import TYPE_CHECKING
from sys import getsizeof
if TYPE_CHECKING:
from typing import Any
from typing import Any, Callable, Dict, List, Optional, Tuple
from sentry_sdk.tracing import Span
from sentry_sdk.tracing import Span
import sentry_sdk
from sentry_sdk.utils import logger
MAX_GEN_AI_MESSAGE_BYTES = 20_000 # 20KB
def _normalize_data(data):
# type: (Any) -> Any
class GEN_AI_ALLOWED_MESSAGE_ROLES:
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING = {
GEN_AI_ALLOWED_MESSAGE_ROLES.SYSTEM: ["system"],
GEN_AI_ALLOWED_MESSAGE_ROLES.USER: ["user", "human"],
GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT: ["assistant", "ai"],
GEN_AI_ALLOWED_MESSAGE_ROLES.TOOL: ["tool", "tool_call"],
}
GEN_AI_MESSAGE_ROLE_MAPPING = {}
for target_role, source_roles in GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING.items():
for source_role in source_roles:
GEN_AI_MESSAGE_ROLE_MAPPING[source_role] = target_role
def _normalize_data(data, unpack=True):
# type: (Any, bool) -> Any
# convert pydantic data (e.g. OpenAI v1+) to json compatible format
if hasattr(data, "model_dump"):
try:
return data.model_dump()
return _normalize_data(data.model_dump(), unpack=unpack)
except Exception as e:
logger.warning("Could not convert pydantic data to JSON: %s", e)
return data
return data if isinstance(data, (int, float, bool, str)) else str(data)
if isinstance(data, list):
if len(data) == 1:
return _normalize_data(data[0]) # remove empty dimensions
return list(_normalize_data(x) for x in data)
if unpack and len(data) == 1:
return _normalize_data(data[0], unpack=unpack) # remove empty dimensions
return list(_normalize_data(x, unpack=unpack) for x in data)
if isinstance(data, dict):
return {k: _normalize_data(v) for (k, v) in data.items()}
return data
return {k: _normalize_data(v, unpack=unpack) for (k, v) in data.items()}
return data if isinstance(data, (int, float, bool, str)) else str(data)
def set_data_normalized(span, key, value):
# type: (Span, str, Any) -> None
normalized = _normalize_data(value)
span.set_data(key, normalized)
def set_data_normalized(span, key, value, unpack=True):
# type: (Span, str, Any, bool) -> None
normalized = _normalize_data(value, unpack=unpack)
if isinstance(normalized, (int, float, bool, str)):
span.set_data(key, normalized)
else:
span.set_data(key, json.dumps(normalized))
def normalize_message_role(role):
# type: (str) -> str
"""
Normalize a message role to one of the 4 allowed gen_ai role values.
Maps "ai" -> "assistant" and keeps other standard roles unchanged.
"""
return GEN_AI_MESSAGE_ROLE_MAPPING.get(role, role)
def normalize_message_roles(messages):
# type: (list[dict[str, Any]]) -> list[dict[str, Any]]
"""
Normalize roles in a list of messages to use standard gen_ai role values.
Creates a deep copy to avoid modifying the original messages.
"""
normalized_messages = []
for message in messages:
if not isinstance(message, dict):
normalized_messages.append(message)
continue
normalized_message = message.copy()
if "role" in message:
normalized_message["role"] = normalize_message_role(message["role"])
normalized_messages.append(normalized_message)
return normalized_messages
def get_start_span_function():
# type: () -> Callable[..., Any]
current_span = sentry_sdk.get_current_span()
transaction_exists = (
current_span is not None and current_span.containing_transaction is not None
)
return sentry_sdk.start_span if transaction_exists else sentry_sdk.start_transaction
def _find_truncation_index(messages, max_bytes):
# type: (List[Dict[str, Any]], int) -> int
"""
Find the index of the first message that would exceed the max bytes limit.
Compute the individual message sizes, and return the index of the first message from the back
of the list that would exceed the max bytes limit.
"""
running_sum = 0
for idx in range(len(messages) - 1, -1, -1):
size = len(json.dumps(messages[idx], separators=(",", ":")).encode("utf-8"))
running_sum += size
if running_sum > max_bytes:
return idx + 1
return 0
def truncate_messages_by_size(messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES):
# type: (List[Dict[str, Any]], int) -> Tuple[List[Dict[str, Any]], int]
serialized_json = json.dumps(messages, separators=(",", ":"))
current_size = len(serialized_json.encode("utf-8"))
if current_size <= max_bytes:
return messages, 0
truncation_index = _find_truncation_index(messages, max_bytes)
return messages[truncation_index:], truncation_index
def truncate_and_annotate_messages(
messages, span, scope, max_bytes=MAX_GEN_AI_MESSAGE_BYTES
):
# type: (Optional[List[Dict[str, Any]]], Any, Any, int) -> Optional[List[Dict[str, Any]]]
if not messages:
return None
truncated_messages, removed_count = truncate_messages_by_size(messages, max_bytes)
if removed_count > 0:
scope._gen_ai_original_message_count[span.span_id] = len(messages)
return truncated_messages
@@ -51,6 +51,7 @@ else:
# When changing this, update __all__ in __init__.py too
__all__ = [
"init",
"add_attachment",
"add_breadcrumb",
"capture_event",
"capture_exception",
@@ -81,6 +82,10 @@ __all__ = [
"start_transaction",
"trace",
"monitor",
"start_session",
"end_session",
"set_transaction_name",
"update_current_span",
]
@@ -184,6 +189,20 @@ def capture_exception(
return get_current_scope().capture_exception(error, scope=scope, **scope_kwargs)
@scopemethod
def add_attachment(
bytes=None, # type: Union[None, bytes, Callable[[], bytes]]
filename=None, # type: Optional[str]
path=None, # type: Optional[str]
content_type=None, # type: Optional[str]
add_to_transactions=False, # type: bool
):
# type: (...) -> None
return get_isolation_scope().add_attachment(
bytes, filename, path, content_type, add_to_transactions
)
@scopemethod
def add_breadcrumb(
crumb=None, # type: Optional[Breadcrumb]
@@ -388,6 +407,10 @@ def start_transaction(
def set_measurement(name, value, unit=""):
# type: (str, float, MeasurementUnit) -> None
"""
.. deprecated:: 2.28.0
This function is deprecated and will be removed in the next major release.
"""
transaction = get_current_scope().transaction
if transaction is not None:
transaction.set_measurement(name, value, unit)
@@ -431,3 +454,102 @@ def continue_trace(
return get_isolation_scope().continue_trace(
environ_or_headers, op, name, source, origin
)
@scopemethod
def start_session(
session_mode="application", # type: str
):
# type: (...) -> None
return get_isolation_scope().start_session(session_mode=session_mode)
@scopemethod
def end_session():
# type: () -> None
return get_isolation_scope().end_session()
@scopemethod
def set_transaction_name(name, source=None):
# type: (str, Optional[str]) -> None
return get_current_scope().set_transaction_name(name, source)
def update_current_span(op=None, name=None, attributes=None, data=None):
# type: (Optional[str], Optional[str], Optional[dict[str, Union[str, int, float, bool]]], Optional[dict[str, Any]]) -> None
"""
Update the current active span with the provided parameters.
This function allows you to modify properties of the currently active span.
If no span is currently active, this function will do nothing.
:param op: The operation name for the span. This is a high-level description
of what the span represents (e.g., "http.client", "db.query").
You can use predefined constants from :py:class:`sentry_sdk.consts.OP`
or provide your own string. If not provided, the span's operation will
remain unchanged.
:type op: str or None
:param name: The human-readable name/description for the span. This provides
more specific details about what the span represents (e.g., "GET /api/users",
"SELECT * FROM users"). If not provided, the span's name will remain unchanged.
:type name: str or None
:param data: A dictionary of key-value pairs to add as data to the span. This
data will be merged with any existing span data. If not provided,
no data will be added.
.. deprecated:: 2.35.0
Use ``attributes`` instead. The ``data`` parameter will be removed
in a future version.
:type data: dict[str, Union[str, int, float, bool]] or None
:param attributes: A dictionary of key-value pairs to add as attributes to the span.
Attribute values must be strings, integers, floats, or booleans. These
attributes will be merged with any existing span data. If not provided,
no attributes will be added.
:type attributes: dict[str, Union[str, int, float, bool]] or None
:returns: None
.. versionadded:: 2.35.0
Example::
import sentry_sdk
from sentry_sdk.consts import OP
sentry_sdk.update_current_span(
op=OP.FUNCTION,
name="process_user_data",
attributes={"user_id": 123, "batch_size": 50}
)
"""
current_span = get_current_span()
if current_span is None:
return
if op is not None:
current_span.op = op
if name is not None:
# internally it is still description
current_span.description = name
if data is not None and attributes is not None:
raise ValueError(
"Cannot provide both `data` and `attributes`. Please use only `attributes`."
)
if data is not None:
warnings.warn(
"The `data` parameter is deprecated. Please use `attributes` instead.",
DeprecationWarning,
stacklevel=2,
)
attributes = data
if attributes is not None:
current_span.update_data(attributes)
@@ -8,6 +8,7 @@ from importlib import import_module
from typing import TYPE_CHECKING, List, Dict, cast, overload
import warnings
import sentry_sdk
from sentry_sdk._compat import PY37, check_uwsgi_thread_support
from sentry_sdk.utils import (
AnnotatedValue,
@@ -22,11 +23,16 @@ from sentry_sdk.utils import (
handle_in_app,
is_gevent,
logger,
get_before_send_log,
get_before_send_metric,
has_logs_enabled,
has_metrics_enabled,
)
from sentry_sdk.serializer import serialize
from sentry_sdk.tracing import trace
from sentry_sdk.transport import BaseHttpTransport, make_transport
from sentry_sdk.consts import (
SPANDATA,
DEFAULT_MAX_VALUE_LENGTH,
DEFAULT_OPTIONS,
INSTRUMENTER,
@@ -34,6 +40,7 @@ from sentry_sdk.consts import (
ClientConstructor,
)
from sentry_sdk.integrations import _DEFAULT_INTEGRATIONS, setup_integrations
from sentry_sdk.integrations.dedupe import DedupeIntegration
from sentry_sdk.sessions import SessionFlusher
from sentry_sdk.envelope import Envelope
from sentry_sdk.profiler.continuous_profiler import setup_continuous_profiler
@@ -44,7 +51,6 @@ from sentry_sdk.profiler.transaction_profiler import (
)
from sentry_sdk.scrubber import EventScrubber
from sentry_sdk.monitor import Monitor
from sentry_sdk.spotlight import setup_spotlight
if TYPE_CHECKING:
from typing import Any
@@ -55,13 +61,14 @@ if TYPE_CHECKING:
from typing import Union
from typing import TypeVar
from sentry_sdk._types import Event, Hint, SDKInfo
from sentry_sdk._types import Event, Hint, SDKInfo, Log, Metric
from sentry_sdk.integrations import Integration
from sentry_sdk.metrics import MetricsAggregator
from sentry_sdk.scope import Scope
from sentry_sdk.session import Session
from sentry_sdk.spotlight import SpotlightClient
from sentry_sdk.transport import Transport
from sentry_sdk._log_batcher import LogBatcher
from sentry_sdk._metrics_batcher import MetricsBatcher
I = TypeVar("I", bound=Integration) # noqa: E741
@@ -107,7 +114,7 @@ def _get_options(*args, **kwargs):
rv["environment"] = os.environ.get("SENTRY_ENVIRONMENT") or "production"
if rv["debug"] is None:
rv["debug"] = env_to_bool(os.environ.get("SENTRY_DEBUG", "False"), strict=True)
rv["debug"] = env_to_bool(os.environ.get("SENTRY_DEBUG"), strict=True) or False
if rv["server_name"] is None and hasattr(socket, "gethostname"):
rv["server_name"] = socket.gethostname()
@@ -139,6 +146,11 @@ def _get_options(*args, **kwargs):
)
rv["socket_options"] = None
if rv["keep_alive"] is None:
rv["keep_alive"] = (
env_to_bool(os.environ.get("SENTRY_KEEP_ALIVE"), strict=True) or False
)
if rv["enable_tracing"] is not None:
warnings.warn(
"The `enable_tracing` parameter is deprecated. Please use `traces_sample_rate` instead.",
@@ -168,13 +180,12 @@ class BaseClient:
def __init__(self, options=None):
# type: (Optional[Dict[str, Any]]) -> None
self.options = (
options if options is not None else DEFAULT_OPTIONS
) # type: Dict[str, Any]
self.options = options if options is not None else DEFAULT_OPTIONS # type: Dict[str, Any]
self.transport = None # type: Optional[Transport]
self.monitor = None # type: Optional[Monitor]
self.metrics_aggregator = None # type: Optional[MetricsAggregator]
self.log_batcher = None # type: Optional[LogBatcher]
self.metrics_batcher = None # type: Optional[MetricsBatcher]
def __getstate__(self, *args, **kwargs):
# type: (*Any, **Any) -> Any
@@ -206,6 +217,14 @@ class BaseClient:
# type: (*Any, **Any) -> Optional[str]
return None
def _capture_log(self, log):
# type: (Log) -> None
pass
def _capture_metric(self, metric):
# type: (Metric) -> None
pass
def capture_session(self, *args, **kwargs):
# type: (*Any, **Any) -> None
return None
@@ -348,25 +367,19 @@ class _Client(BaseClient):
self.session_flusher = SessionFlusher(capture_func=_capture_envelope)
self.metrics_aggregator = None # type: Optional[MetricsAggregator]
experiments = self.options.get("_experiments", {})
if experiments.get("enable_metrics", True):
# Context vars are not working correctly on Python <=3.6
# with gevent.
metrics_supported = not is_gevent() or PY37
if metrics_supported:
from sentry_sdk.metrics import MetricsAggregator
self.log_batcher = None
self.metrics_aggregator = MetricsAggregator(
capture_func=_capture_envelope,
enable_code_locations=bool(
experiments.get("metric_code_locations", True)
),
)
else:
logger.info(
"Metrics not supported on Python 3.6 and lower with gevent."
)
if has_logs_enabled(self.options):
from sentry_sdk._log_batcher import LogBatcher
self.log_batcher = LogBatcher(capture_func=_capture_envelope)
self.metrics_batcher = None
if has_metrics_enabled(self.options):
from sentry_sdk._metrics_batcher import MetricsBatcher
self.metrics_batcher = MetricsBatcher(capture_func=_capture_envelope)
max_request_body_size = ("always", "never", "small", "medium")
if self.options["max_request_body_size"] not in max_request_body_size:
@@ -409,7 +422,17 @@ class _Client(BaseClient):
)
if self.options.get("spotlight"):
# This is intentionally here to prevent setting up spotlight
# stuff we don't need unless spotlight is explicitly enabled
from sentry_sdk.spotlight import setup_spotlight
self.spotlight = setup_spotlight(self.options)
if not self.options["dsn"]:
sample_all = lambda *_args, **_kwargs: 1.0
self.options["send_default_pii"] = True
self.options["error_sampler"] = sample_all
self.options["traces_sampler"] = sample_all
self.options["profiles_sampler"] = sample_all
sdk_name = get_sdk_name(list(self.integrations.keys()))
SDK_INFO["name"] = sdk_name
@@ -437,7 +460,7 @@ class _Client(BaseClient):
if (
self.monitor
or self.metrics_aggregator
or self.log_batcher
or has_profiling_enabled(self.options)
or isinstance(self.transport, BaseHttpTransport)
):
@@ -461,11 +484,7 @@ class _Client(BaseClient):
Returns whether the client should send default PII (Personally Identifiable Information) data to Sentry.
"""
result = self.options.get("send_default_pii")
if result is None:
result = not self.options["dsn"] and self.spotlight is not None
return result
return self.options.get("send_default_pii") or False
@property
def dsn(self):
@@ -482,12 +501,14 @@ class _Client(BaseClient):
# type: (...) -> Optional[Event]
previous_total_spans = None # type: Optional[int]
previous_total_breadcrumbs = None # type: Optional[int]
if event.get("timestamp") is None:
event["timestamp"] = datetime.now(timezone.utc)
is_transaction = event.get("type") == "transaction"
if scope is not None:
is_transaction = event.get("type") == "transaction"
spans_before = len(cast(List[Dict[str, object]], event.get("spans", [])))
event_ = scope.apply_to_event(event, hint, self.options)
@@ -518,9 +539,20 @@ class _Client(BaseClient):
dropped_spans = event.pop("_dropped_spans", 0) + spans_delta # type: int
if dropped_spans > 0:
previous_total_spans = spans_before + dropped_spans
if scope._n_breadcrumbs_truncated > 0:
breadcrumbs = event.get("breadcrumbs", {})
values = (
breadcrumbs.get("values", [])
if not isinstance(breadcrumbs, AnnotatedValue)
else []
)
previous_total_breadcrumbs = (
len(values) + scope._n_breadcrumbs_truncated
)
if (
self.options["attach_stacktrace"]
not is_transaction
and self.options["attach_stacktrace"]
and "exception" not in event
and "stacktrace" not in event
and "threads" not in event
@@ -566,10 +598,30 @@ class _Client(BaseClient):
if event_scrubber:
event_scrubber.scrub_event(event)
if scope is not None and scope._gen_ai_original_message_count:
spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue
if isinstance(spans, list):
for span in spans:
span_id = span.get("span_id", None)
span_data = span.get("data", {})
if (
span_id
and span_id in scope._gen_ai_original_message_count
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
):
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
{"len": scope._gen_ai_original_message_count[span_id]},
)
if previous_total_spans is not None:
event["spans"] = AnnotatedValue(
event.get("spans", []), {"len": previous_total_spans}
)
if previous_total_breadcrumbs is not None:
event["breadcrumbs"] = AnnotatedValue(
event.get("breadcrumbs", {"values": []}),
{"len": previous_total_breadcrumbs},
)
# Postprocess the event here so that annotated types do
# generally not surface in before_send
@@ -599,6 +651,14 @@ class _Client(BaseClient):
self.transport.record_lost_event(
"before_send", data_category="error"
)
# If this is an exception, reset the DedupeIntegration. It still
# remembers the dropped exception as the last exception, meaning
# that if the same exception happens again and is not dropped
# in before_send, it'd get dropped by DedupeIntegration.
if event.get("exception"):
DedupeIntegration.reset_last_seen()
event = new_event
before_send_transaction = self.options["before_send_transaction"]
@@ -740,6 +800,8 @@ class _Client(BaseClient):
if exceptions:
errored = True
for error in exceptions:
if isinstance(error, AnnotatedValue):
error = error.value or {}
mechanism = error.get("mechanism")
if isinstance(mechanism, Mapping) and mechanism.get("handled") is False:
crashed = True
@@ -847,8 +909,135 @@ class _Client(BaseClient):
return return_value
def _capture_log(self, log):
# type: (Optional[Log]) -> None
if not has_logs_enabled(self.options) or log is None:
return
current_scope = sentry_sdk.get_current_scope()
isolation_scope = sentry_sdk.get_isolation_scope()
log["attributes"]["sentry.sdk.name"] = SDK_INFO["name"]
log["attributes"]["sentry.sdk.version"] = SDK_INFO["version"]
server_name = self.options.get("server_name")
if server_name is not None and SPANDATA.SERVER_ADDRESS not in log["attributes"]:
log["attributes"][SPANDATA.SERVER_ADDRESS] = server_name
environment = self.options.get("environment")
if environment is not None and "sentry.environment" not in log["attributes"]:
log["attributes"]["sentry.environment"] = environment
release = self.options.get("release")
if release is not None and "sentry.release" not in log["attributes"]:
log["attributes"]["sentry.release"] = release
span = current_scope.span
if span is not None and "sentry.trace.parent_span_id" not in log["attributes"]:
log["attributes"]["sentry.trace.parent_span_id"] = span.span_id
if log.get("trace_id") is None:
transaction = current_scope.transaction
propagation_context = isolation_scope.get_active_propagation_context()
if transaction is not None:
log["trace_id"] = transaction.trace_id
elif propagation_context is not None:
log["trace_id"] = propagation_context.trace_id
# The user, if present, is always set on the isolation scope.
if isolation_scope._user is not None:
for log_attribute, user_attribute in (
("user.id", "id"),
("user.name", "username"),
("user.email", "email"),
):
if (
user_attribute in isolation_scope._user
and log_attribute not in log["attributes"]
):
log["attributes"][log_attribute] = isolation_scope._user[
user_attribute
]
# If debug is enabled, log the log to the console
debug = self.options.get("debug", False)
if debug:
logger.debug(
f"[Sentry Logs] [{log.get('severity_text')}] {log.get('body')}"
)
before_send_log = get_before_send_log(self.options)
if before_send_log is not None:
log = before_send_log(log, {})
if log is None:
return
if self.log_batcher:
self.log_batcher.add(log)
def _capture_metric(self, metric):
# type: (Optional[Metric]) -> None
if not has_metrics_enabled(self.options) or metric is None:
return
isolation_scope = sentry_sdk.get_isolation_scope()
metric["attributes"]["sentry.sdk.name"] = SDK_INFO["name"]
metric["attributes"]["sentry.sdk.version"] = SDK_INFO["version"]
environment = self.options.get("environment")
if environment is not None and "sentry.environment" not in metric["attributes"]:
metric["attributes"]["sentry.environment"] = environment
release = self.options.get("release")
if release is not None and "sentry.release" not in metric["attributes"]:
metric["attributes"]["sentry.release"] = release
span = sentry_sdk.get_current_span()
metric["trace_id"] = "00000000-0000-0000-0000-000000000000"
if span:
metric["trace_id"] = span.trace_id
metric["span_id"] = span.span_id
else:
propagation_context = isolation_scope.get_active_propagation_context()
if propagation_context and propagation_context.trace_id:
metric["trace_id"] = propagation_context.trace_id
if isolation_scope._user is not None:
for metric_attribute, user_attribute in (
("user.id", "id"),
("user.name", "username"),
("user.email", "email"),
):
if (
user_attribute in isolation_scope._user
and metric_attribute not in metric["attributes"]
):
metric["attributes"][metric_attribute] = isolation_scope._user[
user_attribute
]
debug = self.options.get("debug", False)
if debug:
logger.debug(
f"[Sentry Metrics] [{metric.get('type')}] {metric.get('name')}: {metric.get('value')}"
)
before_send_metric = get_before_send_metric(self.options)
if before_send_metric is not None:
metric = before_send_metric(metric, {})
if metric is None:
return
if self.metrics_batcher:
self.metrics_batcher.add(metric)
def capture_session(
self, session # type: Session
self,
session, # type: Session
):
# type: (...) -> None
if not session.release:
@@ -869,7 +1058,8 @@ class _Client(BaseClient):
...
def get_integration(
self, name_or_class # type: Union[str, Type[Integration]]
self,
name_or_class, # type: Union[str, Type[Integration]]
):
# type: (...) -> Optional[Integration]
"""Returns the integration for this client by name or class.
@@ -897,8 +1087,10 @@ class _Client(BaseClient):
if self.transport is not None:
self.flush(timeout=timeout, callback=callback)
self.session_flusher.kill()
if self.metrics_aggregator is not None:
self.metrics_aggregator.kill()
if self.log_batcher is not None:
self.log_batcher.kill()
if self.metrics_batcher is not None:
self.metrics_batcher.kill()
if self.monitor:
self.monitor.kill()
self.transport.kill()
@@ -921,8 +1113,10 @@ class _Client(BaseClient):
if timeout is None:
timeout = self.options["shutdown_timeout"]
self.session_flusher.flush()
if self.metrics_aggregator is not None:
self.metrics_aggregator.flush()
if self.log_batcher is not None:
self.log_batcher.flush()
if self.metrics_batcher is not None:
self.metrics_batcher.flush()
self.transport.flush(timeout=timeout, callback=callback)
def __enter__(self):
File diff suppressed because it is too large Load Diff
@@ -1,6 +1,7 @@
import uuid
import sentry_sdk
from sentry_sdk.utils import logger
from typing import TYPE_CHECKING
@@ -54,4 +55,8 @@ def capture_checkin(
sentry_sdk.capture_event(check_in_event)
logger.debug(
f"[Crons] Captured check-in ({check_in_event.get('check_in_id')}): {check_in_event.get('monitor_slug')} -> {check_in_event.get('status')}"
)
return check_in_event["check_in_id"]
@@ -57,39 +57,49 @@ class Envelope:
)
def add_event(
self, event # type: Event
self,
event, # type: Event
):
# type: (...) -> None
self.add_item(Item(payload=PayloadRef(json=event), type="event"))
def add_transaction(
self, transaction # type: Event
self,
transaction, # type: Event
):
# type: (...) -> None
self.add_item(Item(payload=PayloadRef(json=transaction), type="transaction"))
def add_profile(
self, profile # type: Any
self,
profile, # type: Any
):
# type: (...) -> None
self.add_item(Item(payload=PayloadRef(json=profile), type="profile"))
def add_profile_chunk(
self, profile_chunk # type: Any
self,
profile_chunk, # type: Any
):
# type: (...) -> None
self.add_item(
Item(payload=PayloadRef(json=profile_chunk), type="profile_chunk")
Item(
payload=PayloadRef(json=profile_chunk),
type="profile_chunk",
headers={"platform": profile_chunk.get("platform", "python")},
)
)
def add_checkin(
self, checkin # type: Any
self,
checkin, # type: Any
):
# type: (...) -> None
self.add_item(Item(payload=PayloadRef(json=checkin), type="check_in"))
def add_session(
self, session # type: Union[Session, Any]
self,
session, # type: Union[Session, Any]
):
# type: (...) -> None
if isinstance(session, Session):
@@ -97,13 +107,15 @@ class Envelope:
self.add_item(Item(payload=PayloadRef(json=session), type="session"))
def add_sessions(
self, sessions # type: Any
self,
sessions, # type: Any
):
# type: (...) -> None
self.add_item(Item(payload=PayloadRef(json=sessions), type="sessions"))
def add_item(
self, item # type: Item
self,
item, # type: Item
):
# type: (...) -> None
self.items.append(item)
@@ -129,7 +141,8 @@ class Envelope:
return iter(self.items)
def serialize_into(
self, f # type: Any
self,
f, # type: Any
):
# type: (...) -> None
f.write(json_dumps(self.headers))
@@ -145,7 +158,8 @@ class Envelope:
@classmethod
def deserialize_from(
cls, f # type: Any
cls,
f, # type: Any
):
# type: (...) -> Envelope
headers = parse_json(f.readline())
@@ -159,7 +173,8 @@ class Envelope:
@classmethod
def deserialize(
cls, bytes # type: bytes
cls,
bytes, # type: bytes
):
# type: (...) -> Envelope
return cls.deserialize_from(io.BytesIO(bytes))
@@ -268,14 +283,16 @@ class Item:
return "transaction"
elif ty == "event":
return "error"
elif ty == "log":
return "log_item"
elif ty == "trace_metric":
return "trace_metric"
elif ty == "client_report":
return "internal"
elif ty == "profile":
return "profile"
elif ty == "profile_chunk":
return "profile_chunk"
elif ty == "statsd":
return "metric_bucket"
elif ty == "check_in":
return "monitor"
else:
@@ -301,7 +318,8 @@ class Item:
return None
def serialize_into(
self, f # type: Any
self,
f, # type: Any
):
# type: (...) -> None
headers = dict(self.headers)
@@ -320,7 +338,8 @@ class Item:
@classmethod
def deserialize_from(
cls, f # type: Any
cls,
f, # type: Any
):
# type: (...) -> Optional[Item]
line = f.readline().rstrip()
@@ -335,7 +354,7 @@ class Item:
# if no length was specified we need to read up to the end of line
# and remove it (if it is present, i.e. not the very last char in an eof terminated envelope)
payload = f.readline().rstrip(b"\n")
if headers.get("type") in ("event", "transaction", "metric_buckets"):
if headers.get("type") in ("event", "transaction"):
rv = cls(headers=headers, payload=PayloadRef(json=parse_json(payload)))
else:
rv = cls(headers=headers, payload=payload)
@@ -343,7 +362,8 @@ class Item:
@classmethod
def deserialize(
cls, bytes # type: bytes
cls,
bytes, # type: bytes
):
# type: (...) -> Optional[Item]
return cls.deserialize_from(io.BytesIO(bytes))
@@ -15,7 +15,6 @@ DEFAULT_FLAG_CAPACITY = 100
class FlagBuffer:
def __init__(self, capacity):
# type: (int) -> None
self.capacity = capacity
@@ -64,5 +63,9 @@ def add_feature_flag(flag, result):
Records a flag and its value to be sent on subsequent error events.
We recommend you do this on flag evaluations. Flags are buffered per Sentry scope.
"""
flags = sentry_sdk.get_current_scope().flags
flags = sentry_sdk.get_isolation_scope().flags
flags.set(flag, result)
span = sentry_sdk.get_current_span()
if span:
span.set_flag(f"flag.evaluation.{flag}", result)
@@ -205,7 +205,8 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
scope._isolation_scope.set(old_isolation_scope)
def run(
self, callback # type: Callable[[], T]
self,
callback, # type: Callable[[], T]
):
# type: (...) -> T
"""
@@ -219,7 +220,8 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
return callback()
def get_integration(
self, name_or_class # type: Union[str, Type[Integration]]
self,
name_or_class, # type: Union[str, Type[Integration]]
):
# type: (...) -> Any
"""
@@ -277,7 +279,8 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
return self._last_event_id
def bind_client(
self, new # type: Optional[BaseClient]
self,
new, # type: Optional[BaseClient]
):
# type: (...) -> None
"""
@@ -430,7 +433,7 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
transaction=None,
instrumenter=INSTRUMENTER.SENTRY,
custom_sampling_context=None,
**kwargs
**kwargs,
):
# type: (Optional[Transaction], str, Optional[SamplingContext], Unpack[TransactionKwargs]) -> Union[Transaction, NoOpSpan]
"""
@@ -487,14 +490,16 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
@overload
def push_scope(
self, callback=None # type: Optional[None]
self,
callback=None, # type: Optional[None]
):
# type: (...) -> ContextManager[Scope]
pass
@overload
def push_scope( # noqa: F811
self, callback # type: Callable[[Scope], None]
self,
callback, # type: Callable[[Scope], None]
):
# type: (...) -> None
pass
@@ -540,14 +545,16 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
@overload
def configure_scope(
self, callback=None # type: Optional[None]
self,
callback=None, # type: Optional[None]
):
# type: (...) -> ContextManager[Scope]
pass
@overload
def configure_scope( # noqa: F811
self, callback # type: Callable[[Scope], None]
self,
callback, # type: Callable[[Scope], None]
):
# type: (...) -> None
pass
@@ -587,7 +594,8 @@ class Hub(with_metaclass(HubMeta)): # type: ignore
return inner()
def start_session(
self, session_mode="application" # type: str
self,
session_mode="application", # type: str
):
# type: (...) -> None
"""
@@ -95,6 +95,7 @@ _AUTO_ENABLING_INTEGRATIONS = [
"sentry_sdk.integrations.huey.HueyIntegration",
"sentry_sdk.integrations.huggingface_hub.HuggingfaceHubIntegration",
"sentry_sdk.integrations.langchain.LangchainIntegration",
"sentry_sdk.integrations.langgraph.LanggraphIntegration",
"sentry_sdk.integrations.litestar.LitestarIntegration",
"sentry_sdk.integrations.loguru.LoguruIntegration",
"sentry_sdk.integrations.openai.OpenAIIntegration",
@@ -131,6 +132,7 @@ _MIN_VERSIONS = {
"celery": (4, 4, 7),
"chalice": (1, 16, 0),
"clickhouse_driver": (0, 2, 0),
"cohere": (5, 4, 0),
"django": (1, 8),
"dramatiq": (1, 9),
"falcon": (1, 4),
@@ -138,13 +140,20 @@ _MIN_VERSIONS = {
"flask": (1, 1, 4),
"gql": (3, 4, 1),
"graphene": (3, 3),
"google_genai": (1, 29, 0), # google-genai
"grpc": (1, 32, 0), # grpcio
"huggingface_hub": (0, 22),
"langchain": (0, 0, 210),
"httpx": (0, 16, 0),
"huggingface_hub": (0, 24, 7),
"langchain": (0, 1, 0),
"langgraph": (0, 6, 6),
"launchdarkly": (9, 8, 0),
"litellm": (1, 77, 5),
"loguru": (0, 7, 0),
"mcp": (1, 15, 0),
"openai": (1, 0, 0),
"openai_agents": (0, 0, 19),
"openfeature": (0, 7, 1),
"pydantic_ai": (1, 0, 0),
"quart": (0, 16, 0),
"ray": (2, 7, 0),
"requests": (2, 0, 0),
@@ -20,9 +20,9 @@ from sentry_sdk.integrations._wsgi_common import (
from sentry_sdk.tracing import (
BAGGAGE_HEADER_NAME,
SOURCE_FOR_STYLE,
TRANSACTION_SOURCE_ROUTE,
TransactionSource,
)
from sentry_sdk.tracing_utils import should_propagate_trace
from sentry_sdk.tracing_utils import should_propagate_trace, add_http_request_source
from sentry_sdk.utils import (
capture_internal_exceptions,
ensure_integration_enabled,
@@ -129,7 +129,7 @@ class AioHttpIntegration(Integration):
# If this transaction name makes it to the UI, AIOHTTP's
# URL resolver did not find a route or died trying.
name="generic AIOHTTP request",
source=TRANSACTION_SOURCE_ROUTE,
source=TransactionSource.ROUTE,
origin=AioHttpIntegration.origin,
)
with sentry_sdk.start_transaction(
@@ -279,6 +279,9 @@ def create_trace_config():
span.set_data("reason", params.response.reason)
span.finish()
with capture_internal_exceptions():
add_http_request_source(span)
trace_config = TraceConfig()
trace_config.on_request_start.append(on_request_start)
@@ -3,16 +3,29 @@ from typing import TYPE_CHECKING
import sentry_sdk
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.ai.utils import (
set_data_normalized,
normalize_message_roles,
truncate_and_annotate_messages,
get_start_span_function,
)
from sentry_sdk.consts import OP, SPANDATA, SPANSTATUS
from sentry_sdk.integrations import _check_minimum_version, DidNotEnable, Integration
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing_utils import set_span_errored
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
package_version,
safe_serialize,
)
try:
try:
from anthropic import NOT_GIVEN
except ImportError:
NOT_GIVEN = None
from anthropic.resources import AsyncMessages, Messages
if TYPE_CHECKING:
@@ -45,6 +58,8 @@ class AnthropicIntegration(Integration):
def _capture_exception(exc):
# type: (Any) -> None
set_span_errored()
event, hint = event_from_exception(
exc,
client_options=sentry_sdk.get_client().options,
@@ -53,8 +68,11 @@ def _capture_exception(exc):
sentry_sdk.capture_event(event, hint=hint)
def _calculate_token_usage(result, span):
# type: (Messages, Span) -> None
def _get_token_usage(result):
# type: (Messages) -> tuple[int, int]
"""
Get token usage from the Anthropic response.
"""
input_tokens = 0
output_tokens = 0
if hasattr(result, "usage"):
@@ -64,31 +82,13 @@ def _calculate_token_usage(result, span):
if hasattr(usage, "output_tokens") and isinstance(usage.output_tokens, int):
output_tokens = usage.output_tokens
total_tokens = input_tokens + output_tokens
record_token_usage(span, input_tokens, output_tokens, total_tokens)
return input_tokens, output_tokens
def _get_responses(content):
# type: (list[Any]) -> list[dict[str, Any]]
def _collect_ai_data(event, model, input_tokens, output_tokens, content_blocks):
# type: (MessageStreamEvent, str | None, int, int, list[str]) -> tuple[str | None, int, int, list[str]]
"""
Get JSON of a Anthropic responses.
"""
responses = []
for item in content:
if hasattr(item, "text"):
responses.append(
{
"type": item.type,
"text": item.text,
}
)
return responses
def _collect_ai_data(event, input_tokens, output_tokens, content_blocks):
# type: (MessageStreamEvent, int, int, list[str]) -> tuple[int, int, list[str]]
"""
Count token usage and collect content blocks from the AI streaming response.
Collect model information, token usage, and collect content blocks from the AI streaming response.
"""
with capture_internal_exceptions():
if hasattr(event, "type"):
@@ -96,36 +96,135 @@ def _collect_ai_data(event, input_tokens, output_tokens, content_blocks):
usage = event.message.usage
input_tokens += usage.input_tokens
output_tokens += usage.output_tokens
model = event.message.model or model
elif event.type == "content_block_start":
pass
elif event.type == "content_block_delta":
if hasattr(event.delta, "text"):
content_blocks.append(event.delta.text)
elif hasattr(event.delta, "partial_json"):
content_blocks.append(event.delta.partial_json)
elif event.type == "content_block_stop":
pass
elif event.type == "message_delta":
output_tokens += event.usage.output_tokens
return input_tokens, output_tokens, content_blocks
return model, input_tokens, output_tokens, content_blocks
def _add_ai_data_to_span(
span, integration, input_tokens, output_tokens, content_blocks
):
# type: (Span, AnthropicIntegration, int, int, list[str]) -> None
def _set_input_data(span, kwargs, integration):
# type: (Span, dict[str, Any], AnthropicIntegration) -> None
"""
Add token usage and content blocks from the AI streaming response to the span.
Set input data for the span based on the provided keyword arguments for the anthropic message creation.
"""
with capture_internal_exceptions():
if should_send_default_pii() and integration.include_prompts:
complete_message = "".join(content_blocks)
span.set_data(
SPANDATA.AI_RESPONSES,
[{"type": "text", "text": complete_message}],
messages = kwargs.get("messages")
if (
messages is not None
and len(messages) > 0
and should_send_default_pii()
and integration.include_prompts
):
normalized_messages = []
for message in messages:
if (
message.get("role") == "user"
and "content" in message
and isinstance(message["content"], (list, tuple))
):
for item in message["content"]:
if item.get("type") == "tool_result":
normalized_messages.append(
{
"role": "tool",
"content": {
"tool_use_id": item.get("tool_use_id"),
"output": item.get("content"),
},
}
)
else:
normalized_messages.append(message)
role_normalized_messages = normalize_message_roles(normalized_messages)
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_messages(
role_normalized_messages, span, scope
)
if messages_data is not None:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
)
total_tokens = input_tokens + output_tokens
record_token_usage(span, input_tokens, output_tokens, total_tokens)
span.set_data(SPANDATA.AI_STREAMING, True)
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_STREAMING, kwargs.get("stream", False)
)
kwargs_keys_to_attributes = {
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
"top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
}
for key, attribute in kwargs_keys_to_attributes.items():
value = kwargs.get(key)
if value is not NOT_GIVEN and value is not None:
set_data_normalized(span, attribute, value)
# Input attributes: Tools
tools = kwargs.get("tools")
if tools is not NOT_GIVEN and tools is not None and len(tools) > 0:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
)
def _set_output_data(
span,
integration,
model,
input_tokens,
output_tokens,
content_blocks,
finish_span=False,
):
# type: (Span, AnthropicIntegration, str | None, int | None, int | None, list[Any], bool) -> None
"""
Set output data for the span based on the AI response."""
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, model)
if should_send_default_pii() and integration.include_prompts:
output_messages = {
"response": [],
"tool": [],
} # type: (dict[str, list[Any]])
for output in content_blocks:
if output["type"] == "text":
output_messages["response"].append(output["text"])
elif output["type"] == "tool_use":
output_messages["tool"].append(output)
if len(output_messages["tool"]) > 0:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
output_messages["tool"],
unpack=False,
)
if len(output_messages["response"]) > 0:
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, output_messages["response"]
)
record_token_usage(
span,
input_tokens=input_tokens,
output_tokens=output_tokens,
)
if finish_span:
span.__exit__(None, None, None)
def _sentry_patched_create_common(f, *args, **kwargs):
@@ -142,31 +241,41 @@ def _sentry_patched_create_common(f, *args, **kwargs):
except TypeError:
return f(*args, **kwargs)
span = sentry_sdk.start_span(
op=OP.ANTHROPIC_MESSAGES_CREATE,
description="Anthropic messages create",
model = kwargs.get("model", "")
span = get_start_span_function()(
op=OP.GEN_AI_CHAT,
name=f"chat {model}".strip(),
origin=AnthropicIntegration.origin,
)
span.__enter__()
_set_input_data(span, kwargs, integration)
result = yield f, args, kwargs
# add data to span and finish it
messages = list(kwargs["messages"])
model = kwargs.get("model")
with capture_internal_exceptions():
span.set_data(SPANDATA.AI_MODEL_ID, model)
span.set_data(SPANDATA.AI_STREAMING, False)
if should_send_default_pii() and integration.include_prompts:
span.set_data(SPANDATA.AI_INPUT_MESSAGES, messages)
if hasattr(result, "content"):
if should_send_default_pii() and integration.include_prompts:
span.set_data(SPANDATA.AI_RESPONSES, _get_responses(result.content))
_calculate_token_usage(result, span)
span.__exit__(None, None, None)
input_tokens, output_tokens = _get_token_usage(result)
content_blocks = []
for content_block in result.content:
if hasattr(content_block, "to_dict"):
content_blocks.append(content_block.to_dict())
elif hasattr(content_block, "model_dump"):
content_blocks.append(content_block.model_dump())
elif hasattr(content_block, "text"):
content_blocks.append({"type": "text", "text": content_block.text})
_set_output_data(
span=span,
integration=integration,
model=getattr(result, "model", None),
input_tokens=input_tokens,
output_tokens=output_tokens,
content_blocks=content_blocks,
finish_span=True,
)
# Streaming response
elif hasattr(result, "_iterator"):
@@ -174,39 +283,53 @@ def _sentry_patched_create_common(f, *args, **kwargs):
def new_iterator():
# type: () -> Iterator[MessageStreamEvent]
model = None
input_tokens = 0
output_tokens = 0
content_blocks = [] # type: list[str]
for event in old_iterator:
input_tokens, output_tokens, content_blocks = _collect_ai_data(
event, input_tokens, output_tokens, content_blocks
model, input_tokens, output_tokens, content_blocks = (
_collect_ai_data(
event, model, input_tokens, output_tokens, content_blocks
)
)
if event.type != "message_stop":
yield event
yield event
_add_ai_data_to_span(
span, integration, input_tokens, output_tokens, content_blocks
_set_output_data(
span=span,
integration=integration,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
finish_span=True,
)
span.__exit__(None, None, None)
async def new_iterator_async():
# type: () -> AsyncIterator[MessageStreamEvent]
model = None
input_tokens = 0
output_tokens = 0
content_blocks = [] # type: list[str]
async for event in old_iterator:
input_tokens, output_tokens, content_blocks = _collect_ai_data(
event, input_tokens, output_tokens, content_blocks
model, input_tokens, output_tokens, content_blocks = (
_collect_ai_data(
event, model, input_tokens, output_tokens, content_blocks
)
)
if event.type != "message_stop":
yield event
yield event
_add_ai_data_to_span(
span, integration, input_tokens, output_tokens, content_blocks
_set_output_data(
span=span,
integration=integration,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
finish_span=True,
)
span.__exit__(None, None, None)
if str(type(result._iterator)) == "<class 'async_generator'>":
result._iterator = new_iterator_async()
@@ -248,7 +371,13 @@ def _wrap_message_create(f):
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
kwargs["integration"] = integration
return _execute_sync(f, *args, **kwargs)
try:
return _execute_sync(f, *args, **kwargs)
finally:
span = sentry_sdk.get_current_span()
if span is not None and span.status == SPANSTATUS.ERROR:
with capture_internal_exceptions():
span.__exit__(None, None, None)
return _sentry_patched_create_sync
@@ -281,6 +410,12 @@ def _wrap_message_create_async(f):
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
kwargs["integration"] = integration
return await _execute_async(f, *args, **kwargs)
try:
return await _execute_async(f, *args, **kwargs)
finally:
span = sentry_sdk.get_current_span()
if span is not None and span.status == SPANSTATUS.ERROR:
with capture_internal_exceptions():
span.__exit__(None, None, None)
return _sentry_patched_create_async
@@ -5,7 +5,7 @@ from sentry_sdk.consts import OP, SPANSTATUS
from sentry_sdk.integrations import _check_minimum_version, DidNotEnable, Integration
from sentry_sdk.integrations.logging import ignore_logger
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing import Transaction, TRANSACTION_SOURCE_TASK
from sentry_sdk.tracing import Transaction, TransactionSource
from sentry_sdk.utils import (
capture_internal_exceptions,
ensure_integration_enabled,
@@ -102,7 +102,7 @@ def patch_run_job():
name="unknown arq task",
status="ok",
op=OP.QUEUE_TASK_ARQ,
source=TRANSACTION_SOURCE_TASK,
source=TransactionSource.TASK,
origin=ArqIntegration.origin,
)
@@ -199,12 +199,13 @@ def patch_create_worker():
if isinstance(settings_cls, dict):
if "functions" in settings_cls:
settings_cls["functions"] = [
_get_arq_function(func) for func in settings_cls["functions"]
_get_arq_function(func)
for func in settings_cls.get("functions", [])
]
if "cron_jobs" in settings_cls:
settings_cls["cron_jobs"] = [
_get_arq_cron_job(cron_job)
for cron_job in settings_cls["cron_jobs"]
for cron_job in settings_cls.get("cron_jobs", [])
]
if hasattr(settings_cls, "functions"):
@@ -213,16 +214,17 @@ def patch_create_worker():
]
if hasattr(settings_cls, "cron_jobs"):
settings_cls.cron_jobs = [
_get_arq_cron_job(cron_job) for cron_job in settings_cls.cron_jobs
_get_arq_cron_job(cron_job)
for cron_job in (settings_cls.cron_jobs or [])
]
if "functions" in kwargs:
kwargs["functions"] = [
_get_arq_function(func) for func in kwargs["functions"]
_get_arq_function(func) for func in kwargs.get("functions", [])
]
if "cron_jobs" in kwargs:
kwargs["cron_jobs"] = [
_get_arq_cron_job(cron_job) for cron_job in kwargs["cron_jobs"]
_get_arq_cron_job(cron_job) for cron_job in kwargs.get("cron_jobs", [])
]
return old_create_worker(*args, **kwargs)
@@ -12,7 +12,6 @@ from functools import partial
import sentry_sdk
from sentry_sdk.api import continue_trace
from sentry_sdk.consts import OP
from sentry_sdk.integrations._asgi_common import (
_get_headers,
_get_request_data,
@@ -25,10 +24,7 @@ from sentry_sdk.integrations._wsgi_common import (
from sentry_sdk.sessions import track_session
from sentry_sdk.tracing import (
SOURCE_FOR_STYLE,
TRANSACTION_SOURCE_ROUTE,
TRANSACTION_SOURCE_URL,
TRANSACTION_SOURCE_COMPONENT,
TRANSACTION_SOURCE_CUSTOM,
TransactionSource,
)
from sentry_sdk.utils import (
ContextVar,
@@ -45,7 +41,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Tuple
@@ -105,6 +100,7 @@ class SentryAsgiMiddleware:
mechanism_type="asgi", # type: str
span_origin="manual", # type: str
http_methods_to_capture=DEFAULT_HTTP_METHODS_TO_CAPTURE, # type: Tuple[str, ...]
asgi_version=None, # type: Optional[int]
):
# type: (...) -> None
"""
@@ -143,10 +139,32 @@ class SentryAsgiMiddleware:
self.app = app
self.http_methods_to_capture = http_methods_to_capture
if _looks_like_asgi3(app):
self.__call__ = self._run_asgi3 # type: Callable[..., Any]
else:
self.__call__ = self._run_asgi2
if asgi_version is None:
if _looks_like_asgi3(app):
asgi_version = 3
else:
asgi_version = 2
if asgi_version == 3:
self.__call__ = self._run_asgi3
elif asgi_version == 2:
self.__call__ = self._run_asgi2 # type: ignore
def _capture_lifespan_exception(self, exc):
# type: (Exception) -> None
"""Capture exceptions raise in application lifespan handlers.
The separate function is needed to support overriding in derived integrations that use different catching mechanisms.
"""
return _capture_exception(exc=exc, mechanism_type=self.mechanism_type)
def _capture_request_exception(self, exc):
# type: (Exception) -> None
"""Capture exceptions raised in incoming request handlers.
The separate function is needed to support overriding in derived integrations that use different catching mechanisms.
"""
return _capture_exception(exc=exc, mechanism_type=self.mechanism_type)
def _run_asgi2(self, scope):
# type: (Any) -> Any
@@ -161,7 +179,7 @@ class SentryAsgiMiddleware:
return await self._run_app(scope, receive, send, asgi_version=3)
async def _run_app(self, scope, receive, send, asgi_version):
# type: (Any, Any, Any, Any, int) -> Any
# type: (Any, Any, Any, int) -> Any
is_recursive_asgi_middleware = _asgi_middleware_applied.get(False)
is_lifespan = scope["type"] == "lifespan"
if is_recursive_asgi_middleware or is_lifespan:
@@ -172,7 +190,7 @@ class SentryAsgiMiddleware:
return await self.app(scope, receive, send)
except Exception as exc:
_capture_exception(exc, mechanism_type=self.mechanism_type)
self._capture_lifespan_exception(exc)
raise exc from None
_asgi_middleware_applied.set(True)
@@ -195,8 +213,8 @@ class SentryAsgiMiddleware:
method = scope.get("method", "").upper()
transaction = None
if method in self.http_methods_to_capture:
if ty in ("http", "websocket"):
if ty in ("http", "websocket"):
if ty == "websocket" or method in self.http_methods_to_capture:
transaction = continue_trace(
_get_headers(scope),
op="{}.server".format(ty),
@@ -204,37 +222,26 @@ class SentryAsgiMiddleware:
source=transaction_source,
origin=self.span_origin,
)
logger.debug(
"[ASGI] Created transaction (continuing trace): %s",
transaction,
)
else:
transaction = Transaction(
op=OP.HTTP_SERVER,
name=transaction_name,
source=transaction_source,
origin=self.span_origin,
)
logger.debug(
"[ASGI] Created transaction (new): %s", transaction
)
transaction.set_tag("asgi.type", ty)
logger.debug(
"[ASGI] Set transaction name and source on transaction: '%s' / '%s'",
transaction.name,
transaction.source,
else:
transaction = Transaction(
op=OP.HTTP_SERVER,
name=transaction_name,
source=transaction_source,
origin=self.span_origin,
)
with (
if transaction:
transaction.set_tag("asgi.type", ty)
transaction_context = (
sentry_sdk.start_transaction(
transaction,
custom_sampling_context={"asgi_scope": scope},
)
if transaction is not None
else nullcontext()
):
logger.debug("[ASGI] Started transaction: %s", transaction)
)
with transaction_context:
try:
async def _sentry_wrapped_send(event):
@@ -258,7 +265,7 @@ class SentryAsgiMiddleware:
scope, receive, _sentry_wrapped_send
)
except Exception as exc:
_capture_exception(exc, mechanism_type=self.mechanism_type)
self._capture_request_exception(exc)
raise exc from None
finally:
_asgi_middleware_applied.set(False)
@@ -270,13 +277,18 @@ class SentryAsgiMiddleware:
event["request"] = deepcopy(request_data)
# Only set transaction name if not already set by Starlette or FastAPI (or other frameworks)
already_set = event["transaction"] != _DEFAULT_TRANSACTION_NAME and event[
"transaction_info"
].get("source") in [
TRANSACTION_SOURCE_COMPONENT,
TRANSACTION_SOURCE_ROUTE,
TRANSACTION_SOURCE_CUSTOM,
]
transaction = event.get("transaction")
transaction_source = (event.get("transaction_info") or {}).get("source")
already_set = (
transaction is not None
and transaction != _DEFAULT_TRANSACTION_NAME
and transaction_source
in [
TransactionSource.COMPONENT,
TransactionSource.ROUTE,
TransactionSource.CUSTOM,
]
)
if not already_set:
name, source = self._get_transaction_name_and_source(
self.transaction_style, asgi_scope
@@ -284,12 +296,6 @@ class SentryAsgiMiddleware:
event["transaction"] = name
event["transaction_info"] = {"source": source}
logger.debug(
"[ASGI] Set transaction name and source in event_processor: '%s' / '%s'",
event["transaction"],
event["transaction_info"]["source"],
)
return event
# Helper functions.
@@ -313,7 +319,7 @@ class SentryAsgiMiddleware:
name = transaction_from_function(endpoint) or ""
else:
name = _get_url(asgi_scope, "http" if ty == "http" else "ws", host=None)
source = TRANSACTION_SOURCE_URL
source = TransactionSource.URL
elif transaction_style == "url":
# FastAPI includes the route object in the scope to let Sentry extract the
@@ -325,11 +331,11 @@ class SentryAsgiMiddleware:
name = path
else:
name = _get_url(asgi_scope, "http" if ty == "http" else "ws", host=None)
source = TRANSACTION_SOURCE_URL
source = TransactionSource.URL
if name is None:
name = _DEFAULT_TRANSACTION_NAME
source = TRANSACTION_SOURCE_ROUTE
source = TransactionSource.ROUTE
return name, source
return name, source
@@ -3,7 +3,7 @@ import sys
import sentry_sdk
from sentry_sdk.consts import OP
from sentry_sdk.integrations import Integration, DidNotEnable
from sentry_sdk.utils import event_from_exception, reraise
from sentry_sdk.utils import event_from_exception, logger, reraise
try:
import asyncio
@@ -11,7 +11,7 @@ try:
except ImportError:
raise DidNotEnable("asyncio not available")
from typing import TYPE_CHECKING
from typing import cast, TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
@@ -39,7 +39,7 @@ def patch_asyncio():
def _sentry_task_factory(loop, coro, **kwargs):
# type: (asyncio.AbstractEventLoop, Coroutine[Any, Any, Any], Any) -> asyncio.Future[Any]
async def _coro_creating_hub_and_span():
async def _task_with_sentry_span_creation():
# type: () -> Any
result = None
@@ -51,32 +51,54 @@ def patch_asyncio():
):
try:
result = await coro
except StopAsyncIteration as e:
raise e from None
except Exception:
reraise(*_capture_exception())
return result
task = None
# Trying to use user set task factory (if there is one)
if orig_task_factory:
return orig_task_factory(loop, _coro_creating_hub_and_span(), **kwargs)
task = orig_task_factory(
loop, _task_with_sentry_span_creation(), **kwargs
)
# The default task factory in `asyncio` does not have its own function
# but is just a couple of lines in `asyncio.base_events.create_task()`
# Those lines are copied here.
if task is None:
# The default task factory in `asyncio` does not have its own function
# but is just a couple of lines in `asyncio.base_events.create_task()`
# Those lines are copied here.
# WARNING:
# If the default behavior of the task creation in asyncio changes,
# this will break!
task = Task(_coro_creating_hub_and_span(), loop=loop, **kwargs)
if task._source_traceback: # type: ignore
del task._source_traceback[-1] # type: ignore
# WARNING:
# If the default behavior of the task creation in asyncio changes,
# this will break!
task = Task(_task_with_sentry_span_creation(), loop=loop, **kwargs)
if task._source_traceback: # type: ignore
del task._source_traceback[-1] # type: ignore
# Set the task name to include the original coroutine's name
try:
cast("asyncio.Task[Any]", task).set_name(
f"{get_name(coro)} (Sentry-wrapped)"
)
except AttributeError:
# set_name might not be available in all Python versions
pass
return task
loop.set_task_factory(_sentry_task_factory) # type: ignore
except RuntimeError:
# When there is no running loop, we have nothing to patch.
pass
logger.warning(
"There is no running asyncio loop so there is nothing Sentry can patch. "
"Please make sure you call sentry_sdk.init() within a running "
"asyncio loop for the AsyncioIntegration to work. "
"See https://docs.sentry.io/platforms/python/integrations/asyncio/"
)
def _capture_exception():
@@ -10,7 +10,7 @@ import sentry_sdk
from sentry_sdk.api import continue_trace
from sentry_sdk.consts import OP
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing import TRANSACTION_SOURCE_COMPONENT
from sentry_sdk.tracing import TransactionSource
from sentry_sdk.utils import (
AnnotatedValue,
capture_internal_exceptions,
@@ -61,7 +61,10 @@ def _wrap_init_error(init_error):
else:
# Fall back to AWS lambdas JSON representation of the error
sentry_event = _event_from_error_json(json.loads(args[1]))
error_info = args[1]
if isinstance(error_info, str):
error_info = json.loads(error_info)
sentry_event = _event_from_error_json(error_info)
sentry_sdk.capture_event(sentry_event)
return init_error(*args, **kwargs)
@@ -135,6 +138,8 @@ def _wrap_handler(handler):
timeout_thread = TimeoutThread(
waiting_time,
configured_time / MILLIS_TO_SECONDS,
isolation_scope=scope,
current_scope=sentry_sdk.get_current_scope(),
)
# Starting the thread to raise timeout warning exception
@@ -150,7 +155,7 @@ def _wrap_handler(handler):
headers,
op=OP.FUNCTION_AWS,
name=aws_context.function_name,
source=TRANSACTION_SOURCE_COMPONENT,
source=TransactionSource.COMPONENT,
origin=AwsLambdaIntegration.origin,
)
with sentry_sdk.start_transaction(
@@ -177,14 +177,20 @@ def _set_transaction_name_and_source(event, transaction_style, request):
name = ""
if transaction_style == "url":
name = request.route.rule or ""
try:
name = request.route.rule or ""
except RuntimeError:
pass
elif transaction_style == "endpoint":
name = (
request.route.name
or transaction_from_function(request.route.callback)
or ""
)
try:
name = (
request.route.name
or transaction_from_function(request.route.callback)
or ""
)
except RuntimeError:
pass
event["transaction"] = name
event["transaction_info"] = {"source": SOURCE_FOR_STYLE[transaction_style]}
@@ -9,12 +9,12 @@ from sentry_sdk.consts import OP, SPANSTATUS, SPANDATA
from sentry_sdk.integrations import _check_minimum_version, Integration, DidNotEnable
from sentry_sdk.integrations.celery.beat import (
_patch_beat_apply_entry,
_patch_redbeat_maybe_due,
_patch_redbeat_apply_async,
_setup_celery_beat_signals,
)
from sentry_sdk.integrations.celery.utils import _now_seconds_since_epoch
from sentry_sdk.integrations.logging import ignore_logger
from sentry_sdk.tracing import BAGGAGE_HEADER_NAME, TRANSACTION_SOURCE_TASK
from sentry_sdk.tracing import BAGGAGE_HEADER_NAME, TransactionSource
from sentry_sdk.tracing_utils import Baggage
from sentry_sdk.utils import (
capture_internal_exceptions,
@@ -73,7 +73,7 @@ class CeleryIntegration(Integration):
self.exclude_beat_tasks = exclude_beat_tasks
_patch_beat_apply_entry()
_patch_redbeat_maybe_due()
_patch_redbeat_apply_async()
_setup_celery_beat_signals(monitor_beat_tasks)
@staticmethod
@@ -319,7 +319,7 @@ def _wrap_tracer(task, f):
headers,
op=OP.QUEUE_TASK_CELERY,
name="unknown celery task",
source=TRANSACTION_SOURCE_TASK,
source=TransactionSource.TASK,
origin=CeleryIntegration.origin,
)
transaction.name = task.name
@@ -391,6 +391,7 @@ def _wrap_task_call(task, f):
)
if latency is not None:
latency *= 1000 # milliseconds
span.set_data(SPANDATA.MESSAGING_MESSAGE_RECEIVE_LATENCY, latency)
with capture_internal_exceptions():
@@ -202,12 +202,12 @@ def _patch_beat_apply_entry():
Scheduler.apply_entry = _wrap_beat_scheduler(Scheduler.apply_entry)
def _patch_redbeat_maybe_due():
def _patch_redbeat_apply_async():
# type: () -> None
if RedBeatScheduler is None:
return
RedBeatScheduler.maybe_due = _wrap_beat_scheduler(RedBeatScheduler.maybe_due)
RedBeatScheduler.apply_async = _wrap_beat_scheduler(RedBeatScheduler.apply_async)
def _setup_celery_beat_signals(monitor_beat_tasks):
@@ -4,7 +4,7 @@ from functools import wraps
import sentry_sdk
from sentry_sdk.integrations import Integration, DidNotEnable
from sentry_sdk.integrations.aws_lambda import _make_request_event_processor
from sentry_sdk.tracing import TRANSACTION_SOURCE_COMPONENT
from sentry_sdk.tracing import TransactionSource
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
@@ -67,7 +67,7 @@ def _get_view_function_response(app, view_function, function_args):
configured_time = app.lambda_context.get_remaining_time_in_millis()
scope.set_transaction_name(
app.lambda_context.function_name,
source=TRANSACTION_SOURCE_COMPONENT,
source=TransactionSource.COMPONENT,
)
scope.add_event_processor(
@@ -11,7 +11,8 @@ from typing import TYPE_CHECKING, TypeVar
# without introducing a hard dependency on `typing_extensions`
# from: https://stackoverflow.com/a/71944042/300572
if TYPE_CHECKING:
from typing import ParamSpec, Callable
from collections.abc import Iterator
from typing import Any, ParamSpec, Callable
else:
# Fake ParamSpec
class ParamSpec:
@@ -49,9 +50,7 @@ class ClickhouseDriverIntegration(Integration):
)
# If the query contains parameters then the send_data function is used to send those parameters to clickhouse
clickhouse_driver.client.Client.send_data = _wrap_send_data(
clickhouse_driver.client.Client.send_data
)
_wrap_send_data()
# Every query ends either with the Client's `receive_end_of_query` (no result expected)
# or its `receive_result` (result expected)
@@ -128,23 +127,44 @@ def _wrap_end(f: Callable[P, T]) -> Callable[P, T]:
return _inner_end
def _wrap_send_data(f: Callable[P, T]) -> Callable[P, T]:
def _inner_send_data(*args: P.args, **kwargs: P.kwargs) -> T:
instance = args[0] # type: clickhouse_driver.client.Client
data = args[2]
span = getattr(instance.connection, "_sentry_span", None)
def _wrap_send_data() -> None:
original_send_data = clickhouse_driver.client.Client.send_data
def _inner_send_data( # type: ignore[no-untyped-def] # clickhouse-driver does not type send_data
self, sample_block, data, types_check=False, columnar=False, *args, **kwargs
):
span = getattr(self.connection, "_sentry_span", None)
if span is not None:
_set_db_data(span, instance.connection)
_set_db_data(span, self.connection)
if should_send_default_pii():
db_params = span._data.get("db.params", [])
db_params.extend(data)
if isinstance(data, (list, tuple)):
db_params.extend(data)
else: # data is a generic iterator
orig_data = data
# Wrap the generator to add items to db.params as they are yielded.
# This allows us to send the params to Sentry without needing to allocate
# memory for the entire generator at once.
def wrapped_generator() -> "Iterator[Any]":
for item in orig_data:
db_params.append(item)
yield item
# Replace the original iterator with the wrapped one.
data = wrapped_generator()
span.set_data("db.params", db_params)
return f(*args, **kwargs)
return original_send_data(
self, sample_block, data, types_check, columnar, *args, **kwargs
)
return _inner_send_data
clickhouse_driver.client.Client.send_data = _inner_send_data
def _set_db_data(
@@ -13,6 +13,8 @@ if TYPE_CHECKING:
CONTEXT_TYPE = "cloud_resource"
HTTP_TIMEOUT = 2.0
AWS_METADATA_HOST = "169.254.169.254"
AWS_TOKEN_URL = "http://{}/latest/api/token".format(AWS_METADATA_HOST)
AWS_METADATA_URL = "http://{}/latest/dynamic/instance-identity/document".format(
@@ -59,7 +61,7 @@ class CloudResourceContextIntegration(Integration):
cloud_provider = ""
aws_token = ""
http = urllib3.PoolManager()
http = urllib3.PoolManager(timeout=HTTP_TIMEOUT)
gcp_metadata = None
@@ -83,7 +85,13 @@ class CloudResourceContextIntegration(Integration):
cls.aws_token = r.data.decode()
return True
except Exception:
except urllib3.exceptions.TimeoutError:
logger.debug(
"AWS metadata service timed out after %s seconds", HTTP_TIMEOUT
)
return False
except Exception as e:
logger.debug("Error checking AWS metadata service: %s", str(e))
return False
@classmethod
@@ -131,8 +139,12 @@ class CloudResourceContextIntegration(Integration):
except Exception:
pass
except Exception:
pass
except urllib3.exceptions.TimeoutError:
logger.debug(
"AWS metadata service timed out after %s seconds", HTTP_TIMEOUT
)
except Exception as e:
logger.debug("Error fetching AWS metadata: %s", str(e))
return ctx
@@ -152,7 +164,13 @@ class CloudResourceContextIntegration(Integration):
cls.gcp_metadata = json.loads(r.data.decode("utf-8"))
return True
except Exception:
except urllib3.exceptions.TimeoutError:
logger.debug(
"GCP metadata service timed out after %s seconds", HTTP_TIMEOUT
)
return False
except Exception as e:
logger.debug("Error checking GCP metadata service: %s", str(e))
return False
@classmethod
@@ -201,8 +219,12 @@ class CloudResourceContextIntegration(Integration):
except Exception:
pass
except Exception:
pass
except urllib3.exceptions.TimeoutError:
logger.debug(
"GCP metadata service timed out after %s seconds", HTTP_TIMEOUT
)
except Exception as e:
logger.debug("Error fetching GCP metadata: %s", str(e))
return ctx
@@ -7,6 +7,8 @@ from sentry_sdk.ai.utils import set_data_normalized
from typing import TYPE_CHECKING
from sentry_sdk.tracing_utils import set_span_errored
if TYPE_CHECKING:
from typing import Any, Callable, Iterator
from sentry_sdk.tracing import Span
@@ -52,17 +54,17 @@ COLLECTED_PII_CHAT_PARAMS = {
}
COLLECTED_CHAT_RESP_ATTRS = {
"generation_id": "ai.generation_id",
"is_search_required": "ai.is_search_required",
"finish_reason": "ai.finish_reason",
"generation_id": SPANDATA.AI_GENERATION_ID,
"is_search_required": SPANDATA.AI_SEARCH_REQUIRED,
"finish_reason": SPANDATA.AI_FINISH_REASON,
}
COLLECTED_PII_CHAT_RESP_ATTRS = {
"citations": "ai.citations",
"documents": "ai.documents",
"search_queries": "ai.search_queries",
"search_results": "ai.search_results",
"tool_calls": "ai.tool_calls",
"citations": SPANDATA.AI_CITATIONS,
"documents": SPANDATA.AI_DOCUMENTS,
"search_queries": SPANDATA.AI_SEARCH_QUERIES,
"search_results": SPANDATA.AI_SEARCH_RESULTS,
"tool_calls": SPANDATA.AI_TOOL_CALLS,
}
@@ -84,6 +86,8 @@ class CohereIntegration(Integration):
def _capture_exception(exc):
# type: (Any) -> None
set_span_errored()
event, hint = event_from_exception(
exc,
client_options=sentry_sdk.get_client().options,
@@ -116,18 +120,18 @@ def _wrap_chat(f, streaming):
if hasattr(res.meta, "billed_units"):
record_token_usage(
span,
prompt_tokens=res.meta.billed_units.input_tokens,
completion_tokens=res.meta.billed_units.output_tokens,
input_tokens=res.meta.billed_units.input_tokens,
output_tokens=res.meta.billed_units.output_tokens,
)
elif hasattr(res.meta, "tokens"):
record_token_usage(
span,
prompt_tokens=res.meta.tokens.input_tokens,
completion_tokens=res.meta.tokens.output_tokens,
input_tokens=res.meta.tokens.input_tokens,
output_tokens=res.meta.tokens.output_tokens,
)
if hasattr(res.meta, "warnings"):
set_data_normalized(span, "ai.warnings", res.meta.warnings)
set_data_normalized(span, SPANDATA.AI_WARNINGS, res.meta.warnings)
@wraps(f)
def new_chat(*args, **kwargs):
@@ -238,7 +242,7 @@ def _wrap_embed(f):
should_send_default_pii() and integration.include_prompts
):
if isinstance(kwargs["texts"], str):
set_data_normalized(span, "ai.texts", [kwargs["texts"]])
set_data_normalized(span, SPANDATA.AI_TEXTS, [kwargs["texts"]])
elif (
isinstance(kwargs["texts"], list)
and len(kwargs["texts"]) > 0
@@ -262,7 +266,7 @@ def _wrap_embed(f):
):
record_token_usage(
span,
prompt_tokens=res.meta.billed_units.input_tokens,
input_tokens=res.meta.billed_units.input_tokens,
total_tokens=res.meta.billed_units.input_tokens,
)
return res
@@ -1,5 +1,7 @@
import weakref
import sentry_sdk
from sentry_sdk.utils import ContextVar
from sentry_sdk.utils import ContextVar, logger
from sentry_sdk.integrations import Integration
from sentry_sdk.scope import add_global_event_processor
@@ -35,8 +37,31 @@ class DedupeIntegration(Integration):
if exc_info is None:
return event
last_seen = integration._last_seen.get(None)
if last_seen is not None:
# last_seen is either a weakref or the original instance
last_seen = (
last_seen() if isinstance(last_seen, weakref.ref) else last_seen
)
exc = exc_info[1]
if integration._last_seen.get(None) is exc:
if last_seen is exc:
logger.info("DedupeIntegration dropped duplicated error event %s", exc)
return None
integration._last_seen.set(exc)
# we can only weakref non builtin types
try:
integration._last_seen.set(weakref.ref(exc))
except TypeError:
integration._last_seen.set(exc)
return event
@staticmethod
def reset_last_seen():
# type: () -> None
integration = sentry_sdk.get_client().get_integration(DedupeIntegration)
if integration is None:
return
integration._last_seen.set(None)
@@ -7,8 +7,8 @@ from importlib import import_module
import sentry_sdk
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.scope import add_global_event_processor, should_send_default_pii
from sentry_sdk.serializer import add_global_repr_processor
from sentry_sdk.tracing import SOURCE_FOR_STYLE, TRANSACTION_SOURCE_URL
from sentry_sdk.serializer import add_global_repr_processor, add_repr_sequence_type
from sentry_sdk.tracing import SOURCE_FOR_STYLE, TransactionSource
from sentry_sdk.tracing_utils import add_query_source, record_sql_queries
from sentry_sdk.utils import (
AnnotatedValue,
@@ -269,6 +269,7 @@ class DjangoIntegration(Integration):
patch_views()
patch_templates()
patch_signals()
add_template_context_repr_sequence()
if patch_caching is not None:
patch_caching()
@@ -398,7 +399,7 @@ def _set_transaction_name_and_source(scope, transaction_style, request):
if transaction_name is None:
transaction_name = request.path_info
source = TRANSACTION_SOURCE_URL
source = TransactionSource.URL
else:
source = SOURCE_FOR_STYLE[transaction_style]
@@ -584,7 +585,7 @@ class DjangoRequestExtractor(RequestExtractor):
# type: () -> Optional[Dict[str, Any]]
try:
return self.request.data
except AttributeError:
except Exception:
return RequestExtractor.parsed_body(self)
@@ -745,3 +746,13 @@ def _set_db_data(span, cursor_or_db):
server_socket_address = connection_params.get("unix_socket")
if server_socket_address is not None:
span.set_data(SPANDATA.SERVER_SOCKET_ADDRESS, server_socket_address)
def add_template_context_repr_sequence():
# type: () -> None
try:
from django.template.context import BaseContext
add_repr_sequence_type(BaseContext)
except Exception:
pass
@@ -155,7 +155,7 @@ def patch_channels_asgi_handler_impl(cls):
http_methods_to_capture=integration.http_methods_to_capture,
)
return await middleware(self.scope)(receive, send)
return await middleware(self.scope)(receive, send) # type: ignore
cls.__call__ = sentry_patched_asgi_handler
@@ -237,9 +237,9 @@ def _asgi_middleware_mixin_factory(_check_middleware_span):
middleware_span = _check_middleware_span(old_method=f)
if middleware_span is None:
return await f(*args, **kwargs)
return await f(*args, **kwargs) # type: ignore
with middleware_span:
return await f(*args, **kwargs)
return await f(*args, **kwargs) # type: ignore
return SentryASGIMixin
@@ -45,7 +45,8 @@ def _patch_cache_method(cache, method_name, address, port):
):
# type: (CacheHandler, str, Callable[..., Any], tuple[Any, ...], dict[str, Any], Optional[str], Optional[int]) -> Any
is_set_operation = method_name.startswith("set")
is_get_operation = not is_set_operation
is_get_method = method_name == "get"
is_get_many_method = method_name == "get_many"
op = OP.CACHE_PUT if is_set_operation else OP.CACHE_GET
description = _get_span_description(method_name, args, kwargs)
@@ -69,8 +70,20 @@ def _patch_cache_method(cache, method_name, address, port):
span.set_data(SPANDATA.CACHE_KEY, key)
item_size = None
if is_get_operation:
if value:
if is_get_many_method:
if value != {}:
item_size = len(str(value))
span.set_data(SPANDATA.CACHE_HIT, True)
else:
span.set_data(SPANDATA.CACHE_HIT, False)
elif is_get_method:
default_value = None
if len(args) >= 2:
default_value = args[1]
elif "default" in kwargs:
default_value = kwargs["default"]
if value != default_value:
item_size = len(str(value))
span.set_data(SPANDATA.CACHE_HIT, True)
else:
@@ -1,18 +1,31 @@
import json
import sentry_sdk
from sentry_sdk.integrations import Integration
from sentry_sdk.consts import OP, SPANSTATUS
from sentry_sdk.api import continue_trace, get_baggage, get_traceparent
from sentry_sdk.integrations import Integration, DidNotEnable
from sentry_sdk.integrations._wsgi_common import request_body_within_bounds
from sentry_sdk.tracing import (
BAGGAGE_HEADER_NAME,
SENTRY_TRACE_HEADER_NAME,
TransactionSource,
)
from sentry_sdk.utils import (
AnnotatedValue,
capture_internal_exceptions,
event_from_exception,
)
from typing import TypeVar
from dramatiq.broker import Broker # type: ignore
from dramatiq.message import Message # type: ignore
from dramatiq.middleware import Middleware, default_middleware # type: ignore
from dramatiq.errors import Retry # type: ignore
R = TypeVar("R")
try:
from dramatiq.broker import Broker
from dramatiq.middleware import Middleware, default_middleware
from dramatiq.errors import Retry
from dramatiq.message import Message
except ImportError:
raise DidNotEnable("Dramatiq is not installed")
from typing import TYPE_CHECKING
@@ -34,10 +47,12 @@ class DramatiqIntegration(Integration):
"""
identifier = "dramatiq"
origin = f"auto.queue.{identifier}"
@staticmethod
def setup_once():
# type: () -> None
_patch_dramatiq_broker()
@@ -85,22 +100,54 @@ class SentryMiddleware(Middleware): # type: ignore[misc]
DramatiqIntegration.
"""
def before_process_message(self, broker, message):
# type: (Broker, Message) -> None
SENTRY_HEADERS_NAME = "_sentry_headers"
def before_enqueue(self, broker, message, delay):
# type: (Broker, Message[R], int) -> None
integration = sentry_sdk.get_client().get_integration(DramatiqIntegration)
if integration is None:
return
message._scope_manager = sentry_sdk.new_scope()
message._scope_manager.__enter__()
message.options[self.SENTRY_HEADERS_NAME] = {
BAGGAGE_HEADER_NAME: get_baggage(),
SENTRY_TRACE_HEADER_NAME: get_traceparent(),
}
scope = sentry_sdk.get_current_scope()
scope.transaction = message.actor_name
def before_process_message(self, broker, message):
# type: (Broker, Message[R]) -> None
integration = sentry_sdk.get_client().get_integration(DramatiqIntegration)
if integration is None:
return
message._scope_manager = sentry_sdk.isolation_scope()
scope = message._scope_manager.__enter__()
scope.clear_breadcrumbs()
scope.set_extra("dramatiq_message_id", message.message_id)
scope.add_event_processor(_make_message_event_processor(message, integration))
sentry_headers = message.options.get(self.SENTRY_HEADERS_NAME) or {}
if "retries" in message.options:
# start new trace in case of retrying
sentry_headers = {}
transaction = continue_trace(
sentry_headers,
name=message.actor_name,
op=OP.QUEUE_TASK_DRAMATIQ,
source=TransactionSource.TASK,
origin=DramatiqIntegration.origin,
)
transaction.set_status(SPANSTATUS.OK)
sentry_sdk.start_transaction(
transaction,
name=message.actor_name,
op=OP.QUEUE_TASK_DRAMATIQ,
source=TransactionSource.TASK,
)
transaction.__enter__()
def after_process_message(self, broker, message, *, result=None, exception=None):
# type: (Broker, Message, Any, Optional[Any], Optional[Exception]) -> None
# type: (Broker, Message[R], Optional[Any], Optional[Exception]) -> None
integration = sentry_sdk.get_client().get_integration(DramatiqIntegration)
if integration is None:
return
@@ -108,27 +155,38 @@ class SentryMiddleware(Middleware): # type: ignore[misc]
actor = broker.get_actor(message.actor_name)
throws = message.options.get("throws") or actor.options.get("throws")
try:
if (
exception is not None
and not (throws and isinstance(exception, throws))
and not isinstance(exception, Retry)
):
event, hint = event_from_exception(
exception,
client_options=sentry_sdk.get_client().options,
mechanism={
"type": DramatiqIntegration.identifier,
"handled": False,
},
)
sentry_sdk.capture_event(event, hint=hint)
finally:
message._scope_manager.__exit__(None, None, None)
scope_manager = message._scope_manager
transaction = sentry_sdk.get_current_scope().transaction
if not transaction:
return None
is_event_capture_required = (
exception is not None
and not (throws and isinstance(exception, throws))
and not isinstance(exception, Retry)
)
if not is_event_capture_required:
# normal transaction finish
transaction.__exit__(None, None, None)
scope_manager.__exit__(None, None, None)
return
event, hint = event_from_exception(
exception, # type: ignore[arg-type]
client_options=sentry_sdk.get_client().options,
mechanism={
"type": DramatiqIntegration.identifier,
"handled": False,
},
)
sentry_sdk.capture_event(event, hint=hint)
# transaction error
transaction.__exit__(type(exception), exception, None)
scope_manager.__exit__(type(exception), exception, None)
def _make_message_event_processor(message, integration):
# type: (Message, DramatiqIntegration) -> Callable[[Event, Hint], Optional[Event]]
# type: (Message[R], DramatiqIntegration) -> Callable[[Event, Hint], Optional[Event]]
def inner(event, hint):
# type: (Event, Hint) -> Optional[Event]
@@ -142,7 +200,7 @@ def _make_message_event_processor(message, integration):
class DramatiqMessageExtractor:
def __init__(self, message):
# type: (Message) -> None
# type: (Message[R]) -> None
self.message_data = dict(message.asdict())
def content_length(self):
@@ -5,11 +5,8 @@ from functools import wraps
import sentry_sdk
from sentry_sdk.integrations import DidNotEnable
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing import SOURCE_FOR_STYLE, TRANSACTION_SOURCE_ROUTE
from sentry_sdk.utils import (
transaction_from_function,
logger,
)
from sentry_sdk.tracing import SOURCE_FOR_STYLE, TransactionSource
from sentry_sdk.utils import transaction_from_function
from typing import TYPE_CHECKING
@@ -61,14 +58,11 @@ def _set_transaction_name_and_source(scope, transaction_style, request):
if not name:
name = _DEFAULT_TRANSACTION_NAME
source = TRANSACTION_SOURCE_ROUTE
source = TransactionSource.ROUTE
else:
source = SOURCE_FOR_STYLE[transaction_style]
scope.set_transaction_name(name, source=source)
logger.debug(
"[FastAPI] Set transaction name and source on scope: %s / %s", name, source
)
def patch_get_request_handler():
@@ -72,6 +72,18 @@ class FlaskIntegration(Integration):
@staticmethod
def setup_once():
# type: () -> None
try:
from quart import Quart # type: ignore
if Flask == Quart:
# This is Quart masquerading as Flask, don't enable the Flask
# integration. See https://github.com/getsentry/sentry-python/issues/2709
raise DidNotEnable(
"This is not a Flask app but rather Quart pretending to be Flask"
)
except ImportError:
pass
version = package_version("flask")
_check_minimum_version(FlaskIntegration, version)
@@ -10,7 +10,7 @@ from sentry_sdk.consts import OP
from sentry_sdk.integrations import Integration
from sentry_sdk.integrations._wsgi_common import _filter_headers
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing import TRANSACTION_SOURCE_COMPONENT
from sentry_sdk.tracing import TransactionSource
from sentry_sdk.utils import (
AnnotatedValue,
capture_internal_exceptions,
@@ -75,7 +75,12 @@ def _wrap_func(func):
):
waiting_time = configured_time - TIMEOUT_WARNING_BUFFER
timeout_thread = TimeoutThread(waiting_time, configured_time)
timeout_thread = TimeoutThread(
waiting_time,
configured_time,
isolation_scope=scope,
current_scope=sentry_sdk.get_current_scope(),
)
# Starting the thread to raise timeout warning exception
timeout_thread.start()
@@ -88,7 +93,7 @@ def _wrap_func(func):
headers,
op=OP.FUNCTION_GCP,
name=environ.get("FUNCTION_NAME", ""),
source=TRANSACTION_SOURCE_COMPONENT,
source=TransactionSource.COMPONENT,
origin=GcpIntegration.origin,
)
sampling_context = {
@@ -11,24 +11,16 @@ if TYPE_CHECKING:
from typing import Any
from sentry_sdk._types import Event
MODULE_RE = r"[a-zA-Z0-9/._:\\-]+"
TYPE_RE = r"[a-zA-Z0-9._:<>,-]+"
HEXVAL_RE = r"[A-Fa-f0-9]+"
# function is everything between index at @
# and then we match on the @ plus the hex val
FUNCTION_RE = r"[^@]+?"
HEX_ADDRESS = r"\s+@\s+0x[0-9a-fA-F]+"
FRAME_RE = r"""
^(?P<index>\d+)\.\s
(?P<package>{MODULE_RE})\(
(?P<retval>{TYPE_RE}\ )?
((?P<function>{TYPE_RE})
(?P<args>\(.*\))?
)?
((?P<constoffset>\ const)?\+0x(?P<offset>{HEXVAL_RE}))?
\)\s
\[0x(?P<retaddr>{HEXVAL_RE})\]$
^(?P<index>\d+)\.\s+(?P<function>{FUNCTION_RE}){HEX_ADDRESS}(?:\s+in\s+(?P<package>.+))?$
""".format(
MODULE_RE=MODULE_RE, HEXVAL_RE=HEXVAL_RE, TYPE_RE=TYPE_RE
FUNCTION_RE=FUNCTION_RE,
HEX_ADDRESS=HEX_ADDRESS,
)
FRAME_RE = re.compile(FRAME_RE, re.MULTILINE | re.VERBOSE)
@@ -0,0 +1,301 @@
from functools import wraps
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
)
import sentry_sdk
from sentry_sdk.ai.utils import get_start_span_function
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.tracing import SPANSTATUS
try:
from google.genai.models import Models, AsyncModels
except ImportError:
raise DidNotEnable("google-genai not installed")
from .consts import IDENTIFIER, ORIGIN, GEN_AI_SYSTEM
from .utils import (
set_span_data_for_request,
set_span_data_for_response,
_capture_exception,
prepare_generate_content_args,
)
from .streaming import (
set_span_data_for_streaming_response,
accumulate_streaming_response,
)
class GoogleGenAIIntegration(Integration):
identifier = IDENTIFIER
origin = ORIGIN
def __init__(self, include_prompts=True):
# type: (GoogleGenAIIntegration, bool) -> None
self.include_prompts = include_prompts
@staticmethod
def setup_once():
# type: () -> None
# Patch sync methods
Models.generate_content = _wrap_generate_content(Models.generate_content)
Models.generate_content_stream = _wrap_generate_content_stream(
Models.generate_content_stream
)
# Patch async methods
AsyncModels.generate_content = _wrap_async_generate_content(
AsyncModels.generate_content
)
AsyncModels.generate_content_stream = _wrap_async_generate_content_stream(
AsyncModels.generate_content_stream
)
def _wrap_generate_content_stream(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
def new_generate_content_stream(self, *args, **kwargs):
# type: (Any, Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(GoogleGenAIIntegration)
if integration is None:
return f(self, *args, **kwargs)
_model, contents, model_name = prepare_generate_content_args(args, kwargs)
span = get_start_span_function()(
op=OP.GEN_AI_INVOKE_AGENT,
name="invoke_agent",
origin=ORIGIN,
)
span.__enter__()
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
set_span_data_for_request(span, integration, model_name, contents, kwargs)
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
chat_span = sentry_sdk.start_span(
op=OP.GEN_AI_CHAT,
name=f"chat {model_name}",
origin=ORIGIN,
)
chat_span.__enter__()
chat_span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
chat_span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
chat_span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
set_span_data_for_request(chat_span, integration, model_name, contents, kwargs)
chat_span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
chat_span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
try:
stream = f(self, *args, **kwargs)
# Create wrapper iterator to accumulate responses
def new_iterator():
# type: () -> Iterator[Any]
chunks = [] # type: List[Any]
try:
for chunk in stream:
chunks.append(chunk)
yield chunk
except Exception as exc:
_capture_exception(exc)
chat_span.set_status(SPANSTATUS.ERROR)
raise
finally:
# Accumulate all chunks and set final response data on spans
if chunks:
accumulated_response = accumulate_streaming_response(chunks)
set_span_data_for_streaming_response(
chat_span, integration, accumulated_response
)
set_span_data_for_streaming_response(
span, integration, accumulated_response
)
chat_span.__exit__(None, None, None)
span.__exit__(None, None, None)
return new_iterator()
except Exception as exc:
_capture_exception(exc)
chat_span.__exit__(None, None, None)
span.__exit__(None, None, None)
raise
return new_generate_content_stream
def _wrap_async_generate_content_stream(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
async def new_async_generate_content_stream(self, *args, **kwargs):
# type: (Any, Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(GoogleGenAIIntegration)
if integration is None:
return await f(self, *args, **kwargs)
_model, contents, model_name = prepare_generate_content_args(args, kwargs)
span = get_start_span_function()(
op=OP.GEN_AI_INVOKE_AGENT,
name="invoke_agent",
origin=ORIGIN,
)
span.__enter__()
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
set_span_data_for_request(span, integration, model_name, contents, kwargs)
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
chat_span = sentry_sdk.start_span(
op=OP.GEN_AI_CHAT,
name=f"chat {model_name}",
origin=ORIGIN,
)
chat_span.__enter__()
chat_span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
chat_span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
chat_span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
set_span_data_for_request(chat_span, integration, model_name, contents, kwargs)
chat_span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
chat_span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
try:
stream = await f(self, *args, **kwargs)
# Create wrapper async iterator to accumulate responses
async def new_async_iterator():
# type: () -> AsyncIterator[Any]
chunks = [] # type: List[Any]
try:
async for chunk in stream:
chunks.append(chunk)
yield chunk
except Exception as exc:
_capture_exception(exc)
chat_span.set_status(SPANSTATUS.ERROR)
raise
finally:
# Accumulate all chunks and set final response data on spans
if chunks:
accumulated_response = accumulate_streaming_response(chunks)
set_span_data_for_streaming_response(
chat_span, integration, accumulated_response
)
set_span_data_for_streaming_response(
span, integration, accumulated_response
)
chat_span.__exit__(None, None, None)
span.__exit__(None, None, None)
return new_async_iterator()
except Exception as exc:
_capture_exception(exc)
chat_span.__exit__(None, None, None)
span.__exit__(None, None, None)
raise
return new_async_generate_content_stream
def _wrap_generate_content(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
def new_generate_content(self, *args, **kwargs):
# type: (Any, Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(GoogleGenAIIntegration)
if integration is None:
return f(self, *args, **kwargs)
model, contents, model_name = prepare_generate_content_args(args, kwargs)
with get_start_span_function()(
op=OP.GEN_AI_INVOKE_AGENT,
name="invoke_agent",
origin=ORIGIN,
) as span:
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
set_span_data_for_request(span, integration, model_name, contents, kwargs)
with sentry_sdk.start_span(
op=OP.GEN_AI_CHAT,
name=f"chat {model_name}",
origin=ORIGIN,
) as chat_span:
chat_span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
chat_span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
chat_span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
chat_span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
set_span_data_for_request(
chat_span, integration, model_name, contents, kwargs
)
try:
response = f(self, *args, **kwargs)
except Exception as exc:
_capture_exception(exc)
chat_span.set_status(SPANSTATUS.ERROR)
raise
set_span_data_for_response(chat_span, integration, response)
set_span_data_for_response(span, integration, response)
return response
return new_generate_content
def _wrap_async_generate_content(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
async def new_async_generate_content(self, *args, **kwargs):
# type: (Any, Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(GoogleGenAIIntegration)
if integration is None:
return await f(self, *args, **kwargs)
model, contents, model_name = prepare_generate_content_args(args, kwargs)
with get_start_span_function()(
op=OP.GEN_AI_INVOKE_AGENT,
name="invoke_agent",
origin=ORIGIN,
) as span:
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, model_name)
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
set_span_data_for_request(span, integration, model_name, contents, kwargs)
with sentry_sdk.start_span(
op=OP.GEN_AI_CHAT,
name=f"chat {model_name}",
origin=ORIGIN,
) as chat_span:
chat_span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
chat_span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
chat_span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
set_span_data_for_request(
chat_span, integration, model_name, contents, kwargs
)
try:
response = await f(self, *args, **kwargs)
except Exception as exc:
_capture_exception(exc)
chat_span.set_status(SPANSTATUS.ERROR)
raise
set_span_data_for_response(chat_span, integration, response)
set_span_data_for_response(span, integration, response)
return response
return new_async_generate_content
@@ -0,0 +1,16 @@
GEN_AI_SYSTEM = "gcp.gemini"
# Mapping of tool attributes to their descriptions
# These are all tools that are available in the Google GenAI API
TOOL_ATTRIBUTES_MAP = {
"google_search_retrieval": "Google Search retrieval tool",
"google_search": "Google Search tool",
"retrieval": "Retrieval tool",
"enterprise_web_search": "Enterprise web search tool",
"google_maps": "Google Maps tool",
"code_execution": "Code execution tool",
"computer_use": "Computer use tool",
}
IDENTIFIER = "google_genai"
ORIGIN = f"auto.ai.{IDENTIFIER}"
@@ -0,0 +1,155 @@
from typing import (
TYPE_CHECKING,
Any,
List,
TypedDict,
Optional,
)
from sentry_sdk.ai.utils import set_data_normalized
from sentry_sdk.consts import SPANDATA
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.utils import (
safe_serialize,
)
from .utils import (
extract_tool_calls,
extract_finish_reasons,
extract_contents_text,
extract_usage_data,
UsageData,
)
if TYPE_CHECKING:
from sentry_sdk.tracing import Span
from google.genai.types import GenerateContentResponse
class AccumulatedResponse(TypedDict):
id: Optional[str]
model: Optional[str]
text: str
finish_reasons: List[str]
tool_calls: List[dict[str, Any]]
usage_metadata: UsageData
def accumulate_streaming_response(chunks):
# type: (List[GenerateContentResponse]) -> AccumulatedResponse
"""Accumulate streaming chunks into a single response-like object."""
accumulated_text = []
finish_reasons = []
tool_calls = []
total_input_tokens = 0
total_output_tokens = 0
total_tokens = 0
total_cached_tokens = 0
total_reasoning_tokens = 0
response_id = None
model = None
for chunk in chunks:
# Extract text and tool calls
if getattr(chunk, "candidates", None):
for candidate in getattr(chunk, "candidates", []):
if hasattr(candidate, "content") and getattr(
candidate.content, "parts", []
):
extracted_text = extract_contents_text(candidate.content)
if extracted_text:
accumulated_text.append(extracted_text)
extracted_finish_reasons = extract_finish_reasons(chunk)
if extracted_finish_reasons:
finish_reasons.extend(extracted_finish_reasons)
extracted_tool_calls = extract_tool_calls(chunk)
if extracted_tool_calls:
tool_calls.extend(extracted_tool_calls)
# Accumulate token usage
extracted_usage_data = extract_usage_data(chunk)
total_input_tokens += extracted_usage_data["input_tokens"]
total_output_tokens += extracted_usage_data["output_tokens"]
total_cached_tokens += extracted_usage_data["input_tokens_cached"]
total_reasoning_tokens += extracted_usage_data["output_tokens_reasoning"]
total_tokens += extracted_usage_data["total_tokens"]
accumulated_response = AccumulatedResponse(
text="".join(accumulated_text),
finish_reasons=finish_reasons,
tool_calls=tool_calls,
usage_metadata=UsageData(
input_tokens=total_input_tokens,
output_tokens=total_output_tokens,
input_tokens_cached=total_cached_tokens,
output_tokens_reasoning=total_reasoning_tokens,
total_tokens=total_tokens,
),
id=response_id,
model=model,
)
return accumulated_response
def set_span_data_for_streaming_response(span, integration, accumulated_response):
# type: (Span, Any, AccumulatedResponse) -> None
"""Set span data for accumulated streaming response."""
if (
should_send_default_pii()
and integration.include_prompts
and accumulated_response.get("text")
):
span.set_data(
SPANDATA.GEN_AI_RESPONSE_TEXT,
safe_serialize([accumulated_response["text"]]),
)
if accumulated_response.get("finish_reasons"):
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
accumulated_response["finish_reasons"],
)
if accumulated_response.get("tool_calls"):
span.set_data(
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
safe_serialize(accumulated_response["tool_calls"]),
)
if accumulated_response.get("id"):
span.set_data(SPANDATA.GEN_AI_RESPONSE_ID, accumulated_response["id"])
if accumulated_response.get("model"):
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, accumulated_response["model"])
if accumulated_response["usage_metadata"]["input_tokens"]:
span.set_data(
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS,
accumulated_response["usage_metadata"]["input_tokens"],
)
if accumulated_response["usage_metadata"]["input_tokens_cached"]:
span.set_data(
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
accumulated_response["usage_metadata"]["input_tokens_cached"],
)
if accumulated_response["usage_metadata"]["output_tokens"]:
span.set_data(
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS,
accumulated_response["usage_metadata"]["output_tokens"],
)
if accumulated_response["usage_metadata"]["output_tokens_reasoning"]:
span.set_data(
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
accumulated_response["usage_metadata"]["output_tokens_reasoning"],
)
if accumulated_response["usage_metadata"]["total_tokens"]:
span.set_data(
SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS,
accumulated_response["usage_metadata"]["total_tokens"],
)
@@ -0,0 +1,576 @@
import copy
import inspect
from functools import wraps
from .consts import ORIGIN, TOOL_ATTRIBUTES_MAP, GEN_AI_SYSTEM
from typing import (
cast,
TYPE_CHECKING,
Iterable,
Any,
Callable,
List,
Optional,
Union,
TypedDict,
)
import sentry_sdk
from sentry_sdk.ai.utils import (
set_data_normalized,
truncate_and_annotate_messages,
normalize_message_roles,
)
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
safe_serialize,
)
from google.genai.types import GenerateContentConfig
if TYPE_CHECKING:
from sentry_sdk.tracing import Span
from google.genai.types import (
GenerateContentResponse,
ContentListUnion,
Tool,
Model,
)
class UsageData(TypedDict):
"""Structure for token usage data."""
input_tokens: int
input_tokens_cached: int
output_tokens: int
output_tokens_reasoning: int
total_tokens: int
def extract_usage_data(response):
# type: (Union[GenerateContentResponse, dict[str, Any]]) -> UsageData
"""Extract usage data from response into a structured format.
Args:
response: The GenerateContentResponse object or dictionary containing usage metadata
Returns:
UsageData: Dictionary with input_tokens, input_tokens_cached,
output_tokens, and output_tokens_reasoning fields
"""
usage_data = UsageData(
input_tokens=0,
input_tokens_cached=0,
output_tokens=0,
output_tokens_reasoning=0,
total_tokens=0,
)
# Handle dictionary response (from streaming)
if isinstance(response, dict):
usage = response.get("usage_metadata", {})
if not usage:
return usage_data
prompt_tokens = usage.get("prompt_token_count", 0) or 0
tool_use_prompt_tokens = usage.get("tool_use_prompt_token_count", 0) or 0
usage_data["input_tokens"] = prompt_tokens + tool_use_prompt_tokens
cached_tokens = usage.get("cached_content_token_count", 0) or 0
usage_data["input_tokens_cached"] = cached_tokens
reasoning_tokens = usage.get("thoughts_token_count", 0) or 0
usage_data["output_tokens_reasoning"] = reasoning_tokens
candidates_tokens = usage.get("candidates_token_count", 0) or 0
# python-genai reports output and reasoning tokens separately
# reasoning should be sub-category of output tokens
usage_data["output_tokens"] = candidates_tokens + reasoning_tokens
total_tokens = usage.get("total_token_count", 0) or 0
usage_data["total_tokens"] = total_tokens
return usage_data
if not hasattr(response, "usage_metadata"):
return usage_data
usage = response.usage_metadata
# Input tokens include both prompt and tool use prompt tokens
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
tool_use_prompt_tokens = getattr(usage, "tool_use_prompt_token_count", 0) or 0
usage_data["input_tokens"] = prompt_tokens + tool_use_prompt_tokens
# Cached input tokens
cached_tokens = getattr(usage, "cached_content_token_count", 0) or 0
usage_data["input_tokens_cached"] = cached_tokens
# Reasoning tokens
reasoning_tokens = getattr(usage, "thoughts_token_count", 0) or 0
usage_data["output_tokens_reasoning"] = reasoning_tokens
# output_tokens = candidates_tokens + reasoning_tokens
# google-genai reports output and reasoning tokens separately
candidates_tokens = getattr(usage, "candidates_token_count", 0) or 0
usage_data["output_tokens"] = candidates_tokens + reasoning_tokens
total_tokens = getattr(usage, "total_token_count", 0) or 0
usage_data["total_tokens"] = total_tokens
return usage_data
def _capture_exception(exc):
# type: (Any) -> None
"""Capture exception with Google GenAI mechanism."""
event, hint = event_from_exception(
exc,
client_options=sentry_sdk.get_client().options,
mechanism={"type": "google_genai", "handled": False},
)
sentry_sdk.capture_event(event, hint=hint)
def get_model_name(model):
# type: (Union[str, Model]) -> str
"""Extract model name from model parameter."""
if isinstance(model, str):
return model
# Handle case where model might be an object with a name attribute
if hasattr(model, "name"):
return str(model.name)
return str(model)
def extract_contents_text(contents):
# type: (ContentListUnion) -> Optional[str]
"""Extract text from contents parameter which can have various formats."""
if contents is None:
return None
# Simple string case
if isinstance(contents, str):
return contents
# List of contents or parts
if isinstance(contents, list):
texts = []
for item in contents:
# Recursively extract text from each item
extracted = extract_contents_text(item)
if extracted:
texts.append(extracted)
return " ".join(texts) if texts else None
# Dictionary case
if isinstance(contents, dict):
if "text" in contents:
return contents["text"]
# Try to extract from parts if present in dict
if "parts" in contents:
return extract_contents_text(contents["parts"])
# Content object with parts - recurse into parts
if getattr(contents, "parts", None):
return extract_contents_text(contents.parts)
# Direct text attribute
if hasattr(contents, "text"):
return contents.text
return None
def _format_tools_for_span(tools):
# type: (Iterable[Tool | Callable[..., Any]]) -> Optional[List[dict[str, Any]]]
"""Format tools parameter for span data."""
formatted_tools = []
for tool in tools:
if callable(tool):
# Handle callable functions passed directly
formatted_tools.append(
{
"name": getattr(tool, "__name__", "unknown"),
"description": getattr(tool, "__doc__", None),
}
)
elif (
hasattr(tool, "function_declarations")
and tool.function_declarations is not None
):
# Tool object with function declarations
for func_decl in tool.function_declarations:
formatted_tools.append(
{
"name": getattr(func_decl, "name", None),
"description": getattr(func_decl, "description", None),
}
)
else:
# Check for predefined tool attributes - each of these tools
# is an attribute of the tool object, by default set to None
for attr_name, description in TOOL_ATTRIBUTES_MAP.items():
if getattr(tool, attr_name, None):
formatted_tools.append(
{
"name": attr_name,
"description": description,
}
)
break
return formatted_tools if formatted_tools else None
def extract_tool_calls(response):
# type: (GenerateContentResponse) -> Optional[List[dict[str, Any]]]
"""Extract tool/function calls from response candidates and automatic function calling history."""
tool_calls = []
# Extract from candidates, sometimes tool calls are nested under the content.parts object
if getattr(response, "candidates", []):
for candidate in response.candidates:
if not hasattr(candidate, "content") or not getattr(
candidate.content, "parts", []
):
continue
for part in candidate.content.parts:
if getattr(part, "function_call", None):
function_call = part.function_call
tool_call = {
"name": getattr(function_call, "name", None),
"type": "function_call",
}
# Extract arguments if available
if getattr(function_call, "args", None):
tool_call["arguments"] = safe_serialize(function_call.args)
tool_calls.append(tool_call)
# Extract from automatic_function_calling_history
# This is the history of tool calls made by the model
if getattr(response, "automatic_function_calling_history", None):
for content in response.automatic_function_calling_history:
if not getattr(content, "parts", None):
continue
for part in getattr(content, "parts", []):
if getattr(part, "function_call", None):
function_call = part.function_call
tool_call = {
"name": getattr(function_call, "name", None),
"type": "function_call",
}
# Extract arguments if available
if hasattr(function_call, "args"):
tool_call["arguments"] = safe_serialize(function_call.args)
tool_calls.append(tool_call)
return tool_calls if tool_calls else None
def _capture_tool_input(args, kwargs, tool):
# type: (tuple[Any, ...], dict[str, Any], Tool) -> dict[str, Any]
"""Capture tool input from args and kwargs."""
tool_input = kwargs.copy() if kwargs else {}
# If we have positional args, try to map them to the function signature
if args:
try:
sig = inspect.signature(tool)
param_names = list(sig.parameters.keys())
for i, arg in enumerate(args):
if i < len(param_names):
tool_input[param_names[i]] = arg
except Exception:
# Fallback if we can't get the signature
tool_input["args"] = args
return tool_input
def _create_tool_span(tool_name, tool_doc):
# type: (str, Optional[str]) -> Span
"""Create a span for tool execution."""
span = sentry_sdk.start_span(
op=OP.GEN_AI_EXECUTE_TOOL,
name=f"execute_tool {tool_name}",
origin=ORIGIN,
)
span.set_data(SPANDATA.GEN_AI_TOOL_NAME, tool_name)
span.set_data(SPANDATA.GEN_AI_TOOL_TYPE, "function")
if tool_doc:
span.set_data(SPANDATA.GEN_AI_TOOL_DESCRIPTION, tool_doc)
return span
def wrapped_tool(tool):
# type: (Tool | Callable[..., Any]) -> Tool | Callable[..., Any]
"""Wrap a tool to emit execute_tool spans when called."""
if not callable(tool):
# Not a callable function, return as-is (predefined tools)
return tool
tool_name = getattr(tool, "__name__", "unknown")
tool_doc = tool.__doc__
if inspect.iscoroutinefunction(tool):
# Async function
@wraps(tool)
async def async_wrapped(*args, **kwargs):
# type: (Any, Any) -> Any
with _create_tool_span(tool_name, tool_doc) as span:
# Capture tool input
tool_input = _capture_tool_input(args, kwargs, tool)
with capture_internal_exceptions():
span.set_data(
SPANDATA.GEN_AI_TOOL_INPUT, safe_serialize(tool_input)
)
try:
result = await tool(*args, **kwargs)
# Capture tool output
with capture_internal_exceptions():
span.set_data(
SPANDATA.GEN_AI_TOOL_OUTPUT, safe_serialize(result)
)
return result
except Exception as exc:
_capture_exception(exc)
raise
return async_wrapped
else:
# Sync function
@wraps(tool)
def sync_wrapped(*args, **kwargs):
# type: (Any, Any) -> Any
with _create_tool_span(tool_name, tool_doc) as span:
# Capture tool input
tool_input = _capture_tool_input(args, kwargs, tool)
with capture_internal_exceptions():
span.set_data(
SPANDATA.GEN_AI_TOOL_INPUT, safe_serialize(tool_input)
)
try:
result = tool(*args, **kwargs)
# Capture tool output
with capture_internal_exceptions():
span.set_data(
SPANDATA.GEN_AI_TOOL_OUTPUT, safe_serialize(result)
)
return result
except Exception as exc:
_capture_exception(exc)
raise
return sync_wrapped
def wrapped_config_with_tools(config):
# type: (GenerateContentConfig) -> GenerateContentConfig
"""Wrap tools in config to emit execute_tool spans. Tools are sometimes passed directly as
callable functions as a part of the config object."""
if not config or not getattr(config, "tools", None):
return config
result = copy.copy(config)
result.tools = [wrapped_tool(tool) for tool in config.tools]
return result
def _extract_response_text(response):
# type: (GenerateContentResponse) -> Optional[List[str]]
"""Extract text from response candidates."""
if not response or not getattr(response, "candidates", []):
return None
texts = []
for candidate in response.candidates:
if not hasattr(candidate, "content") or not hasattr(candidate.content, "parts"):
continue
for part in candidate.content.parts:
if getattr(part, "text", None):
texts.append(part.text)
return texts if texts else None
def extract_finish_reasons(response):
# type: (GenerateContentResponse) -> Optional[List[str]]
"""Extract finish reasons from response candidates."""
if not response or not getattr(response, "candidates", []):
return None
finish_reasons = []
for candidate in response.candidates:
if getattr(candidate, "finish_reason", None):
# Convert enum value to string if necessary
reason = str(candidate.finish_reason)
# Remove enum prefix if present (e.g., "FinishReason.STOP" -> "STOP")
if "." in reason:
reason = reason.split(".")[-1]
finish_reasons.append(reason)
return finish_reasons if finish_reasons else None
def set_span_data_for_request(span, integration, model, contents, kwargs):
# type: (Span, Any, str, ContentListUnion, dict[str, Any]) -> None
"""Set span data for the request."""
span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)
if kwargs.get("stream", False):
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
config = kwargs.get("config")
if config is None:
return
config = cast(GenerateContentConfig, config)
# Set input messages/prompts if PII is allowed
if should_send_default_pii() and integration.include_prompts:
messages = []
# Add system instruction if present
if hasattr(config, "system_instruction"):
system_instruction = config.system_instruction
if system_instruction:
system_text = extract_contents_text(system_instruction)
if system_text:
messages.append({"role": "system", "content": system_text})
# Add user message
contents_text = extract_contents_text(contents)
if contents_text:
messages.append({"role": "user", "content": contents_text})
if messages:
normalized_messages = normalize_message_roles(messages)
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_messages(
normalized_messages, span, scope
)
if messages_data is not None:
set_data_normalized(
span,
SPANDATA.GEN_AI_REQUEST_MESSAGES,
messages_data,
unpack=False,
)
# Extract parameters directly from config (not nested under generation_config)
for param, span_key in [
("temperature", SPANDATA.GEN_AI_REQUEST_TEMPERATURE),
("top_p", SPANDATA.GEN_AI_REQUEST_TOP_P),
("top_k", SPANDATA.GEN_AI_REQUEST_TOP_K),
("max_output_tokens", SPANDATA.GEN_AI_REQUEST_MAX_TOKENS),
("presence_penalty", SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY),
("frequency_penalty", SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY),
("seed", SPANDATA.GEN_AI_REQUEST_SEED),
]:
if hasattr(config, param):
value = getattr(config, param)
if value is not None:
span.set_data(span_key, value)
# Set tools if available
if hasattr(config, "tools"):
tools = config.tools
if tools:
formatted_tools = _format_tools_for_span(tools)
if formatted_tools:
set_data_normalized(
span,
SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
formatted_tools,
unpack=False,
)
def set_span_data_for_response(span, integration, response):
# type: (Span, Any, GenerateContentResponse) -> None
"""Set span data for the response."""
if not response:
return
if should_send_default_pii() and integration.include_prompts:
response_texts = _extract_response_text(response)
if response_texts:
# Format as JSON string array as per documentation
span.set_data(SPANDATA.GEN_AI_RESPONSE_TEXT, safe_serialize(response_texts))
tool_calls = extract_tool_calls(response)
if tool_calls:
# Tool calls should be JSON serialized
span.set_data(SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, safe_serialize(tool_calls))
finish_reasons = extract_finish_reasons(response)
if finish_reasons:
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS, finish_reasons
)
if getattr(response, "response_id", None):
span.set_data(SPANDATA.GEN_AI_RESPONSE_ID, response.response_id)
if getattr(response, "model_version", None):
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response.model_version)
usage_data = extract_usage_data(response)
if usage_data["input_tokens"]:
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage_data["input_tokens"])
if usage_data["input_tokens_cached"]:
span.set_data(
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
usage_data["input_tokens_cached"],
)
if usage_data["output_tokens"]:
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage_data["output_tokens"])
if usage_data["output_tokens_reasoning"]:
span.set_data(
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
usage_data["output_tokens_reasoning"],
)
if usage_data["total_tokens"]:
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage_data["total_tokens"])
def prepare_generate_content_args(args, kwargs):
# type: (tuple[Any, ...], dict[str, Any]) -> tuple[Any, Any, str]
"""Extract and prepare common arguments for generate_content methods."""
model = args[0] if args else kwargs.get("model", "unknown")
contents = args[1] if len(args) > 1 else kwargs.get("contents")
model_name = get_model_name(model)
config = kwargs.get("config")
wrapped_config = wrapped_config_with_tools(config)
if wrapped_config is not config:
kwargs["config"] = wrapped_config
return model, contents, model_name
@@ -18,6 +18,13 @@ try:
)
from gql.transport import Transport, AsyncTransport # type: ignore[import-not-found]
from gql.transport.exceptions import TransportQueryError # type: ignore[import-not-found]
try:
# gql 4.0+
from gql import GraphQLRequest
except ImportError:
GraphQLRequest = None
except ImportError:
raise DidNotEnable("gql is not installed")
@@ -92,13 +99,13 @@ def _patch_execute():
real_execute = gql.Client.execute
@ensure_integration_enabled(GQLIntegration, real_execute)
def sentry_patched_execute(self, document, *args, **kwargs):
def sentry_patched_execute(self, document_or_request, *args, **kwargs):
# type: (gql.Client, DocumentNode, Any, Any) -> Any
scope = sentry_sdk.get_isolation_scope()
scope.add_event_processor(_make_gql_event_processor(self, document))
scope.add_event_processor(_make_gql_event_processor(self, document_or_request))
try:
return real_execute(self, document, *args, **kwargs)
return real_execute(self, document_or_request, *args, **kwargs)
except TransportQueryError as e:
event, hint = event_from_exception(
e,
@@ -112,8 +119,8 @@ def _patch_execute():
gql.Client.execute = sentry_patched_execute
def _make_gql_event_processor(client, document):
# type: (gql.Client, DocumentNode) -> EventProcessor
def _make_gql_event_processor(client, document_or_request):
# type: (gql.Client, Union[DocumentNode, gql.GraphQLRequest]) -> EventProcessor
def processor(event, hint):
# type: (Event, dict[str, Any]) -> Event
try:
@@ -130,6 +137,16 @@ def _make_gql_event_processor(client, document):
)
if should_send_default_pii():
if GraphQLRequest is not None and isinstance(
document_or_request, GraphQLRequest
):
# In v4.0.0, gql moved to using GraphQLRequest instead of
# DocumentNode in execute
# https://github.com/graphql-python/gql/pull/556
document = document_or_request.document
else:
document = document_or_request
request["data"] = _data_from_document(document)
contexts = event.setdefault("contexts", {})
response = contexts.setdefault("response", {})
@@ -6,6 +6,7 @@ from grpc.aio import Channel as AsyncChannel
from grpc.aio import Server as AsyncServer
from sentry_sdk.integrations import Integration
from sentry_sdk.utils import parse_version
from .client import ClientInterceptor
from .server import ServerInterceptor
@@ -41,6 +42,8 @@ else:
P = ParamSpec("P")
GRPC_VERSION = parse_version(grpc.__version__)
def _wrap_channel_sync(func: Callable[P, Channel]) -> Callable[P, Channel]:
"Wrapper for synchronous secure and insecure channel."
@@ -127,7 +130,21 @@ def _wrap_async_server(func: Callable[P, AsyncServer]) -> Callable[P, AsyncServe
**kwargs: P.kwargs,
) -> Server:
server_interceptor = AsyncServerInterceptor()
interceptors = (server_interceptor, *(interceptors or []))
interceptors = [
server_interceptor,
*(interceptors or []),
] # type: Sequence[grpc.ServerInterceptor]
try:
# We prefer interceptors as a list because of compatibility with
# opentelemetry https://github.com/getsentry/sentry-python/issues/4389
# However, prior to grpc 1.42.0, only tuples were accepted, so we
# have no choice there.
if GRPC_VERSION is not None and GRPC_VERSION < (1, 42, 0):
interceptors = tuple(interceptors)
except Exception:
pass
return func(*args, interceptors=interceptors, **kwargs) # type: ignore
return patched_aio_server # type: ignore
@@ -65,7 +65,8 @@ class SentryUnaryUnaryClientInterceptor(ClientInterceptor, UnaryUnaryClientInter
class SentryUnaryStreamClientInterceptor(
ClientInterceptor, UnaryStreamClientInterceptor # type: ignore
ClientInterceptor,
UnaryStreamClientInterceptor, # type: ignore
):
async def intercept_unary_stream(
self,
@@ -2,7 +2,7 @@ import sentry_sdk
from sentry_sdk.consts import OP
from sentry_sdk.integrations import DidNotEnable
from sentry_sdk.integrations.grpc.consts import SPAN_ORIGIN
from sentry_sdk.tracing import Transaction, TRANSACTION_SOURCE_CUSTOM
from sentry_sdk.tracing import Transaction, TransactionSource
from sentry_sdk.utils import event_from_exception
from typing import TYPE_CHECKING
@@ -48,7 +48,7 @@ class ServerInterceptor(grpc.aio.ServerInterceptor): # type: ignore
dict(context.invocation_metadata()),
op=OP.GRPC_SERVER,
name=name,
source=TRANSACTION_SOURCE_CUSTOM,
source=TransactionSource.CUSTOM,
origin=SPAN_ORIGIN,
)
@@ -19,7 +19,8 @@ except ImportError:
class ClientInterceptor(
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor # type: ignore
grpc.UnaryUnaryClientInterceptor, # type: ignore
grpc.UnaryStreamClientInterceptor, # type: ignore
):
_is_intercepted = False
@@ -60,9 +61,7 @@ class ClientInterceptor(
client_call_details
)
response = continuation(
client_call_details, request
) # type: UnaryStreamCall
response = continuation(client_call_details, request) # type: UnaryStreamCall
# Setting code on unary-stream leads to execution getting stuck
# span.set_data("code", response.code().name)
@@ -2,7 +2,7 @@ import sentry_sdk
from sentry_sdk.consts import OP
from sentry_sdk.integrations import DidNotEnable
from sentry_sdk.integrations.grpc.consts import SPAN_ORIGIN
from sentry_sdk.tracing import Transaction, TRANSACTION_SOURCE_CUSTOM
from sentry_sdk.tracing import Transaction, TransactionSource
from typing import TYPE_CHECKING
@@ -42,7 +42,7 @@ class ServerInterceptor(grpc.ServerInterceptor): # type: ignore
metadata,
op=OP.GRPC_SERVER,
name=name,
source=TRANSACTION_SOURCE_CUSTOM,
source=TransactionSource.CUSTOM,
origin=SPAN_ORIGIN,
)
@@ -1,8 +1,13 @@
import sentry_sdk
from sentry_sdk import start_span
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.integrations import Integration, DidNotEnable
from sentry_sdk.tracing import BAGGAGE_HEADER_NAME
from sentry_sdk.tracing_utils import Baggage, should_propagate_trace
from sentry_sdk.tracing_utils import (
Baggage,
should_propagate_trace,
add_http_request_source,
)
from sentry_sdk.utils import (
SENSITIVE_DATA_SUBSTITUTE,
capture_internal_exceptions,
@@ -52,7 +57,7 @@ def _install_httpx_client():
with capture_internal_exceptions():
parsed_url = parse_url(str(request.url), sanitize=False)
with sentry_sdk.start_span(
with start_span(
op=OP.HTTP_CLIENT,
name="%s %s"
% (
@@ -88,7 +93,10 @@ def _install_httpx_client():
span.set_http_status(rv.status_code)
span.set_data("reason", rv.reason_phrase)
return rv
with capture_internal_exceptions():
add_http_request_source(span)
return rv
Client.send = send
@@ -106,7 +114,7 @@ def _install_httpx_async_client():
with capture_internal_exceptions():
parsed_url = parse_url(str(request.url), sanitize=False)
with sentry_sdk.start_span(
with start_span(
op=OP.HTTP_CLIENT,
name="%s %s"
% (
@@ -144,7 +152,10 @@ def _install_httpx_async_client():
span.set_http_status(rv.status_code)
span.set_data("reason", rv.reason_phrase)
return rv
with capture_internal_exceptions():
add_http_request_source(span)
return rv
AsyncClient.send = send
@@ -9,7 +9,7 @@ from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing import (
BAGGAGE_HEADER_NAME,
SENTRY_TRACE_HEADER_NAME,
TRANSACTION_SOURCE_TASK,
TransactionSource,
)
from sentry_sdk.utils import (
capture_internal_exceptions,
@@ -159,7 +159,7 @@ def patch_execute():
sentry_headers or {},
name=task.name,
op=OP.QUEUE_TASK_HUEY,
source=TRANSACTION_SOURCE_TASK,
source=TransactionSource.TASK,
origin=HueyIntegration.origin,
)
transaction.set_status(SPANSTATUS.OK)
@@ -1,24 +1,25 @@
import inspect
from functools import wraps
from sentry_sdk import consts
import sentry_sdk
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.ai.utils import set_data_normalized
from sentry_sdk.consts import SPANDATA
from typing import Any, Iterable, Callable
import sentry_sdk
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing_utils import set_span_errored
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, Callable, Iterable
try:
import huggingface_hub.inference._client
from huggingface_hub import ChatCompletionStreamOutput, TextGenerationOutput
except ImportError:
raise DidNotEnable("Huggingface not installed")
@@ -34,15 +35,26 @@ class HuggingfaceHubIntegration(Integration):
@staticmethod
def setup_once():
# type: () -> None
# Other tasks that can be called: https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks
huggingface_hub.inference._client.InferenceClient.text_generation = (
_wrap_text_generation(
huggingface_hub.inference._client.InferenceClient.text_generation
_wrap_huggingface_task(
huggingface_hub.inference._client.InferenceClient.text_generation,
OP.GEN_AI_GENERATE_TEXT,
)
)
huggingface_hub.inference._client.InferenceClient.chat_completion = (
_wrap_huggingface_task(
huggingface_hub.inference._client.InferenceClient.chat_completion,
OP.GEN_AI_CHAT,
)
)
def _capture_exception(exc):
# type: (Any) -> None
set_span_errored()
event, hint = event_from_exception(
exc,
client_options=sentry_sdk.get_client().options,
@@ -51,34 +63,70 @@ def _capture_exception(exc):
sentry_sdk.capture_event(event, hint=hint)
def _wrap_text_generation(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
def _wrap_huggingface_task(f, op):
# type: (Callable[..., Any], str) -> Callable[..., Any]
@wraps(f)
def new_text_generation(*args, **kwargs):
def new_huggingface_task(*args, **kwargs):
# type: (*Any, **Any) -> Any
integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)
if integration is None:
return f(*args, **kwargs)
prompt = None
if "prompt" in kwargs:
prompt = kwargs["prompt"]
elif "messages" in kwargs:
prompt = kwargs["messages"]
elif len(args) >= 2:
kwargs["prompt"] = args[1]
prompt = kwargs["prompt"]
args = (args[0],) + args[2:]
else:
# invalid call, let it return error
if isinstance(args[1], str) or isinstance(args[1], list):
prompt = args[1]
if prompt is None:
# invalid call, dont instrument, let it return error
return f(*args, **kwargs)
model = kwargs.get("model")
streaming = kwargs.get("stream")
client = args[0]
model = client.model or kwargs.get("model") or ""
operation_name = op.split(".")[-1]
span = sentry_sdk.start_span(
op=consts.OP.HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE,
name="Text Generation",
op=op,
name=f"{operation_name} {model}",
origin=HuggingfaceHubIntegration.origin,
)
span.__enter__()
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, operation_name)
if model:
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)
# Input attributes
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, prompt, unpack=False
)
attribute_mapping = {
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
"top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
"stream": SPANDATA.GEN_AI_RESPONSE_STREAMING,
}
for attribute, span_attribute in attribute_mapping.items():
value = kwargs.get(attribute, None)
if value is not None:
if isinstance(value, (int, float, bool, str)):
span.set_data(span_attribute, value)
else:
set_data_normalized(span, span_attribute, value, unpack=False)
# LLM Execution
try:
res = f(*args, **kwargs)
except Exception as e:
@@ -86,90 +134,245 @@ def _wrap_text_generation(f):
span.__exit__(None, None, None)
raise e from None
# Output attributes
finish_reason = None
response_model = None
response_text_buffer: list[str] = []
tokens_used = 0
tool_calls = None
usage = None
with capture_internal_exceptions():
if isinstance(res, str) and res is not None:
response_text_buffer.append(res)
if hasattr(res, "generated_text") and res.generated_text is not None:
response_text_buffer.append(res.generated_text)
if hasattr(res, "model") and res.model is not None:
response_model = res.model
if hasattr(res, "details") and hasattr(res.details, "finish_reason"):
finish_reason = res.details.finish_reason
if (
hasattr(res, "details")
and hasattr(res.details, "generated_tokens")
and res.details.generated_tokens is not None
):
tokens_used = res.details.generated_tokens
if hasattr(res, "usage") and res.usage is not None:
usage = res.usage
if hasattr(res, "choices") and res.choices is not None:
for choice in res.choices:
if hasattr(choice, "finish_reason"):
finish_reason = choice.finish_reason
if hasattr(choice, "message") and hasattr(
choice.message, "tool_calls"
):
tool_calls = choice.message.tool_calls
if (
hasattr(choice, "message")
and hasattr(choice.message, "content")
and choice.message.content is not None
):
response_text_buffer.append(choice.message.content)
if response_model is not None:
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)
if finish_reason is not None:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
finish_reason,
)
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompt)
set_data_normalized(span, SPANDATA.AI_MODEL_ID, model)
set_data_normalized(span, SPANDATA.AI_STREAMING, streaming)
if isinstance(res, str):
if should_send_default_pii() and integration.include_prompts:
if tool_calls is not None and len(tool_calls) > 0:
set_data_normalized(
span,
"ai.responses",
[res],
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
tool_calls,
unpack=False,
)
span.__exit__(None, None, None)
return res
if isinstance(res, TextGenerationOutput):
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span,
"ai.responses",
[res.generated_text],
)
if res.details is not None and res.details.generated_tokens > 0:
record_token_usage(span, total_tokens=res.details.generated_tokens)
span.__exit__(None, None, None)
return res
if len(response_text_buffer) > 0:
text_response = "".join(response_text_buffer)
if text_response:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TEXT,
text_response,
)
if not isinstance(res, Iterable):
# we only know how to deal with strings and iterables, ignore
set_data_normalized(span, "unknown_response", True)
if usage is not None:
record_token_usage(
span,
input_tokens=usage.prompt_tokens,
output_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
)
elif tokens_used > 0:
record_token_usage(
span,
total_tokens=tokens_used,
)
# If the response is not a generator (meaning a streaming response)
# we are done and can return the response
if not inspect.isgenerator(res):
span.__exit__(None, None, None)
return res
if kwargs.get("details", False):
# res is Iterable[TextGenerationStreamOutput]
# text-generation stream output
def new_details_iterator():
# type: () -> Iterable[ChatCompletionStreamOutput]
# type: () -> Iterable[Any]
finish_reason = None
response_text_buffer: list[str] = []
tokens_used = 0
with capture_internal_exceptions():
tokens_used = 0
data_buf: list[str] = []
for x in res:
if hasattr(x, "token") and hasattr(x.token, "text"):
data_buf.append(x.token.text)
if hasattr(x, "details") and hasattr(
x.details, "generated_tokens"
for chunk in res:
if (
hasattr(chunk, "token")
and hasattr(chunk.token, "text")
and chunk.token.text is not None
):
tokens_used = x.details.generated_tokens
yield x
if (
len(data_buf) > 0
and should_send_default_pii()
and integration.include_prompts
):
response_text_buffer.append(chunk.token.text)
if hasattr(chunk, "details") and hasattr(
chunk.details, "finish_reason"
):
finish_reason = chunk.details.finish_reason
if (
hasattr(chunk, "details")
and hasattr(chunk.details, "generated_tokens")
and chunk.details.generated_tokens is not None
):
tokens_used = chunk.details.generated_tokens
yield chunk
if finish_reason is not None:
set_data_normalized(
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
span,
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
finish_reason,
)
if should_send_default_pii() and integration.include_prompts:
if len(response_text_buffer) > 0:
text_response = "".join(response_text_buffer)
if text_response:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TEXT,
text_response,
)
if tokens_used > 0:
record_token_usage(span, total_tokens=tokens_used)
record_token_usage(
span,
total_tokens=tokens_used,
)
span.__exit__(None, None, None)
return new_details_iterator()
else:
# res is Iterable[str]
else:
# chat-completion stream output
def new_iterator():
# type: () -> Iterable[str]
data_buf: list[str] = []
finish_reason = None
response_model = None
response_text_buffer: list[str] = []
tool_calls = None
usage = None
with capture_internal_exceptions():
for s in res:
if isinstance(s, str):
data_buf.append(s)
yield s
if (
len(data_buf) > 0
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
for chunk in res:
if hasattr(chunk, "model") and chunk.model is not None:
response_model = chunk.model
if hasattr(chunk, "usage") and chunk.usage is not None:
usage = chunk.usage
if isinstance(chunk, str):
if chunk is not None:
response_text_buffer.append(chunk)
if hasattr(chunk, "choices") and chunk.choices is not None:
for choice in chunk.choices:
if (
hasattr(choice, "delta")
and hasattr(choice.delta, "content")
and choice.delta.content is not None
):
response_text_buffer.append(
choice.delta.content
)
if (
hasattr(choice, "finish_reason")
and choice.finish_reason is not None
):
finish_reason = choice.finish_reason
if (
hasattr(choice, "delta")
and hasattr(choice.delta, "tool_calls")
and choice.delta.tool_calls is not None
):
tool_calls = choice.delta.tool_calls
yield chunk
if response_model is not None:
span.set_data(
SPANDATA.GEN_AI_RESPONSE_MODEL, response_model
)
if finish_reason is not None:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
finish_reason,
)
if should_send_default_pii() and integration.include_prompts:
if tool_calls is not None and len(tool_calls) > 0:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
tool_calls,
unpack=False,
)
if len(response_text_buffer) > 0:
text_response = "".join(response_text_buffer)
if text_response:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TEXT,
text_response,
)
if usage is not None:
record_token_usage(
span,
input_tokens=usage.prompt_tokens,
output_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
)
span.__exit__(None, None, None)
return new_iterator()
return new_text_generation
return new_huggingface_task
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,337 @@
from functools import wraps
from typing import Any, Callable, List, Optional
import sentry_sdk
from sentry_sdk.ai.utils import (
set_data_normalized,
normalize_message_roles,
truncate_and_annotate_messages,
)
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.utils import safe_serialize
try:
from langgraph.graph import StateGraph
from langgraph.pregel import Pregel
except ImportError:
raise DidNotEnable("langgraph not installed")
class LanggraphIntegration(Integration):
identifier = "langgraph"
origin = f"auto.ai.{identifier}"
def __init__(self, include_prompts=True):
# type: (LanggraphIntegration, bool) -> None
self.include_prompts = include_prompts
@staticmethod
def setup_once():
# type: () -> None
# LangGraph lets users create agents using a StateGraph or the Functional API.
# StateGraphs are then compiled to a CompiledStateGraph. Both CompiledStateGraph and
# the functional API execute on a Pregel instance. Pregel is the runtime for the graph
# and the invocation happens on Pregel, so patching the invoke methods takes care of both.
# The streaming methods are not patched, because due to some internal reasons, LangGraph
# will automatically patch the streaming methods to run through invoke, and by doing this
# we prevent duplicate spans for invocations.
StateGraph.compile = _wrap_state_graph_compile(StateGraph.compile)
if hasattr(Pregel, "invoke"):
Pregel.invoke = _wrap_pregel_invoke(Pregel.invoke)
if hasattr(Pregel, "ainvoke"):
Pregel.ainvoke = _wrap_pregel_ainvoke(Pregel.ainvoke)
def _get_graph_name(graph_obj):
# type: (Any) -> Optional[str]
for attr in ["name", "graph_name", "__name__", "_name"]:
if hasattr(graph_obj, attr):
name = getattr(graph_obj, attr)
if name and isinstance(name, str):
return name
return None
def _normalize_langgraph_message(message):
# type: (Any) -> Any
if not hasattr(message, "content"):
return None
parsed = {"role": getattr(message, "type", None), "content": message.content}
for attr in ["name", "tool_calls", "function_call", "tool_call_id"]:
if hasattr(message, attr):
value = getattr(message, attr)
if value is not None:
parsed[attr] = value
return parsed
def _parse_langgraph_messages(state):
# type: (Any) -> Optional[List[Any]]
if not state:
return None
messages = None
if isinstance(state, dict):
messages = state.get("messages")
elif hasattr(state, "messages"):
messages = state.messages
elif hasattr(state, "get") and callable(state.get):
try:
messages = state.get("messages")
except Exception:
pass
if not messages or not isinstance(messages, (list, tuple)):
return None
normalized_messages = []
for message in messages:
try:
normalized = _normalize_langgraph_message(message)
if normalized:
normalized_messages.append(normalized)
except Exception:
continue
return normalized_messages if normalized_messages else None
def _wrap_state_graph_compile(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
def new_compile(self, *args, **kwargs):
# type: (Any, Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(LanggraphIntegration)
if integration is None:
return f(self, *args, **kwargs)
with sentry_sdk.start_span(
op=OP.GEN_AI_CREATE_AGENT,
origin=LanggraphIntegration.origin,
) as span:
compiled_graph = f(self, *args, **kwargs)
compiled_graph_name = getattr(compiled_graph, "name", None)
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "create_agent")
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, compiled_graph_name)
if compiled_graph_name:
span.description = f"create_agent {compiled_graph_name}"
else:
span.description = "create_agent"
if kwargs.get("model", None) is not None:
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, kwargs.get("model"))
tools = None
get_graph = getattr(compiled_graph, "get_graph", None)
if get_graph and callable(get_graph):
graph_obj = compiled_graph.get_graph()
nodes = getattr(graph_obj, "nodes", None)
if nodes and isinstance(nodes, dict):
tools_node = nodes.get("tools")
if tools_node:
data = getattr(tools_node, "data", None)
if data and hasattr(data, "tools_by_name"):
tools = list(data.tools_by_name.keys())
if tools is not None:
span.set_data(SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, tools)
return compiled_graph
return new_compile
def _wrap_pregel_invoke(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
def new_invoke(self, *args, **kwargs):
# type: (Any, Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(LanggraphIntegration)
if integration is None:
return f(self, *args, **kwargs)
graph_name = _get_graph_name(self)
span_name = (
f"invoke_agent {graph_name}".strip() if graph_name else "invoke_agent"
)
with sentry_sdk.start_span(
op=OP.GEN_AI_INVOKE_AGENT,
name=span_name,
origin=LanggraphIntegration.origin,
) as span:
if graph_name:
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, graph_name)
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, graph_name)
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
# Store input messages to later compare with output
input_messages = None
if (
len(args) > 0
and should_send_default_pii()
and integration.include_prompts
):
input_messages = _parse_langgraph_messages(args[0])
if input_messages:
normalized_input_messages = normalize_message_roles(input_messages)
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_messages(
normalized_input_messages, span, scope
)
if messages_data is not None:
set_data_normalized(
span,
SPANDATA.GEN_AI_REQUEST_MESSAGES,
messages_data,
unpack=False,
)
result = f(self, *args, **kwargs)
_set_response_attributes(span, input_messages, result, integration)
return result
return new_invoke
def _wrap_pregel_ainvoke(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
async def new_ainvoke(self, *args, **kwargs):
# type: (Any, Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(LanggraphIntegration)
if integration is None:
return await f(self, *args, **kwargs)
graph_name = _get_graph_name(self)
span_name = (
f"invoke_agent {graph_name}".strip() if graph_name else "invoke_agent"
)
with sentry_sdk.start_span(
op=OP.GEN_AI_INVOKE_AGENT,
name=span_name,
origin=LanggraphIntegration.origin,
) as span:
if graph_name:
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, graph_name)
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, graph_name)
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
input_messages = None
if (
len(args) > 0
and should_send_default_pii()
and integration.include_prompts
):
input_messages = _parse_langgraph_messages(args[0])
if input_messages:
normalized_input_messages = normalize_message_roles(input_messages)
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_messages(
normalized_input_messages, span, scope
)
if messages_data is not None:
set_data_normalized(
span,
SPANDATA.GEN_AI_REQUEST_MESSAGES,
messages_data,
unpack=False,
)
result = await f(self, *args, **kwargs)
_set_response_attributes(span, input_messages, result, integration)
return result
return new_ainvoke
def _get_new_messages(input_messages, output_messages):
# type: (Optional[List[Any]], Optional[List[Any]]) -> Optional[List[Any]]
"""Extract only the new messages added during this invocation."""
if not output_messages:
return None
if not input_messages:
return output_messages
# only return the new messages, aka the output messages that are not in the input messages
input_count = len(input_messages)
new_messages = (
output_messages[input_count:] if len(output_messages) > input_count else []
)
return new_messages if new_messages else None
def _extract_llm_response_text(messages):
# type: (Optional[List[Any]]) -> Optional[str]
if not messages:
return None
for message in reversed(messages):
if isinstance(message, dict):
role = message.get("role")
if role in ["assistant", "ai"]:
content = message.get("content")
if content and isinstance(content, str):
return content
return None
def _extract_tool_calls(messages):
# type: (Optional[List[Any]]) -> Optional[List[Any]]
if not messages:
return None
tool_calls = []
for message in messages:
if isinstance(message, dict):
msg_tool_calls = message.get("tool_calls")
if msg_tool_calls and isinstance(msg_tool_calls, list):
tool_calls.extend(msg_tool_calls)
return tool_calls if tool_calls else None
def _set_response_attributes(span, input_messages, result, integration):
# type: (Any, Optional[List[Any]], Any, LanggraphIntegration) -> None
if not (should_send_default_pii() and integration.include_prompts):
return
parsed_response_messages = _parse_langgraph_messages(result)
new_messages = _get_new_messages(input_messages, parsed_response_messages)
llm_response_text = _extract_llm_response_text(new_messages)
if llm_response_text:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, llm_response_text)
elif new_messages:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, new_messages)
else:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, result)
tool_calls = _extract_tool_calls(new_messages)
if tool_calls:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
safe_serialize(tool_calls),
unpack=False,
)
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
import sentry_sdk
from sentry_sdk.feature_flags import add_feature_flag
from sentry_sdk.integrations import DidNotEnable, Integration
try:
@@ -44,7 +44,6 @@ class LaunchDarklyIntegration(Integration):
class LaunchDarklyHook(Hook):
@property
def metadata(self):
# type: () -> Metadata
@@ -53,8 +52,8 @@ class LaunchDarklyHook(Hook):
def after_evaluation(self, series_context, data, detail):
# type: (EvaluationSeriesContext, dict[Any, Any], EvaluationDetail) -> dict[Any, Any]
if isinstance(detail.value, bool):
flags = sentry_sdk.get_current_scope().flags
flags.set(series_context.key, detail.value)
add_feature_flag(series_context.key, detail.value)
return data
def before_evaluation(self, series_context, data):
@@ -0,0 +1,262 @@
from typing import TYPE_CHECKING
import sentry_sdk
from sentry_sdk import consts
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.ai.utils import (
get_start_span_function,
set_data_normalized,
truncate_and_annotate_messages,
)
from sentry_sdk.consts import SPANDATA
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.utils import event_from_exception
if TYPE_CHECKING:
from typing import Any, Dict
from datetime import datetime
try:
import litellm # type: ignore[import-not-found]
except ImportError:
raise DidNotEnable("LiteLLM not installed")
def _get_metadata_dict(kwargs):
# type: (Dict[str, Any]) -> Dict[str, Any]
"""Get the metadata dictionary from the kwargs."""
litellm_params = kwargs.setdefault("litellm_params", {})
# we need this weird little dance, as metadata might be set but may be None initially
metadata = litellm_params.get("metadata")
if metadata is None:
metadata = {}
litellm_params["metadata"] = metadata
return metadata
def _input_callback(kwargs):
# type: (Dict[str, Any]) -> None
"""Handle the start of a request."""
integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration)
if integration is None:
return
# Get key parameters
full_model = kwargs.get("model", "")
try:
model, provider, _, _ = litellm.get_llm_provider(full_model)
except Exception:
model = full_model
provider = "unknown"
call_type = kwargs.get("call_type", None)
if call_type == "embedding":
operation = "embeddings"
else:
operation = "chat"
# Start a new span/transaction
span = get_start_span_function()(
op=(
consts.OP.GEN_AI_CHAT
if operation == "chat"
else consts.OP.GEN_AI_EMBEDDINGS
),
name=f"{operation} {model}",
origin=LiteLLMIntegration.origin,
)
span.__enter__()
# Store span for later
_get_metadata_dict(kwargs)["_sentry_span"] = span
# Set basic data
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, provider)
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, operation)
# Record messages if allowed
messages = kwargs.get("messages", [])
if messages and should_send_default_pii() and integration.include_prompts:
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_messages(messages, span, scope)
if messages_data is not None:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
)
# Record other parameters
params = {
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
"stream": SPANDATA.GEN_AI_RESPONSE_STREAMING,
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
}
for key, attribute in params.items():
value = kwargs.get(key)
if value is not None:
set_data_normalized(span, attribute, value)
# Record LiteLLM-specific parameters
litellm_params = {
"api_base": kwargs.get("api_base"),
"api_version": kwargs.get("api_version"),
"custom_llm_provider": kwargs.get("custom_llm_provider"),
}
for key, value in litellm_params.items():
if value is not None:
set_data_normalized(span, f"gen_ai.litellm.{key}", value)
def _success_callback(kwargs, completion_response, start_time, end_time):
# type: (Dict[str, Any], Any, datetime, datetime) -> None
"""Handle successful completion."""
span = _get_metadata_dict(kwargs).get("_sentry_span")
if span is None:
return
integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration)
if integration is None:
return
try:
# Record model information
if hasattr(completion_response, "model"):
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_MODEL, completion_response.model
)
# Record response content if allowed
if should_send_default_pii() and integration.include_prompts:
if hasattr(completion_response, "choices"):
response_messages = []
for choice in completion_response.choices:
if hasattr(choice, "message"):
if hasattr(choice.message, "model_dump"):
response_messages.append(choice.message.model_dump())
elif hasattr(choice.message, "dict"):
response_messages.append(choice.message.dict())
else:
# Fallback for basic message objects
msg = {}
if hasattr(choice.message, "role"):
msg["role"] = choice.message.role
if hasattr(choice.message, "content"):
msg["content"] = choice.message.content
if hasattr(choice.message, "tool_calls"):
msg["tool_calls"] = choice.message.tool_calls
response_messages.append(msg)
if response_messages:
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_messages
)
# Record token usage
if hasattr(completion_response, "usage"):
usage = completion_response.usage
record_token_usage(
span,
input_tokens=getattr(usage, "prompt_tokens", None),
output_tokens=getattr(usage, "completion_tokens", None),
total_tokens=getattr(usage, "total_tokens", None),
)
finally:
# Always finish the span and clean up
span.__exit__(None, None, None)
def _failure_callback(kwargs, exception, start_time, end_time):
# type: (Dict[str, Any], Exception, datetime, datetime) -> None
"""Handle request failure."""
span = _get_metadata_dict(kwargs).get("_sentry_span")
if span is None:
return
try:
# Capture the exception
event, hint = event_from_exception(
exception,
client_options=sentry_sdk.get_client().options,
mechanism={"type": "litellm", "handled": False},
)
sentry_sdk.capture_event(event, hint=hint)
finally:
# Always finish the span and clean up
span.__exit__(type(exception), exception, None)
class LiteLLMIntegration(Integration):
"""
LiteLLM integration for Sentry.
This integration automatically captures LiteLLM API calls and sends them to Sentry
for monitoring and error tracking. It supports all 100+ LLM providers that LiteLLM
supports, including OpenAI, Anthropic, Google, Cohere, and many others.
Features:
- Automatic exception capture for all LiteLLM calls
- Token usage tracking across all providers
- Provider detection and attribution
- Input/output message capture (configurable)
- Streaming response support
- Cost tracking integration
Usage:
```python
import litellm
import sentry_sdk
# Initialize Sentry with the LiteLLM integration
sentry_sdk.init(
dsn="your-dsn",
send_default_pii=True
integrations=[
sentry_sdk.integrations.LiteLLMIntegration(
include_prompts=True # Set to False to exclude message content
)
]
)
# All LiteLLM calls will now be monitored
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello!"}]
)
```
Configuration:
- include_prompts (bool): Whether to include prompts and responses in spans.
Defaults to True. Set to False to exclude potentially sensitive data.
"""
identifier = "litellm"
origin = f"auto.ai.{identifier}"
def __init__(self, include_prompts=True):
# type: (LiteLLMIntegration, bool) -> None
self.include_prompts = include_prompts
@staticmethod
def setup_once():
# type: () -> None
"""Set up LiteLLM callbacks for monitoring."""
litellm.input_callback = litellm.input_callback or []
if _input_callback not in litellm.input_callback:
litellm.input_callback.append(_input_callback)
litellm.success_callback = litellm.success_callback or []
if _success_callback not in litellm.success_callback:
litellm.success_callback.append(_success_callback)
litellm.failure_callback = litellm.failure_callback or []
if _failure_callback not in litellm.failure_callback:
litellm.failure_callback.append(_failure_callback)
@@ -1,4 +1,6 @@
from collections.abc import Set
from copy import deepcopy
import sentry_sdk
from sentry_sdk.consts import OP
from sentry_sdk.integrations import (
@@ -9,7 +11,7 @@ from sentry_sdk.integrations import (
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from sentry_sdk.integrations.logging import ignore_logger
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing import SOURCE_FOR_STYLE, TRANSACTION_SOURCE_ROUTE
from sentry_sdk.tracing import TransactionSource, SOURCE_FOR_STYLE
from sentry_sdk.utils import (
ensure_integration_enabled,
event_from_exception,
@@ -85,8 +87,18 @@ class SentryLitestarASGIMiddleware(SentryAsgiMiddleware):
transaction_style="endpoint",
mechanism_type="asgi",
span_origin=span_origin,
asgi_version=3,
)
def _capture_request_exception(self, exc):
# type: (Exception) -> None
"""Avoid catching exceptions from request handlers.
Those exceptions are already handled in Litestar.after_exception handler.
We still catch exceptions from application lifespan handlers.
"""
pass
def patch_app_init():
# type: () -> None
@@ -107,7 +119,6 @@ def patch_app_init():
*(kwargs.get("after_exception") or []),
]
SentryLitestarASGIMiddleware.__call__ = SentryLitestarASGIMiddleware._run_asgi3 # type: ignore
middleware = kwargs.get("middleware") or []
kwargs["middleware"] = [SentryLitestarASGIMiddleware, *middleware]
old__init__(self, *args, **kwargs)
@@ -213,9 +224,7 @@ def patch_http_route_handle():
return await old_handle(self, scope, receive, send)
sentry_scope = sentry_sdk.get_isolation_scope()
request = scope["app"].request_class(
scope=scope, receive=receive, send=send
) # type: Request[Any, Any]
request = scope["app"].request_class(scope=scope, receive=receive, send=send) # type: Request[Any, Any]
extracted_request_data = ConnectionDataExtractor(
parse_body=True, parse_query=True
)(request)
@@ -249,11 +258,11 @@ def patch_http_route_handle():
if not tx_name:
tx_name = _DEFAULT_TRANSACTION_NAME
tx_info = {"source": TRANSACTION_SOURCE_ROUTE}
tx_info = {"source": TransactionSource.ROUTE}
event.update(
{
"request": request_info,
"request": deepcopy(request_info),
"transaction": tx_name,
"transaction_info": tx_info,
}
@@ -1,13 +1,18 @@
import logging
import sys
from datetime import datetime, timezone
from fnmatch import fnmatch
import sentry_sdk
from sentry_sdk.client import BaseClient
from sentry_sdk.logger import _log_level_to_otel
from sentry_sdk.utils import (
safe_repr,
to_string,
event_from_exception,
current_stacktrace,
capture_internal_exceptions,
has_logs_enabled,
)
from sentry_sdk.integrations import Integration
@@ -33,6 +38,16 @@ LOGGING_TO_EVENT_LEVEL = {
logging.CRITICAL: "fatal", # CRITICAL is same as FATAL
}
# Map logging level numbers to corresponding OTel level numbers
SEVERITY_TO_OTEL_SEVERITY = {
logging.CRITICAL: 21, # fatal
logging.ERROR: 17, # error
logging.WARNING: 13, # warn
logging.INFO: 9, # info
logging.DEBUG: 5, # debug
}
# Capturing events from those loggers causes recursion errors. We cannot allow
# the user to unconditionally create events from those loggers under any
# circumstances.
@@ -61,14 +76,23 @@ def ignore_logger(
class LoggingIntegration(Integration):
identifier = "logging"
def __init__(self, level=DEFAULT_LEVEL, event_level=DEFAULT_EVENT_LEVEL):
# type: (Optional[int], Optional[int]) -> None
def __init__(
self,
level=DEFAULT_LEVEL,
event_level=DEFAULT_EVENT_LEVEL,
sentry_logs_level=DEFAULT_LEVEL,
):
# type: (Optional[int], Optional[int], Optional[int]) -> None
self._handler = None
self._breadcrumb_handler = None
self._sentry_logs_handler = None
if level is not None:
self._breadcrumb_handler = BreadcrumbHandler(level=level)
if sentry_logs_level is not None:
self._sentry_logs_handler = SentryLogsHandler(level=sentry_logs_level)
if event_level is not None:
self._handler = EventHandler(level=event_level)
@@ -83,6 +107,12 @@ class LoggingIntegration(Integration):
):
self._breadcrumb_handler.handle(record)
if (
self._sentry_logs_handler is not None
and record.levelno >= self._sentry_logs_handler.level
):
self._sentry_logs_handler.handle(record)
@staticmethod
def setup_once():
# type: () -> None
@@ -101,7 +131,10 @@ class LoggingIntegration(Integration):
# the integration. Otherwise we have a high chance of getting
# into a recursion error when the integration is resolved
# (this also is slower).
if ignored_loggers is not None and record.name not in ignored_loggers:
if (
ignored_loggers is not None
and record.name.strip() not in ignored_loggers
):
integration = sentry_sdk.get_client().get_integration(
LoggingIntegration
)
@@ -146,7 +179,7 @@ class _BaseHandler(logging.Handler):
# type: (LogRecord) -> bool
"""Prevents ignored loggers from recording"""
for logger in _IGNORED_LOGGERS:
if fnmatch(record.name, logger):
if fnmatch(record.name.strip(), logger):
return False
return True
@@ -231,25 +264,25 @@ class EventHandler(_BaseHandler):
event["level"] = level # type: ignore[typeddict-item]
event["logger"] = record.name
# Log records from `warnings` module as separate issues
record_caputured_from_warnings_module = (
record.name == "py.warnings" and record.msg == "%s"
)
if record_caputured_from_warnings_module:
# use the actual message and not "%s" as the message
# this prevents grouping all warnings under one "%s" issue
msg = record.args[0] # type: ignore
event["logentry"] = {
"message": msg,
"params": (),
}
if (
sys.version_info < (3, 11)
and record.name == "py.warnings"
and record.msg == "%s"
):
# warnings module on Python 3.10 and below sets record.msg to "%s"
# and record.args[0] to the actual warning message.
# This was fixed in https://github.com/python/cpython/pull/30975.
message = record.args[0]
params = ()
else:
event["logentry"] = {
"message": to_string(record.msg),
"params": record.args,
}
message = record.msg
params = record.args
event["logentry"] = {
"message": to_string(message),
"formatted": record.getMessage(),
"params": params,
}
event["extra"] = self._extra_from_record(record)
@@ -292,3 +325,97 @@ class BreadcrumbHandler(_BaseHandler):
"timestamp": datetime.fromtimestamp(record.created, timezone.utc),
"data": self._extra_from_record(record),
}
class SentryLogsHandler(_BaseHandler):
"""
A logging handler that records Sentry logs for each Python log record.
Note that you do not have to use this class if the logging integration is enabled, which it is by default.
"""
def emit(self, record):
# type: (LogRecord) -> Any
with capture_internal_exceptions():
self.format(record)
if not self._can_record(record):
return
client = sentry_sdk.get_client()
if not client.is_active():
return
if not has_logs_enabled(client.options):
return
self._capture_log_from_record(client, record)
def _capture_log_from_record(self, client, record):
# type: (BaseClient, LogRecord) -> None
otel_severity_number, otel_severity_text = _log_level_to_otel(
record.levelno, SEVERITY_TO_OTEL_SEVERITY
)
project_root = client.options["project_root"]
attrs = self._extra_from_record(record) # type: Any
attrs["sentry.origin"] = "auto.logger.log"
parameters_set = False
if record.args is not None:
if isinstance(record.args, tuple):
parameters_set = bool(record.args)
for i, arg in enumerate(record.args):
attrs[f"sentry.message.parameter.{i}"] = (
arg
if isinstance(arg, (str, float, int, bool))
else safe_repr(arg)
)
elif isinstance(record.args, dict):
parameters_set = bool(record.args)
for key, value in record.args.items():
attrs[f"sentry.message.parameter.{key}"] = (
value
if isinstance(value, (str, float, int, bool))
else safe_repr(value)
)
if parameters_set and isinstance(record.msg, str):
# only include template if there is at least one
# sentry.message.parameter.X set
attrs["sentry.message.template"] = record.msg
if record.lineno:
attrs["code.line.number"] = record.lineno
if record.pathname:
if project_root is not None and record.pathname.startswith(project_root):
attrs["code.file.path"] = record.pathname[len(project_root) + 1 :]
else:
attrs["code.file.path"] = record.pathname
if record.funcName:
attrs["code.function.name"] = record.funcName
if record.thread:
attrs["thread.id"] = record.thread
if record.threadName:
attrs["thread.name"] = record.threadName
if record.process:
attrs["process.pid"] = record.process
if record.processName:
attrs["process.executable.name"] = record.processName
if record.name:
attrs["logger.name"] = record.name
# noinspection PyProtectedMember
client._capture_log(
{
"severity_text": otel_severity_text,
"severity_number": otel_severity_number,
"body": record.message,
"attributes": attrs,
"time_unix_nano": int(record.created * 1e9),
"trace_id": None,
},
)
@@ -1,22 +1,28 @@
import enum
import sentry_sdk
from sentry_sdk.integrations import Integration, DidNotEnable
from sentry_sdk.integrations.logging import (
BreadcrumbHandler,
EventHandler,
_BaseHandler,
)
from sentry_sdk.logger import _log_level_to_otel
from sentry_sdk.utils import has_logs_enabled
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from logging import LogRecord
from typing import Optional, Tuple
from typing import Any, Optional
try:
import loguru
from loguru import logger
from loguru._defaults import LOGURU_FORMAT as DEFAULT_FORMAT
if TYPE_CHECKING:
from loguru import Message
except ImportError:
raise DidNotEnable("LOGURU is not installed")
@@ -33,68 +39,167 @@ class LoggingLevels(enum.IntEnum):
DEFAULT_LEVEL = LoggingLevels.INFO.value
DEFAULT_EVENT_LEVEL = LoggingLevels.ERROR.value
# We need to save the handlers to be able to remove them later
# in tests (they call `LoguruIntegration.__init__` multiple times,
# and we can't use `setup_once` because it's called before
# than we get configuration).
_ADDED_HANDLERS = (None, None) # type: Tuple[Optional[int], Optional[int]]
SENTRY_LEVEL_FROM_LOGURU_LEVEL = {
"TRACE": "DEBUG",
"DEBUG": "DEBUG",
"INFO": "INFO",
"SUCCESS": "INFO",
"WARNING": "WARNING",
"ERROR": "ERROR",
"CRITICAL": "CRITICAL",
}
# Map Loguru level numbers to corresponding OTel level numbers
SEVERITY_TO_OTEL_SEVERITY = {
LoggingLevels.CRITICAL: 21, # fatal
LoggingLevels.ERROR: 17, # error
LoggingLevels.WARNING: 13, # warn
LoggingLevels.SUCCESS: 11, # info
LoggingLevels.INFO: 9, # info
LoggingLevels.DEBUG: 5, # debug
LoggingLevels.TRACE: 1, # trace
}
class LoguruIntegration(Integration):
identifier = "loguru"
level = DEFAULT_LEVEL # type: Optional[int]
event_level = DEFAULT_EVENT_LEVEL # type: Optional[int]
breadcrumb_format = DEFAULT_FORMAT
event_format = DEFAULT_FORMAT
sentry_logs_level = DEFAULT_LEVEL # type: Optional[int]
def __init__(
self,
level=DEFAULT_LEVEL,
event_level=DEFAULT_EVENT_LEVEL,
breadcrumb_format=DEFAULT_FORMAT,
event_format=DEFAULT_FORMAT,
sentry_logs_level=DEFAULT_LEVEL,
):
# type: (Optional[int], Optional[int], str | loguru.FormatFunction, str | loguru.FormatFunction) -> None
global _ADDED_HANDLERS
breadcrumb_handler, event_handler = _ADDED_HANDLERS
if breadcrumb_handler is not None:
logger.remove(breadcrumb_handler)
breadcrumb_handler = None
if event_handler is not None:
logger.remove(event_handler)
event_handler = None
if level is not None:
breadcrumb_handler = logger.add(
LoguruBreadcrumbHandler(level=level),
level=level,
format=breadcrumb_format,
)
if event_level is not None:
event_handler = logger.add(
LoguruEventHandler(level=event_level),
level=event_level,
format=event_format,
)
_ADDED_HANDLERS = (breadcrumb_handler, event_handler)
# type: (Optional[int], Optional[int], str | loguru.FormatFunction, str | loguru.FormatFunction, Optional[int]) -> None
LoguruIntegration.level = level
LoguruIntegration.event_level = event_level
LoguruIntegration.breadcrumb_format = breadcrumb_format
LoguruIntegration.event_format = event_format
LoguruIntegration.sentry_logs_level = sentry_logs_level
@staticmethod
def setup_once():
# type: () -> None
pass # we do everything in __init__
if LoguruIntegration.level is not None:
logger.add(
LoguruBreadcrumbHandler(level=LoguruIntegration.level),
level=LoguruIntegration.level,
format=LoguruIntegration.breadcrumb_format,
)
if LoguruIntegration.event_level is not None:
logger.add(
LoguruEventHandler(level=LoguruIntegration.event_level),
level=LoguruIntegration.event_level,
format=LoguruIntegration.event_format,
)
if LoguruIntegration.sentry_logs_level is not None:
logger.add(
loguru_sentry_logs_handler,
level=LoguruIntegration.sentry_logs_level,
)
class _LoguruBaseHandler(_BaseHandler):
def __init__(self, *args, **kwargs):
# type: (*Any, **Any) -> None
if kwargs.get("level"):
kwargs["level"] = SENTRY_LEVEL_FROM_LOGURU_LEVEL.get(
kwargs.get("level", ""), DEFAULT_LEVEL
)
super().__init__(*args, **kwargs)
def _logging_to_event_level(self, record):
# type: (LogRecord) -> str
try:
return LoggingLevels(record.levelno).name.lower()
except ValueError:
return SENTRY_LEVEL_FROM_LOGURU_LEVEL[
LoggingLevels(record.levelno).name
].lower()
except (ValueError, KeyError):
return record.levelname.lower() if record.levelname else ""
class LoguruEventHandler(_LoguruBaseHandler, EventHandler):
"""Modified version of :class:`sentry_sdk.integrations.logging.EventHandler` to use loguru's level names."""
pass
class LoguruBreadcrumbHandler(_LoguruBaseHandler, BreadcrumbHandler):
"""Modified version of :class:`sentry_sdk.integrations.logging.BreadcrumbHandler` to use loguru's level names."""
pass
def loguru_sentry_logs_handler(message):
# type: (Message) -> None
# This is intentionally a callable sink instead of a standard logging handler
# since otherwise we wouldn't get direct access to message.record
client = sentry_sdk.get_client()
if not client.is_active():
return
if not has_logs_enabled(client.options):
return
record = message.record
if (
LoguruIntegration.sentry_logs_level is None
or record["level"].no < LoguruIntegration.sentry_logs_level
):
return
otel_severity_number, otel_severity_text = _log_level_to_otel(
record["level"].no, SEVERITY_TO_OTEL_SEVERITY
)
attrs = {"sentry.origin": "auto.logger.loguru"} # type: dict[str, Any]
project_root = client.options["project_root"]
if record.get("file"):
if project_root is not None and record["file"].path.startswith(project_root):
attrs["code.file.path"] = record["file"].path[len(project_root) + 1 :]
else:
attrs["code.file.path"] = record["file"].path
if record.get("line") is not None:
attrs["code.line.number"] = record["line"]
if record.get("function"):
attrs["code.function.name"] = record["function"]
if record.get("thread"):
attrs["thread.name"] = record["thread"].name
attrs["thread.id"] = record["thread"].id
if record.get("process"):
attrs["process.pid"] = record["process"].id
attrs["process.executable.name"] = record["process"].name
if record.get("name"):
attrs["logger.name"] = record["name"]
client._capture_log(
{
"severity_text": otel_severity_text,
"severity_number": otel_severity_number,
"body": record["message"],
"attributes": attrs,
"time_unix_nano": int(record["time"].timestamp() * 1e9),
"trace_id": None,
}
)
@@ -0,0 +1,552 @@
"""
Sentry integration for MCP (Model Context Protocol) servers.
This integration instruments MCP servers to create spans for tool, prompt,
and resource handler execution, and captures errors that occur during execution.
Supports the low-level `mcp.server.lowlevel.Server` API.
"""
import inspect
from functools import wraps
from typing import TYPE_CHECKING
import sentry_sdk
from sentry_sdk.ai.utils import get_start_span_function
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.integrations import Integration, DidNotEnable
from sentry_sdk.utils import safe_serialize
from sentry_sdk.scope import should_send_default_pii
try:
from mcp.server.lowlevel import Server # type: ignore[import-not-found]
from mcp.server.lowlevel.server import request_ctx # type: ignore[import-not-found]
except ImportError:
raise DidNotEnable("MCP SDK not installed")
if TYPE_CHECKING:
from typing import Any, Callable, Optional
class MCPIntegration(Integration):
identifier = "mcp"
origin = "auto.ai.mcp"
def __init__(self, include_prompts=True):
# type: (bool) -> None
"""
Initialize the MCP integration.
Args:
include_prompts: Whether to include prompts (tool results and prompt content)
in span data. Requires send_default_pii=True. Default is True.
"""
self.include_prompts = include_prompts
@staticmethod
def setup_once():
# type: () -> None
"""
Patches MCP server classes to instrument handler execution.
"""
_patch_lowlevel_server()
def _get_request_context_data():
# type: () -> tuple[Optional[str], Optional[str], str]
"""
Extract request ID, session ID, and transport type from the MCP request context.
Returns:
Tuple of (request_id, session_id, transport).
- request_id: May be None if not available
- session_id: May be None if not available
- transport: "tcp" for HTTP-based, "pipe" for stdio
"""
request_id = None # type: Optional[str]
session_id = None # type: Optional[str]
transport = "pipe" # type: str
try:
ctx = request_ctx.get()
if ctx is not None:
request_id = ctx.request_id
if hasattr(ctx, "request") and ctx.request is not None:
transport = "tcp"
request = ctx.request
if hasattr(request, "headers"):
session_id = request.headers.get("mcp-session-id")
except LookupError:
# No request context available - default to pipe
pass
return request_id, session_id, transport
def _get_span_config(handler_type, item_name):
# type: (str, str) -> tuple[str, str, str, Optional[str]]
"""
Get span configuration based on handler type.
Returns:
Tuple of (span_data_key, span_name, mcp_method_name, result_data_key)
Note: result_data_key is None for resources
"""
if handler_type == "tool":
span_data_key = SPANDATA.MCP_TOOL_NAME
mcp_method_name = "tools/call"
result_data_key = SPANDATA.MCP_TOOL_RESULT_CONTENT
elif handler_type == "prompt":
span_data_key = SPANDATA.MCP_PROMPT_NAME
mcp_method_name = "prompts/get"
result_data_key = SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT
else: # resource
span_data_key = SPANDATA.MCP_RESOURCE_URI
mcp_method_name = "resources/read"
result_data_key = None # Resources don't capture result content
span_name = f"{mcp_method_name} {item_name}"
return span_data_key, span_name, mcp_method_name, result_data_key
def _set_span_input_data(
span,
handler_name,
span_data_key,
mcp_method_name,
arguments,
request_id,
session_id,
transport,
):
# type: (Any, str, str, str, dict[str, Any], Optional[str], Optional[str], str) -> None
"""Set input span data for MCP handlers."""
# Set handler identifier
span.set_data(span_data_key, handler_name)
span.set_data(SPANDATA.MCP_METHOD_NAME, mcp_method_name)
# Set transport type
span.set_data(SPANDATA.MCP_TRANSPORT, transport)
# Set request_id if provided
if request_id:
span.set_data(SPANDATA.MCP_REQUEST_ID, request_id)
# Set session_id if provided
if session_id:
span.set_data(SPANDATA.MCP_SESSION_ID, session_id)
# Set request arguments (excluding common request context objects)
for k, v in arguments.items():
span.set_data(f"mcp.request.argument.{k}", safe_serialize(v))
def _extract_tool_result_content(result):
# type: (Any) -> Any
"""
Extract meaningful content from MCP tool result.
Tool handlers can return:
- tuple (UnstructuredContent, StructuredContent): Return the structured content (dict)
- dict (StructuredContent): Return as-is
- Iterable (UnstructuredContent): Extract text from content blocks
"""
if result is None:
return None
# Handle CombinationContent: tuple of (UnstructuredContent, StructuredContent)
if isinstance(result, tuple) and len(result) == 2:
# Return the structured content (2nd element)
return result[1]
# Handle StructuredContent: dict
if isinstance(result, dict):
return result
# Handle UnstructuredContent: iterable of ContentBlock objects
# Try to extract text content
if hasattr(result, "__iter__") and not isinstance(result, (str, bytes, dict)):
texts = []
try:
for item in result:
# Try to get text attribute from ContentBlock objects
if hasattr(item, "text"):
texts.append(item.text)
elif isinstance(item, dict) and "text" in item:
texts.append(item["text"])
except Exception:
# If extraction fails, return the original
return result
return " ".join(texts) if texts else result
return result
def _set_span_output_data(span, result, result_data_key, handler_type):
# type: (Any, Any, Optional[str], str) -> None
"""Set output span data for MCP handlers."""
if result is None:
return
# Get integration to check PII settings
integration = sentry_sdk.get_client().get_integration(MCPIntegration)
if integration is None:
return
# Check if we should include sensitive data
should_include_data = should_send_default_pii() and integration.include_prompts
# For tools, extract the meaningful content
if handler_type == "tool":
extracted = _extract_tool_result_content(result)
if extracted is not None and should_include_data:
span.set_data(result_data_key, safe_serialize(extracted))
# Set content count if result is a dict
if isinstance(extracted, dict):
span.set_data(SPANDATA.MCP_TOOL_RESULT_CONTENT_COUNT, len(extracted))
elif handler_type == "prompt":
# For prompts, count messages and set role/content only for single-message prompts
try:
messages = None # type: Optional[list[str]]
message_count = 0
# Check if result has messages attribute (GetPromptResult)
if hasattr(result, "messages") and result.messages:
messages = result.messages
message_count = len(messages)
# Also check if result is a dict with messages
elif isinstance(result, dict) and result.get("messages"):
messages = result["messages"]
message_count = len(messages)
# Always set message count if we found messages
if message_count > 0:
span.set_data(SPANDATA.MCP_PROMPT_RESULT_MESSAGE_COUNT, message_count)
# Only set role and content for single-message prompts if PII is allowed
if message_count == 1 and should_include_data and messages:
first_message = messages[0]
# Extract role
role = None
if hasattr(first_message, "role"):
role = first_message.role
elif isinstance(first_message, dict) and "role" in first_message:
role = first_message["role"]
if role:
span.set_data(SPANDATA.MCP_PROMPT_RESULT_MESSAGE_ROLE, role)
# Extract content text
content_text = None
if hasattr(first_message, "content"):
msg_content = first_message.content
# Content can be a TextContent object or similar
if hasattr(msg_content, "text"):
content_text = msg_content.text
elif isinstance(msg_content, dict) and "text" in msg_content:
content_text = msg_content["text"]
elif isinstance(msg_content, str):
content_text = msg_content
elif isinstance(first_message, dict) and "content" in first_message:
msg_content = first_message["content"]
if isinstance(msg_content, dict) and "text" in msg_content:
content_text = msg_content["text"]
elif isinstance(msg_content, str):
content_text = msg_content
if content_text:
span.set_data(result_data_key, content_text)
except Exception:
# Silently ignore if we can't extract message info
pass
# Resources don't capture result content (result_data_key is None)
# Handler data preparation and wrapping
def _prepare_handler_data(handler_type, original_args):
# type: (str, tuple[Any, ...]) -> tuple[str, dict[str, Any], str, str, str, Optional[str]]
"""
Prepare common handler data for both async and sync wrappers.
Returns:
Tuple of (handler_name, arguments, span_data_key, span_name, mcp_method_name, result_data_key)
"""
# Extract handler-specific data based on handler type
if handler_type == "tool":
handler_name = original_args[0] # tool_name
arguments = original_args[1] if len(original_args) > 1 else {}
elif handler_type == "prompt":
handler_name = original_args[0] # name
arguments = original_args[1] if len(original_args) > 1 else {}
# Include name in arguments dict for span data
arguments = {"name": handler_name, **(arguments or {})}
else: # resource
uri = original_args[0]
handler_name = str(uri) if uri else "unknown"
arguments = {}
# Get span configuration
span_data_key, span_name, mcp_method_name, result_data_key = _get_span_config(
handler_type, handler_name
)
return (
handler_name,
arguments,
span_data_key,
span_name,
mcp_method_name,
result_data_key,
)
async def _async_handler_wrapper(handler_type, func, original_args):
# type: (str, Callable[..., Any], tuple[Any, ...]) -> Any
"""
Async wrapper for MCP handlers.
Args:
handler_type: "tool", "prompt", or "resource"
func: The async handler function to wrap
original_args: Original arguments passed to the handler
"""
(
handler_name,
arguments,
span_data_key,
span_name,
mcp_method_name,
result_data_key,
) = _prepare_handler_data(handler_type, original_args)
# Start span and execute
with get_start_span_function()(
op=OP.MCP_SERVER,
name=span_name,
origin=MCPIntegration.origin,
) as span:
# Get request ID, session ID, and transport from context
request_id, session_id, transport = _get_request_context_data()
# Set input span data
_set_span_input_data(
span,
handler_name,
span_data_key,
mcp_method_name,
arguments,
request_id,
session_id,
transport,
)
# For resources, extract and set protocol
if handler_type == "resource":
uri = original_args[0]
protocol = None
if hasattr(uri, "scheme"):
protocol = uri.scheme
elif handler_name and "://" in handler_name:
protocol = handler_name.split("://")[0]
if protocol:
span.set_data(SPANDATA.MCP_RESOURCE_PROTOCOL, protocol)
try:
# Execute the async handler
result = await func(*original_args)
except Exception as e:
# Set error flag for tools
if handler_type == "tool":
span.set_data(SPANDATA.MCP_TOOL_RESULT_IS_ERROR, True)
sentry_sdk.capture_exception(e)
raise
_set_span_output_data(span, result, result_data_key, handler_type)
return result
def _sync_handler_wrapper(handler_type, func, original_args):
# type: (str, Callable[..., Any], tuple[Any, ...]) -> Any
"""
Sync wrapper for MCP handlers.
Args:
handler_type: "tool", "prompt", or "resource"
func: The sync handler function to wrap
original_args: Original arguments passed to the handler
"""
(
handler_name,
arguments,
span_data_key,
span_name,
mcp_method_name,
result_data_key,
) = _prepare_handler_data(handler_type, original_args)
# Start span and execute
with get_start_span_function()(
op=OP.MCP_SERVER,
name=span_name,
origin=MCPIntegration.origin,
) as span:
# Get request ID, session ID, and transport from context
request_id, session_id, transport = _get_request_context_data()
# Set input span data
_set_span_input_data(
span,
handler_name,
span_data_key,
mcp_method_name,
arguments,
request_id,
session_id,
transport,
)
# For resources, extract and set protocol
if handler_type == "resource":
uri = original_args[0]
protocol = None
if hasattr(uri, "scheme"):
protocol = uri.scheme
elif handler_name and "://" in handler_name:
protocol = handler_name.split("://")[0]
if protocol:
span.set_data(SPANDATA.MCP_RESOURCE_PROTOCOL, protocol)
try:
# Execute the sync handler
result = func(*original_args)
except Exception as e:
# Set error flag for tools
if handler_type == "tool":
span.set_data(SPANDATA.MCP_TOOL_RESULT_IS_ERROR, True)
sentry_sdk.capture_exception(e)
raise
_set_span_output_data(span, result, result_data_key, handler_type)
return result
def _create_instrumented_handler(handler_type, func):
# type: (str, Callable[..., Any]) -> Callable[..., Any]
"""
Create an instrumented version of a handler function (async or sync).
This function wraps the user's handler with a runtime wrapper that will create
Sentry spans and capture metrics when the handler is actually called.
The wrapper preserves the async/sync nature of the original function, which is
critical for Python's async/await to work correctly.
Args:
handler_type: "tool", "prompt", or "resource" - determines span configuration
func: The handler function to instrument (async or sync)
Returns:
A wrapped version of func that creates Sentry spans on execution
"""
if inspect.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args):
# type: (*Any) -> Any
return await _async_handler_wrapper(handler_type, func, args)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args):
# type: (*Any) -> Any
return _sync_handler_wrapper(handler_type, func, args)
return sync_wrapper
def _create_instrumented_decorator(
original_decorator, handler_type, *decorator_args, **decorator_kwargs
):
# type: (Callable[..., Any], str, *Any, **Any) -> Callable[..., Any]
"""
Create an instrumented version of an MCP decorator.
This function intercepts MCP decorators (like @server.call_tool()) and injects
Sentry instrumentation into the handler registration flow. The returned decorator
will:
1. Receive the user's handler function
2. Wrap it with instrumentation via _create_instrumented_handler
3. Pass the instrumented version to the original MCP decorator
This ensures that when the handler is called at runtime, it's already wrapped
with Sentry spans and metrics collection.
Args:
original_decorator: The original MCP decorator method (e.g., Server.call_tool)
handler_type: "tool", "prompt", or "resource" - determines span configuration
decorator_args: Positional arguments to pass to the original decorator (e.g., self)
decorator_kwargs: Keyword arguments to pass to the original decorator
Returns:
A decorator function that instruments handlers before registering them
"""
def instrumented_decorator(func):
# type: (Callable[..., Any]) -> Callable[..., Any]
# First wrap the handler with instrumentation
instrumented_func = _create_instrumented_handler(handler_type, func)
# Then register it with the original MCP decorator
return original_decorator(*decorator_args, **decorator_kwargs)(
instrumented_func
)
return instrumented_decorator
def _patch_lowlevel_server():
# type: () -> None
"""
Patches the mcp.server.lowlevel.Server class to instrument handler execution.
"""
# Patch call_tool decorator
original_call_tool = Server.call_tool
def patched_call_tool(self, **kwargs):
# type: (Server, **Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]
"""Patched version of Server.call_tool that adds Sentry instrumentation."""
return lambda func: _create_instrumented_decorator(
original_call_tool, "tool", self, **kwargs
)(func)
Server.call_tool = patched_call_tool
# Patch get_prompt decorator
original_get_prompt = Server.get_prompt
def patched_get_prompt(self):
# type: (Server) -> Callable[[Callable[..., Any]], Callable[..., Any]]
"""Patched version of Server.get_prompt that adds Sentry instrumentation."""
return lambda func: _create_instrumented_decorator(
original_get_prompt, "prompt", self
)(func)
Server.get_prompt = patched_get_prompt
# Patch read_resource decorator
original_read_resource = Server.read_resource
def patched_read_resource(self):
# type: (Server) -> Callable[[Callable[..., Any]], Callable[..., Any]]
"""Patched version of Server.read_resource that adds Sentry instrumentation."""
return lambda func: _create_instrumented_decorator(
original_read_resource, "resource", self
)(func)
Server.read_resource = patched_read_resource
@@ -3,13 +3,19 @@ from functools import wraps
import sentry_sdk
from sentry_sdk import consts
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.ai.utils import set_data_normalized
from sentry_sdk.ai.utils import (
set_data_normalized,
normalize_message_roles,
truncate_and_annotate_messages,
)
from sentry_sdk.consts import SPANDATA
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing_utils import set_span_errored
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
safe_serialize,
)
from typing import TYPE_CHECKING
@@ -19,6 +25,16 @@ if TYPE_CHECKING:
from sentry_sdk.tracing import Span
try:
try:
from openai import NotGiven
except ImportError:
NotGiven = None
try:
from openai import Omit
except ImportError:
Omit = None
from openai.resources.chat.completions import Completions, AsyncCompletions
from openai.resources import Embeddings, AsyncEmbeddings
@@ -27,6 +43,14 @@ try:
except ImportError:
raise DidNotEnable("OpenAI not installed")
RESPONSES_API_ENABLED = True
try:
# responses API support was introduced in v1.66.0
from openai.resources.responses import Responses, AsyncResponses
from openai.types.responses.response_completed_event import ResponseCompletedEvent
except ImportError:
RESPONSES_API_ENABLED = False
class OpenAIIntegration(Integration):
identifier = "openai"
@@ -46,13 +70,17 @@ class OpenAIIntegration(Integration):
def setup_once():
# type: () -> None
Completions.create = _wrap_chat_completion_create(Completions.create)
Embeddings.create = _wrap_embeddings_create(Embeddings.create)
AsyncCompletions.create = _wrap_async_chat_completion_create(
AsyncCompletions.create
)
Embeddings.create = _wrap_embeddings_create(Embeddings.create)
AsyncEmbeddings.create = _wrap_async_embeddings_create(AsyncEmbeddings.create)
if RESPONSES_API_ENABLED:
Responses.create = _wrap_responses_create(Responses.create)
AsyncResponses.create = _wrap_async_responses_create(AsyncResponses.create)
def count_tokens(self, s):
# type: (OpenAIIntegration, str) -> int
if self.tiktoken_encoding is not None:
@@ -60,8 +88,16 @@ class OpenAIIntegration(Integration):
return 0
def _capture_exception(exc):
# type: (Any) -> None
def _capture_exception(exc, manual_span_cleanup=True):
# type: (Any, bool) -> None
# Close an eventually open span
# We need to do this by hand because we are not using the start_span context manager
current_span = sentry_sdk.get_current_span()
set_span_errored(current_span)
if manual_span_cleanup and current_span is not None:
current_span.__exit__(None, None, None)
event, hint = event_from_exception(
exc,
client_options=sentry_sdk.get_client().options,
@@ -70,52 +106,316 @@ def _capture_exception(exc):
sentry_sdk.capture_event(event, hint=hint)
def _calculate_chat_completion_usage(
def _get_usage(usage, names):
# type: (Any, List[str]) -> int
for name in names:
if hasattr(usage, name) and isinstance(getattr(usage, name), int):
return getattr(usage, name)
return 0
def _calculate_token_usage(
messages, response, span, streaming_message_responses, count_tokens
):
# type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]], Callable[..., Any]) -> None
completion_tokens = 0 # type: Optional[int]
prompt_tokens = 0 # type: Optional[int]
# type: (Optional[Iterable[ChatCompletionMessageParam]], Any, Span, Optional[List[str]], Callable[..., Any]) -> None
input_tokens = 0 # type: Optional[int]
input_tokens_cached = 0 # type: Optional[int]
output_tokens = 0 # type: Optional[int]
output_tokens_reasoning = 0 # type: Optional[int]
total_tokens = 0 # type: Optional[int]
if hasattr(response, "usage"):
if hasattr(response.usage, "completion_tokens") and isinstance(
response.usage.completion_tokens, int
):
completion_tokens = response.usage.completion_tokens
if hasattr(response.usage, "prompt_tokens") and isinstance(
response.usage.prompt_tokens, int
):
prompt_tokens = response.usage.prompt_tokens
if hasattr(response.usage, "total_tokens") and isinstance(
response.usage.total_tokens, int
):
total_tokens = response.usage.total_tokens
input_tokens = _get_usage(response.usage, ["input_tokens", "prompt_tokens"])
if hasattr(response.usage, "input_tokens_details"):
input_tokens_cached = _get_usage(
response.usage.input_tokens_details, ["cached_tokens"]
)
if prompt_tokens == 0:
for message in messages:
if "content" in message:
prompt_tokens += count_tokens(message["content"])
output_tokens = _get_usage(
response.usage, ["output_tokens", "completion_tokens"]
)
if hasattr(response.usage, "output_tokens_details"):
output_tokens_reasoning = _get_usage(
response.usage.output_tokens_details, ["reasoning_tokens"]
)
if completion_tokens == 0:
total_tokens = _get_usage(response.usage, ["total_tokens"])
# Manually count tokens
if input_tokens == 0:
for message in messages or []:
if isinstance(message, dict) and "content" in message:
input_tokens += count_tokens(message["content"])
elif isinstance(message, str):
input_tokens += count_tokens(message)
if output_tokens == 0:
if streaming_message_responses is not None:
for message in streaming_message_responses:
completion_tokens += count_tokens(message)
output_tokens += count_tokens(message)
elif hasattr(response, "choices"):
for choice in response.choices:
if hasattr(choice, "message"):
completion_tokens += count_tokens(choice.message)
output_tokens += count_tokens(choice.message)
if prompt_tokens == 0:
prompt_tokens = None
if completion_tokens == 0:
completion_tokens = None
if total_tokens == 0:
total_tokens = None
record_token_usage(span, prompt_tokens, completion_tokens, total_tokens)
# Do not set token data if it is 0
input_tokens = input_tokens or None
input_tokens_cached = input_tokens_cached or None
output_tokens = output_tokens or None
output_tokens_reasoning = output_tokens_reasoning or None
total_tokens = total_tokens or None
record_token_usage(
span,
input_tokens=input_tokens,
input_tokens_cached=input_tokens_cached,
output_tokens=output_tokens,
output_tokens_reasoning=output_tokens_reasoning,
total_tokens=total_tokens,
)
def _set_input_data(span, kwargs, operation, integration):
# type: (Span, dict[str, Any], str, OpenAIIntegration) -> None
# Input messages (the prompt or data sent to the model)
messages = kwargs.get("messages")
if messages is None:
messages = kwargs.get("input")
if isinstance(messages, str):
messages = [messages]
if (
messages is not None
and len(messages) > 0
and should_send_default_pii()
and integration.include_prompts
):
normalized_messages = normalize_message_roles(messages)
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_messages(normalized_messages, span, scope)
if messages_data is not None:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
)
# Input attributes: Common
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "openai")
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, operation)
# Input attributes: Optional
kwargs_keys_to_attributes = {
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
"stream": SPANDATA.GEN_AI_RESPONSE_STREAMING,
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
}
for key, attribute in kwargs_keys_to_attributes.items():
value = kwargs.get(key)
if value is not None and _is_given(value):
set_data_normalized(span, attribute, value)
# Input attributes: Tools
tools = kwargs.get("tools")
if tools is not None and _is_given(tools) and len(tools) > 0:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
)
def _set_output_data(span, response, kwargs, integration, finish_span=True):
# type: (Span, Any, dict[str, Any], OpenAIIntegration, bool) -> None
if hasattr(response, "model"):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, response.model)
# Input messages (the prompt or data sent to the model)
# used for the token usage calculation
messages = kwargs.get("messages")
if messages is None:
messages = kwargs.get("input")
if messages is not None and isinstance(messages, str):
messages = [messages]
if hasattr(response, "choices"):
if should_send_default_pii() and integration.include_prompts:
response_text = [choice.message.model_dump() for choice in response.choices]
if len(response_text) > 0:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_text)
_calculate_token_usage(messages, response, span, None, integration.count_tokens)
if finish_span:
span.__exit__(None, None, None)
elif hasattr(response, "output"):
if should_send_default_pii() and integration.include_prompts:
output_messages = {
"response": [],
"tool": [],
} # type: (dict[str, list[Any]])
for output in response.output:
if output.type == "function_call":
output_messages["tool"].append(output.dict())
elif output.type == "message":
for output_message in output.content:
try:
output_messages["response"].append(output_message.text)
except AttributeError:
# Unknown output message type, just return the json
output_messages["response"].append(output_message.dict())
if len(output_messages["tool"]) > 0:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
output_messages["tool"],
unpack=False,
)
if len(output_messages["response"]) > 0:
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, output_messages["response"]
)
_calculate_token_usage(messages, response, span, None, integration.count_tokens)
if finish_span:
span.__exit__(None, None, None)
elif hasattr(response, "_iterator"):
data_buf: list[list[str]] = [] # one for each choice
old_iterator = response._iterator
def new_iterator():
# type: () -> Iterator[ChatCompletionChunk]
count_tokens_manually = True
for x in old_iterator:
with capture_internal_exceptions():
# OpenAI chat completion API
if hasattr(x, "choices"):
choice_index = 0
for choice in x.choices:
if hasattr(choice, "delta") and hasattr(
choice.delta, "content"
):
content = choice.delta.content
if len(data_buf) <= choice_index:
data_buf.append([])
data_buf[choice_index].append(content or "")
choice_index += 1
# OpenAI responses API
elif hasattr(x, "delta"):
if len(data_buf) == 0:
data_buf.append([])
data_buf[0].append(x.delta or "")
# OpenAI responses API end of streaming response
if RESPONSES_API_ENABLED and isinstance(x, ResponseCompletedEvent):
_calculate_token_usage(
messages,
x.response,
span,
None,
integration.count_tokens,
)
count_tokens_manually = False
yield x
with capture_internal_exceptions():
if len(data_buf) > 0:
all_responses = ["".join(chunk) for chunk in data_buf]
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
)
if count_tokens_manually:
_calculate_token_usage(
messages,
response,
span,
all_responses,
integration.count_tokens,
)
if finish_span:
span.__exit__(None, None, None)
async def new_iterator_async():
# type: () -> AsyncIterator[ChatCompletionChunk]
count_tokens_manually = True
async for x in old_iterator:
with capture_internal_exceptions():
# OpenAI chat completion API
if hasattr(x, "choices"):
choice_index = 0
for choice in x.choices:
if hasattr(choice, "delta") and hasattr(
choice.delta, "content"
):
content = choice.delta.content
if len(data_buf) <= choice_index:
data_buf.append([])
data_buf[choice_index].append(content or "")
choice_index += 1
# OpenAI responses API
elif hasattr(x, "delta"):
if len(data_buf) == 0:
data_buf.append([])
data_buf[0].append(x.delta or "")
# OpenAI responses API end of streaming response
if RESPONSES_API_ENABLED and isinstance(x, ResponseCompletedEvent):
_calculate_token_usage(
messages,
x.response,
span,
None,
integration.count_tokens,
)
count_tokens_manually = False
yield x
with capture_internal_exceptions():
if len(data_buf) > 0:
all_responses = ["".join(chunk) for chunk in data_buf]
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
)
if count_tokens_manually:
_calculate_token_usage(
messages,
response,
span,
all_responses,
integration.count_tokens,
)
if finish_span:
span.__exit__(None, None, None)
if str(type(response._iterator)) == "<class 'async_generator'>":
response._iterator = new_iterator_async()
else:
response._iterator = new_iterator()
else:
_calculate_token_usage(messages, response, span, None, integration.count_tokens)
if finish_span:
span.__exit__(None, None, None)
def _new_chat_completion_common(f, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
# type: (Any, Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None:
return f(*args, **kwargs)
@@ -130,124 +430,29 @@ def _new_chat_completion_common(f, *args, **kwargs):
# invalid call (in all versions), messages must be iterable
return f(*args, **kwargs)
kwargs["messages"] = list(kwargs["messages"])
messages = kwargs["messages"]
model = kwargs.get("model")
streaming = kwargs.get("stream")
operation = "chat"
span = sentry_sdk.start_span(
op=consts.OP.OPENAI_CHAT_COMPLETIONS_CREATE,
name="Chat Completion",
op=consts.OP.GEN_AI_CHAT,
name=f"{operation} {model}",
origin=OpenAIIntegration.origin,
)
span.__enter__()
res = yield f, args, kwargs
_set_input_data(span, kwargs, operation, integration)
with capture_internal_exceptions():
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, messages)
response = yield f, args, kwargs
set_data_normalized(span, SPANDATA.AI_MODEL_ID, model)
set_data_normalized(span, SPANDATA.AI_STREAMING, streaming)
_set_output_data(span, response, kwargs, integration, finish_span=True)
if hasattr(res, "choices"):
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span,
"ai.responses",
list(map(lambda x: x.message, res.choices)),
)
_calculate_chat_completion_usage(
messages, res, span, None, integration.count_tokens
)
span.__exit__(None, None, None)
elif hasattr(res, "_iterator"):
data_buf: list[list[str]] = [] # one for each choice
old_iterator = res._iterator
def new_iterator():
# type: () -> Iterator[ChatCompletionChunk]
with capture_internal_exceptions():
for x in old_iterator:
if hasattr(x, "choices"):
choice_index = 0
for choice in x.choices:
if hasattr(choice, "delta") and hasattr(
choice.delta, "content"
):
content = choice.delta.content
if len(data_buf) <= choice_index:
data_buf.append([])
data_buf[choice_index].append(content or "")
choice_index += 1
yield x
if len(data_buf) > 0:
all_responses = list(
map(lambda chunk: "".join(chunk), data_buf)
)
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span, SPANDATA.AI_RESPONSES, all_responses
)
_calculate_chat_completion_usage(
messages,
res,
span,
all_responses,
integration.count_tokens,
)
span.__exit__(None, None, None)
async def new_iterator_async():
# type: () -> AsyncIterator[ChatCompletionChunk]
with capture_internal_exceptions():
async for x in old_iterator:
if hasattr(x, "choices"):
choice_index = 0
for choice in x.choices:
if hasattr(choice, "delta") and hasattr(
choice.delta, "content"
):
content = choice.delta.content
if len(data_buf) <= choice_index:
data_buf.append([])
data_buf[choice_index].append(content or "")
choice_index += 1
yield x
if len(data_buf) > 0:
all_responses = list(
map(lambda chunk: "".join(chunk), data_buf)
)
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span, SPANDATA.AI_RESPONSES, all_responses
)
_calculate_chat_completion_usage(
messages,
res,
span,
all_responses,
integration.count_tokens,
)
span.__exit__(None, None, None)
if str(type(res._iterator)) == "<class 'async_generator'>":
res._iterator = new_iterator_async()
else:
res._iterator = new_iterator()
else:
set_data_normalized(span, "unknown_response", True)
span.__exit__(None, None, None)
return res
return response
def _wrap_chat_completion_create(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
def _execute_sync(f, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
# type: (Any, Any, Any) -> Any
gen = _new_chat_completion_common(f, *args, **kwargs)
try:
@@ -268,7 +473,7 @@ def _wrap_chat_completion_create(f):
@wraps(f)
def _sentry_patched_create_sync(*args, **kwargs):
# type: (*Any, **Any) -> Any
# type: (Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None or "messages" not in kwargs:
# no "messages" means invalid call (in all versions of openai), let it return error
@@ -282,7 +487,7 @@ def _wrap_chat_completion_create(f):
def _wrap_async_chat_completion_create(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
async def _execute_async(f, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
# type: (Any, Any, Any) -> Any
gen = _new_chat_completion_common(f, *args, **kwargs)
try:
@@ -303,7 +508,7 @@ def _wrap_async_chat_completion_create(f):
@wraps(f)
async def _sentry_patched_create_async(*args, **kwargs):
# type: (*Any, **Any) -> Any
# type: (Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None or "messages" not in kwargs:
# no "messages" means invalid call (in all versions of openai), let it return error
@@ -315,48 +520,24 @@ def _wrap_async_chat_completion_create(f):
def _new_embeddings_create_common(f, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
# type: (Any, Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None:
return f(*args, **kwargs)
model = kwargs.get("model")
operation = "embeddings"
with sentry_sdk.start_span(
op=consts.OP.OPENAI_EMBEDDINGS_CREATE,
description="OpenAI Embedding Creation",
op=consts.OP.GEN_AI_EMBEDDINGS,
name=f"{operation} {model}",
origin=OpenAIIntegration.origin,
) as span:
if "input" in kwargs and (
should_send_default_pii() and integration.include_prompts
):
if isinstance(kwargs["input"], str):
set_data_normalized(span, "ai.input_messages", [kwargs["input"]])
elif (
isinstance(kwargs["input"], list)
and len(kwargs["input"]) > 0
and isinstance(kwargs["input"][0], str)
):
set_data_normalized(span, "ai.input_messages", kwargs["input"])
if "model" in kwargs:
set_data_normalized(span, "ai.model_id", kwargs["model"])
_set_input_data(span, kwargs, operation, integration)
response = yield f, args, kwargs
prompt_tokens = 0
total_tokens = 0
if hasattr(response, "usage"):
if hasattr(response.usage, "prompt_tokens") and isinstance(
response.usage.prompt_tokens, int
):
prompt_tokens = response.usage.prompt_tokens
if hasattr(response.usage, "total_tokens") and isinstance(
response.usage.total_tokens, int
):
total_tokens = response.usage.total_tokens
if prompt_tokens == 0:
prompt_tokens = integration.count_tokens(kwargs["input"] or "")
record_token_usage(span, prompt_tokens, None, total_tokens or prompt_tokens)
_set_output_data(span, response, kwargs, integration, finish_span=False)
return response
@@ -364,9 +545,102 @@ def _new_embeddings_create_common(f, *args, **kwargs):
def _wrap_embeddings_create(f):
# type: (Any) -> Any
def _execute_sync(f, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
# type: (Any, Any, Any) -> Any
gen = _new_embeddings_create_common(f, *args, **kwargs)
try:
f, args, kwargs = next(gen)
except StopIteration as e:
return e.value
try:
try:
result = f(*args, **kwargs)
except Exception as e:
_capture_exception(e, manual_span_cleanup=False)
raise e from None
return gen.send(result)
except StopIteration as e:
return e.value
@wraps(f)
def _sentry_patched_create_sync(*args, **kwargs):
# type: (Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None:
return f(*args, **kwargs)
return _execute_sync(f, *args, **kwargs)
return _sentry_patched_create_sync
def _wrap_async_embeddings_create(f):
# type: (Any) -> Any
async def _execute_async(f, *args, **kwargs):
# type: (Any, Any, Any) -> Any
gen = _new_embeddings_create_common(f, *args, **kwargs)
try:
f, args, kwargs = next(gen)
except StopIteration as e:
return await e.value
try:
try:
result = await f(*args, **kwargs)
except Exception as e:
_capture_exception(e, manual_span_cleanup=False)
raise e from None
return gen.send(result)
except StopIteration as e:
return e.value
@wraps(f)
async def _sentry_patched_create_async(*args, **kwargs):
# type: (Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None:
return await f(*args, **kwargs)
return await _execute_async(f, *args, **kwargs)
return _sentry_patched_create_async
def _new_responses_create_common(f, *args, **kwargs):
# type: (Any, Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None:
return f(*args, **kwargs)
model = kwargs.get("model")
operation = "responses"
span = sentry_sdk.start_span(
op=consts.OP.GEN_AI_RESPONSES,
name=f"{operation} {model}",
origin=OpenAIIntegration.origin,
)
span.__enter__()
_set_input_data(span, kwargs, operation, integration)
response = yield f, args, kwargs
_set_output_data(span, response, kwargs, integration, finish_span=True)
return response
def _wrap_responses_create(f):
# type: (Any) -> Any
def _execute_sync(f, *args, **kwargs):
# type: (Any, Any, Any) -> Any
gen = _new_responses_create_common(f, *args, **kwargs)
try:
f, args, kwargs = next(gen)
except StopIteration as e:
@@ -385,7 +659,7 @@ def _wrap_embeddings_create(f):
@wraps(f)
def _sentry_patched_create_sync(*args, **kwargs):
# type: (*Any, **Any) -> Any
# type: (Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None:
return f(*args, **kwargs)
@@ -395,11 +669,11 @@ def _wrap_embeddings_create(f):
return _sentry_patched_create_sync
def _wrap_async_embeddings_create(f):
def _wrap_async_responses_create(f):
# type: (Any) -> Any
async def _execute_async(f, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
gen = _new_embeddings_create_common(f, *args, **kwargs)
# type: (Any, Any, Any) -> Any
gen = _new_responses_create_common(f, *args, **kwargs)
try:
f, args, kwargs = next(gen)
@@ -418,12 +692,24 @@ def _wrap_async_embeddings_create(f):
return e.value
@wraps(f)
async def _sentry_patched_create_async(*args, **kwargs):
# type: (*Any, **Any) -> Any
async def _sentry_patched_responses_async(*args, **kwargs):
# type: (Any, Any) -> Any
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None:
return await f(*args, **kwargs)
return await _execute_async(f, *args, **kwargs)
return _sentry_patched_create_async
return _sentry_patched_responses_async
def _is_given(obj):
# type: (Any) -> bool
"""
Check for givenness safely across different openai versions.
"""
if NotGiven is not None and isinstance(obj, NotGiven):
return False
if Omit is not None and isinstance(obj, Omit):
return False
return True
@@ -0,0 +1,55 @@
from sentry_sdk.integrations import DidNotEnable, Integration
from .patches import (
_create_get_model_wrapper,
_create_get_all_tools_wrapper,
_create_run_wrapper,
_patch_agent_run,
_patch_error_tracing,
)
try:
import agents
except ImportError:
raise DidNotEnable("OpenAI Agents not installed")
def _patch_runner():
# type: () -> None
# Create the root span for one full agent run (including eventual handoffs)
# Note agents.run.DEFAULT_AGENT_RUNNER.run_sync is a wrapper around
# agents.run.DEFAULT_AGENT_RUNNER.run. It does not need to be wrapped separately.
# TODO-anton: Also patch streaming runner: agents.Runner.run_streamed
agents.run.DEFAULT_AGENT_RUNNER.run = _create_run_wrapper(
agents.run.DEFAULT_AGENT_RUNNER.run
)
# Creating the actual spans for each agent run.
_patch_agent_run()
def _patch_model():
# type: () -> None
agents.run.AgentRunner._get_model = classmethod(
_create_get_model_wrapper(agents.run.AgentRunner._get_model),
)
def _patch_tools():
# type: () -> None
agents.run.AgentRunner._get_all_tools = classmethod(
_create_get_all_tools_wrapper(agents.run.AgentRunner._get_all_tools),
)
class OpenAIAgentsIntegration(Integration):
identifier = "openai_agents"
@staticmethod
def setup_once():
# type: () -> None
_patch_error_tracing()
_patch_tools()
_patch_model()
_patch_runner()
@@ -0,0 +1 @@
SPAN_ORIGIN = "auto.ai.openai_agents"
@@ -0,0 +1,5 @@
from .models import _create_get_model_wrapper # noqa: F401
from .tools import _create_get_all_tools_wrapper # noqa: F401
from .runner import _create_run_wrapper # noqa: F401
from .agent_run import _patch_agent_run # noqa: F401
from .error_tracing import _patch_error_tracing # noqa: F401
@@ -0,0 +1,140 @@
from functools import wraps
from sentry_sdk.integrations import DidNotEnable
from ..spans import invoke_agent_span, update_invoke_agent_span, handoff_span
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, Optional
try:
import agents
except ImportError:
raise DidNotEnable("OpenAI Agents not installed")
def _patch_agent_run():
# type: () -> None
"""
Patches AgentRunner methods to create agent invocation spans.
This directly patches the execution flow to track when agents start and stop.
"""
# Store original methods
original_run_single_turn = agents.run.AgentRunner._run_single_turn
original_execute_handoffs = agents._run_impl.RunImpl.execute_handoffs
original_execute_final_output = agents._run_impl.RunImpl.execute_final_output
def _start_invoke_agent_span(context_wrapper, agent, kwargs):
# type: (agents.RunContextWrapper, agents.Agent, dict[str, Any]) -> None
"""Start an agent invocation span"""
# Store the agent on the context wrapper so we can access it later
context_wrapper._sentry_current_agent = agent
invoke_agent_span(context_wrapper, agent, kwargs)
def _end_invoke_agent_span(context_wrapper, agent, output=None):
# type: (agents.RunContextWrapper, agents.Agent, Optional[Any]) -> None
"""End the agent invocation span"""
# Clear the stored agent
if hasattr(context_wrapper, "_sentry_current_agent"):
delattr(context_wrapper, "_sentry_current_agent")
update_invoke_agent_span(context_wrapper, agent, output)
def _has_active_agent_span(context_wrapper):
# type: (agents.RunContextWrapper) -> bool
"""Check if there's an active agent span for this context"""
return getattr(context_wrapper, "_sentry_current_agent", None) is not None
def _get_current_agent(context_wrapper):
# type: (agents.RunContextWrapper) -> Optional[agents.Agent]
"""Get the current agent from context wrapper"""
return getattr(context_wrapper, "_sentry_current_agent", None)
@wraps(
original_run_single_turn.__func__
if hasattr(original_run_single_turn, "__func__")
else original_run_single_turn
)
async def patched_run_single_turn(cls, *args, **kwargs):
# type: (agents.Runner, *Any, **Any) -> Any
"""Patched _run_single_turn that creates agent invocation spans"""
agent = kwargs.get("agent")
context_wrapper = kwargs.get("context_wrapper")
should_run_agent_start_hooks = kwargs.get("should_run_agent_start_hooks")
# Start agent span when agent starts (but only once per agent)
if should_run_agent_start_hooks and agent and context_wrapper:
# End any existing span for a different agent
if _has_active_agent_span(context_wrapper):
current_agent = _get_current_agent(context_wrapper)
if current_agent and current_agent != agent:
_end_invoke_agent_span(context_wrapper, current_agent)
_start_invoke_agent_span(context_wrapper, agent, kwargs)
# Call original method with all the correct parameters
result = await original_run_single_turn(*args, **kwargs)
return result
@wraps(
original_execute_handoffs.__func__
if hasattr(original_execute_handoffs, "__func__")
else original_execute_handoffs
)
async def patched_execute_handoffs(cls, *args, **kwargs):
# type: (agents.Runner, *Any, **Any) -> Any
"""Patched execute_handoffs that creates handoff spans and ends agent span for handoffs"""
context_wrapper = kwargs.get("context_wrapper")
run_handoffs = kwargs.get("run_handoffs")
agent = kwargs.get("agent")
# Create Sentry handoff span for the first handoff (agents library only processes the first one)
if run_handoffs:
first_handoff = run_handoffs[0]
handoff_agent_name = first_handoff.handoff.agent_name
handoff_span(context_wrapper, agent, handoff_agent_name)
# Call original method with all parameters
try:
result = await original_execute_handoffs(*args, **kwargs)
finally:
# End span for current agent after handoff processing is complete
if agent and context_wrapper and _has_active_agent_span(context_wrapper):
_end_invoke_agent_span(context_wrapper, agent)
return result
@wraps(
original_execute_final_output.__func__
if hasattr(original_execute_final_output, "__func__")
else original_execute_final_output
)
async def patched_execute_final_output(cls, *args, **kwargs):
# type: (agents.Runner, *Any, **Any) -> Any
"""Patched execute_final_output that ends agent span for final outputs"""
agent = kwargs.get("agent")
context_wrapper = kwargs.get("context_wrapper")
final_output = kwargs.get("final_output")
# Call original method with all parameters
try:
result = await original_execute_final_output(*args, **kwargs)
finally:
# End span for current agent after final output processing is complete
if agent and context_wrapper and _has_active_agent_span(context_wrapper):
_end_invoke_agent_span(context_wrapper, agent, final_output)
return result
# Apply patches
agents.run.AgentRunner._run_single_turn = classmethod(patched_run_single_turn)
agents._run_impl.RunImpl.execute_handoffs = classmethod(patched_execute_handoffs)
agents._run_impl.RunImpl.execute_final_output = classmethod(
patched_execute_final_output
)
@@ -0,0 +1,77 @@
from functools import wraps
import sentry_sdk
from sentry_sdk.consts import SPANSTATUS
from sentry_sdk.tracing_utils import set_span_errored
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, Callable, Optional
def _patch_error_tracing():
# type: () -> None
"""
Patches agents error tracing function to inject our span error logic
when a tool execution fails.
In newer versions, the function is at: agents.util._error_tracing.attach_error_to_current_span
In older versions, it was at: agents._utils.attach_error_to_current_span
This works even when the module or function doesn't exist.
"""
error_tracing_module = None
# Try newer location first (agents.util._error_tracing)
try:
from agents.util import _error_tracing
error_tracing_module = _error_tracing
except (ImportError, AttributeError):
pass
# Try older location (agents._utils)
if error_tracing_module is None:
try:
import agents._utils
error_tracing_module = agents._utils
except (ImportError, AttributeError):
# Module doesn't exist in either location, nothing to patch
return
# Check if the function exists
if not hasattr(error_tracing_module, "attach_error_to_current_span"):
return
original_attach_error = error_tracing_module.attach_error_to_current_span
@wraps(original_attach_error)
def sentry_attach_error_to_current_span(error, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
"""
Wraps agents' error attachment to also set Sentry span status to error.
This allows us to properly track tool execution errors even though
the agents library swallows exceptions.
"""
# Set the current Sentry span to errored
current_span = sentry_sdk.get_current_span()
if current_span is not None:
set_span_errored(current_span)
current_span.set_data("span.status", "error")
# Optionally capture the error details if we have them
if hasattr(error, "__class__"):
current_span.set_data("error.type", error.__class__.__name__)
if hasattr(error, "__str__"):
error_message = str(error)
if error_message:
current_span.set_data("error.message", error_message)
# Call the original function
return original_attach_error(error, *args, **kwargs)
error_tracing_module.attach_error_to_current_span = (
sentry_attach_error_to_current_span
)
@@ -0,0 +1,50 @@
from functools import wraps
from sentry_sdk.integrations import DidNotEnable
from ..spans import ai_client_span, update_ai_client_span
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, Callable
try:
import agents
except ImportError:
raise DidNotEnable("OpenAI Agents not installed")
def _create_get_model_wrapper(original_get_model):
# type: (Callable[..., Any]) -> Callable[..., Any]
"""
Wraps the agents.Runner._get_model method to wrap the get_response method of the model to create a AI client span.
"""
@wraps(
original_get_model.__func__
if hasattr(original_get_model, "__func__")
else original_get_model
)
def wrapped_get_model(cls, agent, run_config):
# type: (agents.Runner, agents.Agent, agents.RunConfig) -> agents.Model
model = original_get_model(agent, run_config)
original_get_response = model.get_response
@wraps(original_get_response)
async def wrapped_get_response(*args, **kwargs):
# type: (*Any, **Any) -> Any
with ai_client_span(agent, kwargs) as span:
result = await original_get_response(*args, **kwargs)
update_ai_client_span(span, agent, kwargs, result)
return result
model.get_response = wrapped_get_response
return model
return wrapped_get_model
@@ -0,0 +1,45 @@
from functools import wraps
import sentry_sdk
from ..spans import agent_workflow_span
from ..utils import _capture_exception
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, Callable
def _create_run_wrapper(original_func):
# type: (Callable[..., Any]) -> Callable[..., Any]
"""
Wraps the agents.Runner.run methods to create a root span for the agent workflow runs.
Note agents.Runner.run_sync() is a wrapper around agents.Runner.run(),
so it does not need to be wrapped separately.
"""
@wraps(original_func)
async def wrapper(*args, **kwargs):
# type: (*Any, **Any) -> Any
# Isolate each workflow so that when agents are run in asyncio tasks they
# don't touch each other's scopes
with sentry_sdk.isolation_scope():
agent = args[0]
with agent_workflow_span(agent):
result = None
try:
result = await original_func(*args, **kwargs)
return result
except Exception as exc:
_capture_exception(exc)
# It could be that there is a "invoke agent" span still open
current_span = sentry_sdk.get_current_span()
if current_span is not None and current_span.timestamp is None:
current_span.__exit__(None, None, None)
raise exc from None
return wrapper
@@ -0,0 +1,77 @@
from functools import wraps
from sentry_sdk.integrations import DidNotEnable
from ..spans import execute_tool_span, update_execute_tool_span
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, Callable
try:
import agents
except ImportError:
raise DidNotEnable("OpenAI Agents not installed")
def _create_get_all_tools_wrapper(original_get_all_tools):
# type: (Callable[..., Any]) -> Callable[..., Any]
"""
Wraps the agents.Runner._get_all_tools method of the Runner class to wrap all function tools with Sentry instrumentation.
"""
@wraps(
original_get_all_tools.__func__
if hasattr(original_get_all_tools, "__func__")
else original_get_all_tools
)
async def wrapped_get_all_tools(cls, agent, context_wrapper):
# type: (agents.Runner, agents.Agent, agents.RunContextWrapper) -> list[agents.Tool]
# Get the original tools
tools = await original_get_all_tools(agent, context_wrapper)
wrapped_tools = []
for tool in tools:
# Wrap only the function tools (for now)
if tool.__class__.__name__ != "FunctionTool":
wrapped_tools.append(tool)
continue
# Create a new FunctionTool with our wrapped invoke method
original_on_invoke = tool.on_invoke_tool
def create_wrapped_invoke(current_tool, current_on_invoke):
# type: (agents.Tool, Callable[..., Any]) -> Callable[..., Any]
@wraps(current_on_invoke)
async def sentry_wrapped_on_invoke_tool(*args, **kwargs):
# type: (*Any, **Any) -> Any
with execute_tool_span(current_tool, *args, **kwargs) as span:
# We can not capture exceptions in tool execution here because
# `_on_invoke_tool` is swallowing the exception here:
# https://github.com/openai/openai-agents-python/blob/main/src/agents/tool.py#L409-L422
# And because function_tool is a decorator with `default_tool_error_function` set as a default parameter
# I was unable to monkey patch it because those are evaluated at module import time
# and the SDK is too late to patch it. I was also unable to patch `_on_invoke_tool_impl`
# because it is nested inside this import time code. As if they made it hard to patch on purpose...
result = await current_on_invoke(*args, **kwargs)
update_execute_tool_span(span, agent, current_tool, result)
return result
return sentry_wrapped_on_invoke_tool
wrapped_tool = agents.FunctionTool(
name=tool.name,
description=tool.description,
params_json_schema=tool.params_json_schema,
on_invoke_tool=create_wrapped_invoke(tool, original_on_invoke),
strict_json_schema=tool.strict_json_schema,
is_enabled=tool.is_enabled,
)
wrapped_tools.append(wrapped_tool)
return wrapped_tools
return wrapped_get_all_tools
@@ -0,0 +1,5 @@
from .agent_workflow import agent_workflow_span # noqa: F401
from .ai_client import ai_client_span, update_ai_client_span # noqa: F401
from .execute_tool import execute_tool_span, update_execute_tool_span # noqa: F401
from .handoff import handoff_span # noqa: F401
from .invoke_agent import invoke_agent_span, update_invoke_agent_span # noqa: F401
@@ -0,0 +1,21 @@
import sentry_sdk
from sentry_sdk.ai.utils import get_start_span_function
from ..consts import SPAN_ORIGIN
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import agents
def agent_workflow_span(agent):
# type: (agents.Agent) -> sentry_sdk.tracing.Span
# Create a transaction or a span if an transaction is already active
span = get_start_span_function()(
name=f"{agent.name} workflow",
origin=SPAN_ORIGIN,
)
return span
@@ -0,0 +1,42 @@
import sentry_sdk
from sentry_sdk.consts import OP, SPANDATA
from ..consts import SPAN_ORIGIN
from ..utils import (
_set_agent_data,
_set_input_data,
_set_output_data,
_set_usage_data,
_create_mcp_execute_tool_spans,
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from agents import Agent
from typing import Any
def ai_client_span(agent, get_response_kwargs):
# type: (Agent, dict[str, Any]) -> sentry_sdk.tracing.Span
# TODO-anton: implement other types of operations. Now "chat" is hardcoded.
model_name = agent.model.model if hasattr(agent.model, "model") else agent.model
span = sentry_sdk.start_span(
op=OP.GEN_AI_CHAT,
description=f"chat {model_name}",
origin=SPAN_ORIGIN,
)
# TODO-anton: remove hardcoded stuff and replace something that also works for embedding and so on
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
_set_agent_data(span, agent)
return span
def update_ai_client_span(span, agent, get_response_kwargs, result):
# type: (sentry_sdk.tracing.Span, Agent, dict[str, Any], Any) -> None
_set_usage_data(span, result.usage)
_set_input_data(span, get_response_kwargs)
_set_output_data(span, result)
_create_mcp_execute_tool_spans(span, result)
@@ -0,0 +1,48 @@
import sentry_sdk
from sentry_sdk.consts import OP, SPANDATA, SPANSTATUS
from sentry_sdk.scope import should_send_default_pii
from ..consts import SPAN_ORIGIN
from ..utils import _set_agent_data
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import agents
from typing import Any
def execute_tool_span(tool, *args, **kwargs):
# type: (agents.Tool, *Any, **Any) -> sentry_sdk.tracing.Span
span = sentry_sdk.start_span(
op=OP.GEN_AI_EXECUTE_TOOL,
name=f"execute_tool {tool.name}",
origin=SPAN_ORIGIN,
)
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "execute_tool")
if tool.__class__.__name__ == "FunctionTool":
span.set_data(SPANDATA.GEN_AI_TOOL_TYPE, "function")
span.set_data(SPANDATA.GEN_AI_TOOL_NAME, tool.name)
span.set_data(SPANDATA.GEN_AI_TOOL_DESCRIPTION, tool.description)
if should_send_default_pii():
input = args[1]
span.set_data(SPANDATA.GEN_AI_TOOL_INPUT, input)
return span
def update_execute_tool_span(span, agent, tool, result):
# type: (sentry_sdk.tracing.Span, agents.Agent, agents.Tool, Any) -> None
_set_agent_data(span, agent)
if isinstance(result, str) and result.startswith(
"An error occurred while running the tool"
):
span.set_status(SPANSTATUS.ERROR)
if should_send_default_pii():
span.set_data(SPANDATA.GEN_AI_TOOL_OUTPUT, result)
@@ -0,0 +1,19 @@
import sentry_sdk
from sentry_sdk.consts import OP, SPANDATA
from ..consts import SPAN_ORIGIN
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import agents
def handoff_span(context, from_agent, to_agent_name):
# type: (agents.RunContextWrapper, agents.Agent, str) -> None
with sentry_sdk.start_span(
op=OP.GEN_AI_HANDOFF,
name=f"handoff from {from_agent.name} to {to_agent_name}",
origin=SPAN_ORIGIN,
) as span:
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "handoff")
@@ -0,0 +1,86 @@
import sentry_sdk
from sentry_sdk.ai.utils import (
get_start_span_function,
set_data_normalized,
normalize_message_roles,
)
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.utils import safe_serialize
from ..consts import SPAN_ORIGIN
from ..utils import _set_agent_data
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import agents
from typing import Any
def invoke_agent_span(context, agent, kwargs):
# type: (agents.RunContextWrapper, agents.Agent, dict[str, Any]) -> sentry_sdk.tracing.Span
start_span_function = get_start_span_function()
span = start_span_function(
op=OP.GEN_AI_INVOKE_AGENT,
name=f"invoke_agent {agent.name}",
origin=SPAN_ORIGIN,
)
span.__enter__()
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
if should_send_default_pii():
messages = []
if agent.instructions:
message = (
agent.instructions
if isinstance(agent.instructions, str)
else safe_serialize(agent.instructions)
)
messages.append(
{
"content": [{"text": message, "type": "text"}],
"role": "system",
}
)
original_input = kwargs.get("original_input")
if original_input is not None:
message = (
original_input
if isinstance(original_input, str)
else safe_serialize(original_input)
)
messages.append(
{
"content": [{"text": message, "type": "text"}],
"role": "user",
}
)
if len(messages) > 0:
normalized_messages = normalize_message_roles(messages)
set_data_normalized(
span,
SPANDATA.GEN_AI_REQUEST_MESSAGES,
normalized_messages,
unpack=False,
)
_set_agent_data(span, agent)
return span
def update_invoke_agent_span(context, agent, output):
# type: (agents.RunContextWrapper, agents.Agent, Any) -> None
span = sentry_sdk.get_current_span()
if span:
if should_send_default_pii():
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, output, unpack=False
)
span.__exit__(None, None, None)
@@ -0,0 +1,199 @@
import sentry_sdk
from sentry_sdk.ai.utils import (
GEN_AI_ALLOWED_MESSAGE_ROLES,
normalize_message_roles,
set_data_normalized,
normalize_message_role,
)
from sentry_sdk.consts import SPANDATA, SPANSTATUS, OP
from sentry_sdk.integrations import DidNotEnable
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing_utils import set_span_errored
from sentry_sdk.utils import event_from_exception, safe_serialize
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
from agents import Usage
try:
import agents
except ImportError:
raise DidNotEnable("OpenAI Agents not installed")
def _capture_exception(exc):
# type: (Any) -> None
set_span_errored()
event, hint = event_from_exception(
exc,
client_options=sentry_sdk.get_client().options,
mechanism={"type": "openai_agents", "handled": False},
)
sentry_sdk.capture_event(event, hint=hint)
def _set_agent_data(span, agent):
# type: (sentry_sdk.tracing.Span, agents.Agent) -> None
span.set_data(
SPANDATA.GEN_AI_SYSTEM, "openai"
) # See footnote for https://opentelemetry.io/docs/specs/semconv/registry/attributes/gen-ai/#gen-ai-system for explanation why.
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent.name)
if agent.model_settings.max_tokens:
span.set_data(
SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, agent.model_settings.max_tokens
)
if agent.model:
model_name = agent.model.model if hasattr(agent.model, "model") else agent.model
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
if agent.model_settings.presence_penalty:
span.set_data(
SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
agent.model_settings.presence_penalty,
)
if agent.model_settings.temperature:
span.set_data(
SPANDATA.GEN_AI_REQUEST_TEMPERATURE, agent.model_settings.temperature
)
if agent.model_settings.top_p:
span.set_data(SPANDATA.GEN_AI_REQUEST_TOP_P, agent.model_settings.top_p)
if agent.model_settings.frequency_penalty:
span.set_data(
SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
agent.model_settings.frequency_penalty,
)
if len(agent.tools) > 0:
span.set_data(
SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
safe_serialize([vars(tool) for tool in agent.tools]),
)
def _set_usage_data(span, usage):
# type: (sentry_sdk.tracing.Span, Usage) -> None
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage.input_tokens)
span.set_data(
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
usage.input_tokens_details.cached_tokens,
)
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage.output_tokens)
span.set_data(
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
usage.output_tokens_details.reasoning_tokens,
)
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage.total_tokens)
def _set_input_data(span, get_response_kwargs):
# type: (sentry_sdk.tracing.Span, dict[str, Any]) -> None
if not should_send_default_pii():
return
request_messages = []
system_instructions = get_response_kwargs.get("system_instructions")
if system_instructions:
request_messages.append(
{
"role": GEN_AI_ALLOWED_MESSAGE_ROLES.SYSTEM,
"content": [{"type": "text", "text": system_instructions}],
}
)
for message in get_response_kwargs.get("input", []):
if "role" in message:
normalized_role = normalize_message_role(message.get("role"))
request_messages.append(
{
"role": normalized_role,
"content": [{"type": "text", "text": message.get("content")}],
}
)
else:
if message.get("type") == "function_call":
request_messages.append(
{
"role": GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT,
"content": [message],
}
)
elif message.get("type") == "function_call_output":
request_messages.append(
{
"role": GEN_AI_ALLOWED_MESSAGE_ROLES.TOOL,
"content": [message],
}
)
set_data_normalized(
span,
SPANDATA.GEN_AI_REQUEST_MESSAGES,
normalize_message_roles(request_messages),
unpack=False,
)
def _set_output_data(span, result):
# type: (sentry_sdk.tracing.Span, Any) -> None
if not should_send_default_pii():
return
output_messages = {
"response": [],
"tool": [],
} # type: (dict[str, list[Any]])
for output in result.output:
if output.type == "function_call":
output_messages["tool"].append(output.dict())
elif output.type == "message":
for output_message in output.content:
try:
output_messages["response"].append(output_message.text)
except AttributeError:
# Unknown output message type, just return the json
output_messages["response"].append(output_message.dict())
if len(output_messages["tool"]) > 0:
span.set_data(
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, safe_serialize(output_messages["tool"])
)
if len(output_messages["response"]) > 0:
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, output_messages["response"]
)
def _create_mcp_execute_tool_spans(span, result):
# type: (sentry_sdk.tracing.Span, agents.Result) -> None
for output in result.output:
if output.__class__.__name__ == "McpCall":
with sentry_sdk.start_span(
op=OP.GEN_AI_EXECUTE_TOOL,
description=f"execute_tool {output.name}",
start_timestamp=span.start_timestamp,
) as execute_tool_span:
set_data_normalized(execute_tool_span, SPANDATA.GEN_AI_TOOL_TYPE, "mcp")
set_data_normalized(
execute_tool_span, SPANDATA.GEN_AI_TOOL_NAME, output.name
)
if should_send_default_pii():
execute_tool_span.set_data(
SPANDATA.GEN_AI_TOOL_INPUT, output.arguments
)
execute_tool_span.set_data(
SPANDATA.GEN_AI_TOOL_OUTPUT, output.output
)
if output.error:
execute_tool_span.set_status(SPANSTATUS.ERROR)
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
import sentry_sdk
from typing import TYPE_CHECKING, Any
from sentry_sdk.feature_flags import add_feature_flag
from sentry_sdk.integrations import DidNotEnable, Integration
try:
@@ -8,7 +8,6 @@ try:
from openfeature.hook import Hook
if TYPE_CHECKING:
from openfeature.flag_evaluation import FlagEvaluationDetails
from openfeature.hook import HookContext, HookHints
except ImportError:
raise DidNotEnable("OpenFeature is not installed")
@@ -25,15 +24,12 @@ class OpenFeatureIntegration(Integration):
class OpenFeatureHook(Hook):
def after(self, hook_context, details, hints):
# type: (HookContext, FlagEvaluationDetails[bool], HookHints) -> None
# type: (Any, Any, Any) -> None
if isinstance(details.value, bool):
flags = sentry_sdk.get_current_scope().flags
flags.set(details.flag_key, details.value)
add_feature_flag(details.flag_key, details.value)
def error(self, hook_context, exception, hints):
# type: (HookContext, Exception, HookHints) -> None
if isinstance(hook_context.default_value, bool):
flags = sentry_sdk.get_current_scope().flags
flags.set(hook_context.flag_key, hook_context.default_value)
add_feature_flag(hook_context.flag_key, hook_context.default_value)
@@ -116,7 +116,9 @@ def pure_eval_frame(frame):
return (n.lineno, n.col_offset)
nodes_before_stmt = [
node for node in nodes if start(node) < stmt.last_token.end # type: ignore
node
for node in nodes
if start(node) < stmt.last_token.end # type: ignore
]
if nodes_before_stmt:
# The position of the last node before or in the statement
@@ -0,0 +1,47 @@
from sentry_sdk.integrations import DidNotEnable, Integration
try:
import pydantic_ai # type: ignore
except ImportError:
raise DidNotEnable("pydantic-ai not installed")
from .patches import (
_patch_agent_run,
_patch_graph_nodes,
_patch_model_request,
_patch_tool_execution,
)
class PydanticAIIntegration(Integration):
identifier = "pydantic_ai"
origin = f"auto.ai.{identifier}"
def __init__(self, include_prompts=True):
# type: (bool) -> None
"""
Initialize the Pydantic AI integration.
Args:
include_prompts: Whether to include prompts and messages in span data.
Requires send_default_pii=True. Defaults to True.
"""
self.include_prompts = include_prompts
@staticmethod
def setup_once():
# type: () -> None
"""
Set up the pydantic-ai integration.
This patches the key methods in pydantic-ai to create Sentry spans for:
- Agent invocations (Agent.run methods)
- Model requests (AI client calls)
- Tool executions
"""
_patch_agent_run()
_patch_graph_nodes()
_patch_model_request()
_patch_tool_execution()
@@ -0,0 +1 @@
SPAN_ORIGIN = "auto.ai.pydantic_ai"
@@ -0,0 +1,4 @@
from .agent_run import _patch_agent_run # noqa: F401
from .graph_nodes import _patch_graph_nodes # noqa: F401
from .model_request import _patch_model_request # noqa: F401
from .tools import _patch_tool_execution # noqa: F401
@@ -0,0 +1,217 @@
from functools import wraps
import sentry_sdk
from sentry_sdk.tracing_utils import set_span_errored
from sentry_sdk.utils import event_from_exception
from ..spans import invoke_agent_span, update_invoke_agent_span
from typing import TYPE_CHECKING
from pydantic_ai.agent import Agent # type: ignore
if TYPE_CHECKING:
from typing import Any, Callable, Optional
def _capture_exception(exc):
# type: (Any) -> None
set_span_errored()
event, hint = event_from_exception(
exc,
client_options=sentry_sdk.get_client().options,
mechanism={"type": "pydantic_ai", "handled": False},
)
sentry_sdk.capture_event(event, hint=hint)
class _StreamingContextManagerWrapper:
"""Wrapper for streaming methods that return async context managers."""
def __init__(
self,
agent,
original_ctx_manager,
user_prompt,
model,
model_settings,
is_streaming=True,
):
# type: (Any, Any, Any, Any, Any, bool) -> None
self.agent = agent
self.original_ctx_manager = original_ctx_manager
self.user_prompt = user_prompt
self.model = model
self.model_settings = model_settings
self.is_streaming = is_streaming
self._isolation_scope = None # type: Any
self._span = None # type: Optional[sentry_sdk.tracing.Span]
self._result = None # type: Any
async def __aenter__(self):
# type: () -> Any
# Set up isolation scope and invoke_agent span
self._isolation_scope = sentry_sdk.isolation_scope()
self._isolation_scope.__enter__()
# Store agent reference and streaming flag
sentry_sdk.get_current_scope().set_context(
"pydantic_ai_agent", {"_agent": self.agent, "_streaming": self.is_streaming}
)
# Create invoke_agent span (will be closed in __aexit__)
self._span = invoke_agent_span(
self.user_prompt, self.agent, self.model, self.model_settings
)
self._span.__enter__()
# Enter the original context manager
result = await self.original_ctx_manager.__aenter__()
self._result = result
return result
async def __aexit__(self, exc_type, exc_val, exc_tb):
# type: (Any, Any, Any) -> None
try:
# Exit the original context manager first
await self.original_ctx_manager.__aexit__(exc_type, exc_val, exc_tb)
# Update span with output if successful
if exc_type is None and self._result and hasattr(self._result, "output"):
output = (
self._result.output if hasattr(self._result, "output") else None
)
if self._span is not None:
update_invoke_agent_span(self._span, output)
finally:
sentry_sdk.get_current_scope().remove_context("pydantic_ai_agent")
# Clean up invoke span
if self._span:
self._span.__exit__(exc_type, exc_val, exc_tb)
# Clean up isolation scope
if self._isolation_scope:
self._isolation_scope.__exit__(exc_type, exc_val, exc_tb)
def _create_run_wrapper(original_func, is_streaming=False):
# type: (Callable[..., Any], bool) -> Callable[..., Any]
"""
Wraps the Agent.run method to create an invoke_agent span.
Args:
original_func: The original run method
is_streaming: Whether this is a streaming method (for future use)
"""
@wraps(original_func)
async def wrapper(self, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
# Isolate each workflow so that when agents are run in asyncio tasks they
# don't touch each other's scopes
with sentry_sdk.isolation_scope():
# Store agent reference and streaming flag in Sentry scope for access in nested spans
# We store the full agent to allow access to tools and system prompts
sentry_sdk.get_current_scope().set_context(
"pydantic_ai_agent", {"_agent": self, "_streaming": is_streaming}
)
# Extract parameters for the span
user_prompt = kwargs.get("user_prompt") or (args[0] if args else None)
model = kwargs.get("model")
model_settings = kwargs.get("model_settings")
# Create invoke_agent span
with invoke_agent_span(user_prompt, self, model, model_settings) as span:
try:
result = await original_func(self, *args, **kwargs)
# Update span with output
output = result.output if hasattr(result, "output") else None
update_invoke_agent_span(span, output)
return result
except Exception as exc:
_capture_exception(exc)
raise exc from None
finally:
sentry_sdk.get_current_scope().remove_context("pydantic_ai_agent")
return wrapper
def _create_streaming_wrapper(original_func):
# type: (Callable[..., Any]) -> Callable[..., Any]
"""
Wraps run_stream method that returns an async context manager.
"""
@wraps(original_func)
def wrapper(self, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
# Extract parameters for the span
user_prompt = kwargs.get("user_prompt") or (args[0] if args else None)
model = kwargs.get("model")
model_settings = kwargs.get("model_settings")
# Call original function to get the context manager
original_ctx_manager = original_func(self, *args, **kwargs)
# Wrap it with our instrumentation
return _StreamingContextManagerWrapper(
agent=self,
original_ctx_manager=original_ctx_manager,
user_prompt=user_prompt,
model=model,
model_settings=model_settings,
is_streaming=True,
)
return wrapper
def _create_streaming_events_wrapper(original_func):
# type: (Callable[..., Any]) -> Callable[..., Any]
"""
Wraps run_stream_events method - no span needed as it delegates to run().
Note: run_stream_events internally calls self.run() with an event_stream_handler,
so the invoke_agent span will be created by the run() wrapper.
"""
@wraps(original_func)
async def wrapper(self, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
# Just call the original generator - it will call run() which has the instrumentation
try:
async for event in original_func(self, *args, **kwargs):
yield event
except Exception as exc:
_capture_exception(exc)
raise exc from None
return wrapper
def _patch_agent_run():
# type: () -> None
"""
Patches the Agent run methods to create spans for agent execution.
This patches both non-streaming (run, run_sync) and streaming
(run_stream, run_stream_events) methods.
"""
# Store original methods
original_run = Agent.run
original_run_stream = Agent.run_stream
original_run_stream_events = Agent.run_stream_events
# Wrap and apply patches for non-streaming methods
Agent.run = _create_run_wrapper(original_run, is_streaming=False)
# Wrap and apply patches for streaming methods
Agent.run_stream = _create_streaming_wrapper(original_run_stream)
Agent.run_stream_events = _create_streaming_events_wrapper(
original_run_stream_events
)
@@ -0,0 +1,105 @@
from contextlib import asynccontextmanager
from functools import wraps
import sentry_sdk
from ..spans import (
ai_client_span,
update_ai_client_span,
)
from pydantic_ai._agent_graph import ModelRequestNode # type: ignore
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, Callable
def _extract_span_data(node, ctx):
# type: (Any, Any) -> tuple[list[Any], Any, Any]
"""Extract common data needed for creating chat spans.
Returns:
Tuple of (messages, model, model_settings)
"""
# Extract model and settings from context
model = None
model_settings = None
if hasattr(ctx, "deps"):
model = getattr(ctx.deps, "model", None)
model_settings = getattr(ctx.deps, "model_settings", None)
# Build full message list: history + current request
messages = []
if hasattr(ctx, "state") and hasattr(ctx.state, "message_history"):
messages.extend(ctx.state.message_history)
current_request = getattr(node, "request", None)
if current_request:
messages.append(current_request)
return messages, model, model_settings
def _patch_graph_nodes():
# type: () -> None
"""
Patches the graph node execution to create appropriate spans.
ModelRequestNode -> Creates ai_client span for model requests
CallToolsNode -> Handles tool calls (spans created in tool patching)
"""
# Patch ModelRequestNode to create ai_client spans
original_model_request_run = ModelRequestNode.run
@wraps(original_model_request_run)
async def wrapped_model_request_run(self, ctx):
# type: (Any, Any) -> Any
messages, model, model_settings = _extract_span_data(self, ctx)
with ai_client_span(messages, None, model, model_settings) as span:
result = await original_model_request_run(self, ctx)
# Extract response from result if available
model_response = None
if hasattr(result, "model_response"):
model_response = result.model_response
update_ai_client_span(span, model_response)
return result
ModelRequestNode.run = wrapped_model_request_run
# Patch ModelRequestNode.stream for streaming requests
original_model_request_stream = ModelRequestNode.stream
def create_wrapped_stream(original_stream_method):
# type: (Callable[..., Any]) -> Callable[..., Any]
"""Create a wrapper for ModelRequestNode.stream that creates chat spans."""
@asynccontextmanager
@wraps(original_stream_method)
async def wrapped_model_request_stream(self, ctx):
# type: (Any, Any) -> Any
messages, model, model_settings = _extract_span_data(self, ctx)
# Create chat span for streaming request
with ai_client_span(messages, None, model, model_settings) as span:
# Call the original stream method
async with original_stream_method(self, ctx) as stream:
yield stream
# After streaming completes, update span with response data
# The ModelRequestNode stores the final response in _result
model_response = None
if hasattr(self, "_result") and self._result is not None:
# _result is a NextNode containing the model_response
if hasattr(self._result, "model_response"):
model_response = self._result.model_response
update_ai_client_span(span, model_response)
return wrapped_model_request_stream
ModelRequestNode.stream = create_wrapped_stream(original_model_request_stream)
@@ -0,0 +1,35 @@
from functools import wraps
from typing import TYPE_CHECKING
from pydantic_ai import models # type: ignore
from ..spans import ai_client_span, update_ai_client_span
if TYPE_CHECKING:
from typing import Any
def _patch_model_request():
# type: () -> None
"""
Patches model request execution to create AI client spans.
In pydantic-ai, model requests are handled through the Model interface.
We need to patch the request method on models to create spans.
"""
# Patch the base Model class's request method
if hasattr(models, "Model"):
original_request = models.Model.request
@wraps(original_request)
async def wrapped_request(self, messages, *args, **kwargs):
# type: (Any, Any, *Any, **Any) -> Any
# Pass all messages (full conversation history)
with ai_client_span(messages, None, self, None) as span:
result = await original_request(self, messages, *args, **kwargs)
update_ai_client_span(span, result)
return result
models.Model.request = wrapped_request
@@ -0,0 +1,75 @@
from functools import wraps
from pydantic_ai._tool_manager import ToolManager # type: ignore
import sentry_sdk
from ..spans import execute_tool_span, update_execute_tool_span
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
try:
from pydantic_ai.mcp import MCPServer # type: ignore
HAS_MCP = True
except ImportError:
HAS_MCP = False
def _patch_tool_execution():
# type: () -> None
"""
Patch ToolManager._call_tool to create execute_tool spans.
This is the single point where ALL tool calls flow through in pydantic_ai,
regardless of toolset type (function, MCP, combined, wrapper, etc.).
By patching here, we avoid:
- Patching multiple toolset classes
- Dealing with signature mismatches from instrumented MCP servers
- Complex nested toolset handling
"""
original_call_tool = ToolManager._call_tool
@wraps(original_call_tool)
async def wrapped_call_tool(self, call, allow_partial, wrap_validation_errors):
# type: (Any, Any, bool, bool) -> Any
# Extract tool info before calling original
name = call.tool_name
tool = self.tools.get(name) if self.tools else None
# Determine tool type by checking tool.toolset
tool_type = "function" # default
if tool and HAS_MCP and isinstance(tool.toolset, MCPServer):
tool_type = "mcp"
# Get agent from Sentry scope
current_span = sentry_sdk.get_current_span()
if current_span and tool:
agent_data = (
sentry_sdk.get_current_scope()._contexts.get("pydantic_ai_agent") or {}
)
agent = agent_data.get("_agent")
# Get args for span (before validation)
# call.args can be a string (JSON) or dict
args_dict = call.args if isinstance(call.args, dict) else {}
with execute_tool_span(name, args_dict, agent, tool_type=tool_type) as span:
result = await original_call_tool(
self, call, allow_partial, wrap_validation_errors
)
update_execute_tool_span(span, result)
return result
# No span context - just call original
return await original_call_tool(
self, call, allow_partial, wrap_validation_errors
)
ToolManager._call_tool = wrapped_call_tool
@@ -0,0 +1,3 @@
from .ai_client import ai_client_span, update_ai_client_span # noqa: F401
from .execute_tool import execute_tool_span, update_execute_tool_span # noqa: F401
from .invoke_agent import invoke_agent_span, update_invoke_agent_span # noqa: F401
@@ -0,0 +1,253 @@
import sentry_sdk
from sentry_sdk.ai.utils import set_data_normalized
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.utils import safe_serialize
from ..consts import SPAN_ORIGIN
from ..utils import (
_set_agent_data,
_set_available_tools,
_set_model_data,
_should_send_prompts,
_get_model_name,
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, List, Dict
from pydantic_ai.usage import RequestUsage # type: ignore
try:
from pydantic_ai.messages import ( # type: ignore
BaseToolCallPart,
BaseToolReturnPart,
SystemPromptPart,
UserPromptPart,
TextPart,
ThinkingPart,
)
except ImportError:
# Fallback if these classes are not available
BaseToolCallPart = None
BaseToolReturnPart = None
SystemPromptPart = None
UserPromptPart = None
TextPart = None
ThinkingPart = None
def _set_usage_data(span, usage):
# type: (sentry_sdk.tracing.Span, RequestUsage) -> None
"""Set token usage data on a span."""
if usage is None:
return
if hasattr(usage, "input_tokens") and usage.input_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage.input_tokens)
if hasattr(usage, "output_tokens") and usage.output_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage.output_tokens)
if hasattr(usage, "total_tokens") and usage.total_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage.total_tokens)
def _set_input_messages(span, messages):
# type: (sentry_sdk.tracing.Span, Any) -> None
"""Set input messages data on a span."""
if not _should_send_prompts():
return
if not messages:
return
try:
formatted_messages = []
system_prompt = None
# Extract system prompt from any ModelRequest with instructions
for msg in messages:
if hasattr(msg, "instructions") and msg.instructions:
system_prompt = msg.instructions
break
# Add system prompt as first message if present
if system_prompt:
formatted_messages.append(
{"role": "system", "content": [{"type": "text", "text": system_prompt}]}
)
for msg in messages:
if hasattr(msg, "parts"):
for part in msg.parts:
role = "user"
# Use isinstance checks with proper base classes
if SystemPromptPart and isinstance(part, SystemPromptPart):
role = "system"
elif (
(TextPart and isinstance(part, TextPart))
or (ThinkingPart and isinstance(part, ThinkingPart))
or (BaseToolCallPart and isinstance(part, BaseToolCallPart))
):
role = "assistant"
elif BaseToolReturnPart and isinstance(part, BaseToolReturnPart):
role = "tool"
content = [] # type: List[Dict[str, Any] | str]
tool_calls = None
tool_call_id = None
# Handle ToolCallPart (assistant requesting tool use)
if BaseToolCallPart and isinstance(part, BaseToolCallPart):
tool_call_data = {}
if hasattr(part, "tool_name"):
tool_call_data["name"] = part.tool_name
if hasattr(part, "args"):
tool_call_data["arguments"] = safe_serialize(part.args)
if tool_call_data:
tool_calls = [tool_call_data]
# Handle ToolReturnPart (tool result)
elif BaseToolReturnPart and isinstance(part, BaseToolReturnPart):
if hasattr(part, "tool_name"):
tool_call_id = part.tool_name
if hasattr(part, "content"):
content.append({"type": "text", "text": str(part.content)})
# Handle regular content
elif hasattr(part, "content"):
if isinstance(part.content, str):
content.append({"type": "text", "text": part.content})
elif isinstance(part.content, list):
for item in part.content:
if isinstance(item, str):
content.append({"type": "text", "text": item})
else:
content.append(safe_serialize(item))
else:
content.append({"type": "text", "text": str(part.content)})
# Add message if we have content or tool calls
if content or tool_calls:
message = {"role": role} # type: Dict[str, Any]
if content:
message["content"] = content
if tool_calls:
message["tool_calls"] = tool_calls
if tool_call_id:
message["tool_call_id"] = tool_call_id
formatted_messages.append(message)
if formatted_messages:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, formatted_messages, unpack=False
)
except Exception:
# If we fail to format messages, just skip it
pass
def _set_output_data(span, response):
# type: (sentry_sdk.tracing.Span, Any) -> None
"""Set output data on a span."""
if not _should_send_prompts():
return
if not response:
return
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, response.model_name)
try:
# Extract text from ModelResponse
if hasattr(response, "parts"):
texts = []
tool_calls = []
for part in response.parts:
if TextPart and isinstance(part, TextPart) and hasattr(part, "content"):
texts.append(part.content)
elif BaseToolCallPart and isinstance(part, BaseToolCallPart):
tool_call_data = {
"type": "function",
}
if hasattr(part, "tool_name"):
tool_call_data["name"] = part.tool_name
if hasattr(part, "args"):
tool_call_data["arguments"] = safe_serialize(part.args)
tool_calls.append(tool_call_data)
if texts:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, texts)
if tool_calls:
span.set_data(
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, safe_serialize(tool_calls)
)
except Exception:
# If we fail to format output, just skip it
pass
def ai_client_span(messages, agent, model, model_settings):
# type: (Any, Any, Any, Any) -> sentry_sdk.tracing.Span
"""Create a span for an AI client call (model request).
Args:
messages: Full conversation history (list of messages)
agent: Agent object
model: Model object
model_settings: Model settings
"""
# Determine model name for span name
model_obj = model
if agent and hasattr(agent, "model"):
model_obj = agent.model
model_name = _get_model_name(model_obj) or "unknown"
span = sentry_sdk.start_span(
op=OP.GEN_AI_CHAT,
name=f"chat {model_name}",
origin=SPAN_ORIGIN,
)
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
_set_agent_data(span, agent)
_set_model_data(span, model, model_settings)
# Set streaming flag
agent_data = sentry_sdk.get_current_scope()._contexts.get("pydantic_ai_agent") or {}
is_streaming = agent_data.get("_streaming", False)
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, is_streaming)
# Add available tools if agent is available
agent_obj = agent
if not agent_obj:
# Try to get from Sentry scope
agent_data = (
sentry_sdk.get_current_scope()._contexts.get("pydantic_ai_agent") or {}
)
agent_obj = agent_data.get("_agent")
_set_available_tools(span, agent_obj)
# Set input messages (full conversation history)
if messages:
_set_input_messages(span, messages)
return span
def update_ai_client_span(span, model_response):
# type: (sentry_sdk.tracing.Span, Any) -> None
"""Update the AI client span with response data."""
if not span:
return
# Set usage data if available
if model_response and hasattr(model_response, "usage"):
_set_usage_data(span, model_response.usage)
# Set output data
_set_output_data(span, model_response)
@@ -0,0 +1,49 @@
import sentry_sdk
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.utils import safe_serialize
from ..consts import SPAN_ORIGIN
from ..utils import _set_agent_data, _should_send_prompts
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
def execute_tool_span(tool_name, tool_args, agent, tool_type="function"):
# type: (str, Any, Any, str) -> sentry_sdk.tracing.Span
"""Create a span for tool execution.
Args:
tool_name: The name of the tool being executed
tool_args: The arguments passed to the tool
agent: The agent executing the tool
tool_type: The type of tool ("function" for regular tools, "mcp" for MCP services)
"""
span = sentry_sdk.start_span(
op=OP.GEN_AI_EXECUTE_TOOL,
name=f"execute_tool {tool_name}",
origin=SPAN_ORIGIN,
)
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "execute_tool")
span.set_data(SPANDATA.GEN_AI_TOOL_TYPE, tool_type)
span.set_data(SPANDATA.GEN_AI_TOOL_NAME, tool_name)
_set_agent_data(span, agent)
if _should_send_prompts() and tool_args is not None:
span.set_data(SPANDATA.GEN_AI_TOOL_INPUT, safe_serialize(tool_args))
return span
def update_execute_tool_span(span, result):
# type: (sentry_sdk.tracing.Span, Any) -> None
"""Update the execute tool span with the result."""
if not span:
return
if _should_send_prompts() and result is not None:
span.set_data(SPANDATA.GEN_AI_TOOL_OUTPUT, safe_serialize(result))
@@ -0,0 +1,112 @@
import sentry_sdk
from sentry_sdk.ai.utils import get_start_span_function, set_data_normalized
from sentry_sdk.consts import OP, SPANDATA
from ..consts import SPAN_ORIGIN
from ..utils import (
_set_agent_data,
_set_available_tools,
_set_model_data,
_should_send_prompts,
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
def invoke_agent_span(user_prompt, agent, model, model_settings):
# type: (Any, Any, Any, Any) -> sentry_sdk.tracing.Span
"""Create a span for invoking the agent."""
# Determine agent name for span
name = "agent"
if agent and getattr(agent, "name", None):
name = agent.name
span = get_start_span_function()(
op=OP.GEN_AI_INVOKE_AGENT,
name=f"invoke_agent {name}",
origin=SPAN_ORIGIN,
)
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
_set_agent_data(span, agent)
_set_model_data(span, model, model_settings)
_set_available_tools(span, agent)
# Add user prompt and system prompts if available and prompts are enabled
if _should_send_prompts():
messages = []
# Add system prompts (both instructions and system_prompt)
system_texts = []
if agent:
# Check for system_prompt
system_prompts = getattr(agent, "_system_prompts", None) or []
for prompt in system_prompts:
if isinstance(prompt, str):
system_texts.append(prompt)
# Check for instructions (stored in _instructions)
instructions = getattr(agent, "_instructions", None)
if instructions:
if isinstance(instructions, str):
system_texts.append(instructions)
elif isinstance(instructions, (list, tuple)):
for instr in instructions:
if isinstance(instr, str):
system_texts.append(instr)
elif callable(instr):
# Skip dynamic/callable instructions
pass
# Add all system texts as system messages
for system_text in system_texts:
messages.append(
{
"content": [{"text": system_text, "type": "text"}],
"role": "system",
}
)
# Add user prompt
if user_prompt:
if isinstance(user_prompt, str):
messages.append(
{
"content": [{"text": user_prompt, "type": "text"}],
"role": "user",
}
)
elif isinstance(user_prompt, list):
# Handle list of user content
content = []
for item in user_prompt:
if isinstance(item, str):
content.append({"text": item, "type": "text"})
if content:
messages.append(
{
"content": content,
"role": "user",
}
)
if messages:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages, unpack=False
)
return span
def update_invoke_agent_span(span, output):
# type: (sentry_sdk.tracing.Span, Any) -> None
"""Update and close the invoke agent span."""
if span and _should_send_prompts() and output:
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, str(output), unpack=False
)

Some files were not shown because too many files have changed in this diff Show More