diff --git a/sentry_sdk/integrations/langchain.py b/sentry_sdk/integrations/langchain.py index 431fc46bec..1064f29ffd 100644 --- a/sentry_sdk/integrations/langchain.py +++ b/sentry_sdk/integrations/langchain.py @@ -22,6 +22,7 @@ from langchain_core.callbacks import ( manager, BaseCallbackHandler, + Callbacks, ) from langchain_core.agents import AgentAction, AgentFinish except ImportError: @@ -416,50 +417,57 @@ def _wrap_configure(f): # type: (Callable[..., Any]) -> Callable[..., Any] @wraps(f) - def new_configure(*args, **kwargs): - # type: (Any, Any) -> Any + def new_configure( + callback_manager_cls, # type: type + inheritable_callbacks=None, # type: Callbacks + local_callbacks=None, # type: Callbacks + *args, # type: Any + **kwargs, # type: Any + ): + # type: (...) -> Any integration = sentry_sdk.get_client().get_integration(LangchainIntegration) if integration is None: - return f(*args, **kwargs) + return f( + callback_manager_cls, + inheritable_callbacks, + local_callbacks, + *args, + **kwargs, + ) - with capture_internal_exceptions(): - new_callbacks = [] # type: List[BaseCallbackHandler] - if "local_callbacks" in kwargs: - existing_callbacks = kwargs["local_callbacks"] - kwargs["local_callbacks"] = new_callbacks - elif len(args) > 2: - existing_callbacks = args[2] - args = ( - args[0], - args[1], - new_callbacks, - ) + args[3:] - else: - existing_callbacks = [] - - if existing_callbacks: - if isinstance(existing_callbacks, list): - for cb in existing_callbacks: - new_callbacks.append(cb) - elif isinstance(existing_callbacks, BaseCallbackHandler): - new_callbacks.append(existing_callbacks) - else: - logger.debug("Unknown callback type: %s", existing_callbacks) - - already_added = False - for callback in new_callbacks: - if isinstance(callback, SentryLangchainCallback): - already_added = True - - if not already_added: - new_callbacks.append( - SentryLangchainCallback( - integration.max_spans, - integration.include_prompts, - integration.tiktoken_encoding_name, - ) - ) - return f(*args, **kwargs) + callbacks_list = local_callbacks or [] + + if isinstance(callbacks_list, BaseCallbackHandler): + callbacks_list = [callbacks_list] + elif not isinstance(callbacks_list, list): + logger.debug("Unknown callback type: %s", callbacks_list) + # Just proceed with original function call + return f( + callback_manager_cls, + inheritable_callbacks, + local_callbacks, + *args, + **kwargs, + ) + + if not any(isinstance(cb, SentryLangchainCallback) for cb in callbacks_list): + # Avoid mutating the existing callbacks list + callbacks_list = [ + *callbacks_list, + SentryLangchainCallback( + integration.max_spans, + integration.include_prompts, + integration.tiktoken_encoding_name, + ), + ] + + return f( + callback_manager_cls, + inheritable_callbacks, + callbacks_list, + *args, + **kwargs, + ) return new_configure