Source code for freeiam.ldap.connection

# SPDX-FileCopyrightText: 2025 Florian Best
# SPDX-License-Identifier: MIT OR Apache-2.0
"""LDAP Connection."""

import asyncio
import contextlib
import logging
import math
import os
from collections.abc import AsyncGenerator, Callable, Generator, Sequence
from types import TracebackType
from typing import Any, Literal, Self, TypeAlias, cast, overload

import ldap.ldapobject
import ldap.modlist
import ldap.sasl
from ldap.schema import SCHEMA_ATTRS

from freeiam import errors
from freeiam.ldap._wrapper import Page, Result, _Response
from freeiam.ldap.attr import Attributes
from freeiam.ldap.constants import (
    AnyOption,
    AnyOptionValue,
    Option,
    OptionValue,
    ResponseType,
    Scope,
    TLSCRLCheck,
    TLSOption,
    TLSProtocol,
    TLSRequireCert,
    Version,
)
from freeiam.ldap.controls import Controls, server_side_sorting, simple_paged_results, transaction, virtual_list_view
from freeiam.ldap.dn import DN
from freeiam.ldap.extended_operations import ExtendedRequest, ExtendedResponse, refresh_ttl, transaction_commit, transaction_start
from freeiam.ldap.schema import Schema
from freeiam.ldap.sync_connection import Connection as SynchronousConnection


__all__ = ('Connection',)

log = logging.getLogger(__name__)

LDAPObject: TypeAlias = ldap.ldapobject.SimpleLDAPObject
LDAPAddList: TypeAlias = list[tuple[str, list[bytes]]]
LDAPModList: TypeAlias = list[tuple[int, str, list[bytes]]]
Sorting: TypeAlias = list[str | tuple[str, str | None, bool]]


[docs] class Connection: """ A LDAP Connection. :ivar str uri: The LDAP URI. :ivar int timelimit: The global timelimit. :ivar bool automatic_reconnect: Whether automatic reconnection is enabled. :ivar int max_connection_attempts: number of connection attempt on connection loss. :ivar float retry_delay: The retry delay (in seconds) between the reconnection attempts. """ __slots__ = ( '__conn_s', '__reconnects_counter', '__schema', '_conn', '_hide_parent_exception', '_last_auth_state', '_options', '_start_tls', 'automatic_reconnect', 'max_connection_attempts', 'retry_delay', 'timeout', 'uri', ) @property def conn(self) -> LDAPObject: """The underlying connection.""" if self._conn is None: raise RuntimeError('not connected') # noqa: TRY003 return self._conn @property def fileno(self) -> int: """Get the file descriptor number of the active socket connection.""" if not self._conn or not hasattr(self._conn, '_l'): return -1 return self._conn.fileno() @property def connected(self) -> bool: """Whether the connection is established.""" return self.fileno != -1 def __init__( self, uri: str | None = '', *, start_tls: bool = False, timeout: int = -1, automatic_reconnect: bool = True, max_connection_attempts: int = 10, retry_delay: float = 0.0, _hide_parent_exception: bool = True, _conn: LDAPObject | None = None, ) -> None: self._conn = _conn self.uri = uri self.timeout = timeout self.automatic_reconnect = automatic_reconnect self.max_connection_attempts = max_connection_attempts self.retry_delay = retry_delay self._start_tls = start_tls self.__reconnects_counter = 0 self.__schema: dict[DN | str | None, Schema] = {} self._last_auth_state: tuple[str, str | None, str | None] | None = None self._options: list[tuple[AnyOption, AnyOptionValue | Sequence[ldap.controls.RequestControl]]] = [] self._hide_parent_exception = _hide_parent_exception self.__conn_s: SynchronousConnection | None = None @property def _sync_connection(self) -> SynchronousConnection: if self.__conn_s is None: self.__conn_s = SynchronousConnection( self.uri, # start_tls=self._start_tls, timeout=self.timeout, automatic_reconnect=self.automatic_reconnect, max_connection_attempts=self.max_connection_attempts, retry_delay=self.retry_delay, _hide_parent_exception=self._hide_parent_exception, _conn=self._conn, ) return self.__conn_s async def __aenter__(self) -> Self: """Initialize asynchronous connection.""" self.connect() return self async def __aexit__(self, etype: type[BaseException] | None, exc: BaseException | None, etraceback: TracebackType | None) -> None: """Close connection on shutdown.""" await self.unbind() self.disconnect() def __enter__(self) -> SynchronousConnection: """Initialize synchronous connection.""" return self._sync_connection.__enter__() def __exit__(self, etype: type[BaseException] | None, exc: BaseException | None, etraceback: TracebackType | None) -> None: """Close connection on shutdown.""" self._sync_connection.__exit__(etype, exc, etraceback) self.disconnect() @overload def get_option( self, option: Literal[Option.ProtocolVersion | Option.Timelimit | Option.NetworkTimeout | Option.Dereference | Option.Sizelimit] ) -> int: ... @overload def get_option(self, option: AnyOption) -> AnyOptionValue: ...
[docs] def get_option(self, option: AnyOption) -> AnyOptionValue: """Get a LDAP connection option.""" with errors.LdapError.wrap(self._hide_parent_exception): return cast('AnyOptionValue', self.conn.get_option(option))
@overload def set_option(self, option: Literal[Option.ServerControls], value: Sequence[ldap.controls.RequestControl]) -> None: ... @overload def set_option(self, option: Literal[Option.ClientControls], value: Sequence[ldap.controls.RequestControl]) -> None: ... @overload def set_option(self, option: Literal[Option.ProtocolVersion], value: int) -> None: ... @overload def set_option(self, option: AnyOption, value: AnyOptionValue | Sequence[ldap.controls.RequestControl], *, append: bool = True) -> None: ...
[docs] def set_option(self, option: AnyOption, value: AnyOptionValue | Sequence[ldap.controls.RequestControl], *, append: bool = True) -> None: """Set a LDAP connection option.""" if append: self._options.append((option, value)) with errors.LdapError.wrap(self._hide_parent_exception): self.conn.set_option(option, value)
[docs] def set_controls(self, controls: Controls) -> None: """Set LDAP controls for all operations on this connection.""" if controls.server is not None: self.set_option(Option.ServerControls, controls.server) if controls.client is not None: self.set_option(Option.ClientControls, controls.client)
@property def protocol_version(self) -> Version: """Get the LDAP protocol version.""" return Version(self.get_option(Option.ProtocolVersion)) @protocol_version.setter def protocol_version(self, value: Version) -> None: """Set the LDAP protocol version.""" self.set_option(Option.ProtocolVersion, value) @property def timelimit(self) -> int: """Get the LDAP time limit.""" return self.get_option(Option.Timelimit) @timelimit.setter def timelimit(self, value: int) -> None: """Set the LDAP time limit.""" self.set_option(Option.Timelimit, value) @property def network_timeout(self) -> int: """Get the LDAP network timeout.""" return self.get_option(Option.NetworkTimeout) @network_timeout.setter def network_timeout(self, value: int) -> None: """Set the LDAP network timeout.""" self.set_option(Option.NetworkTimeout, value) @property def dereference(self) -> int: """Get the de-reference setting.""" return self.get_option(Option.Dereference) @dereference.setter def dereference(self, value: int) -> None: """Set the de-reference setting.""" self.set_option(Option.Dereference, value) @property def follow_referrals(self) -> bool | None: """Follow referrals enabled.""" follow = self.get_option(Option.Referrals) if follow == -1: return None return follow == OptionValue.On @follow_referrals.setter def follow_referrals(self, value: bool) -> None: """Enable following of referrals.""" self.set_option(Option.Referrals, OptionValue.On if value else OptionValue.Off) @property def sizelimit(self) -> int: """Get the sizelimit setting.""" return self.get_option(Option.Sizelimit) @sizelimit.setter def sizelimit(self, value: int) -> None: """Set the sizelimit setting.""" self.set_option(Option.Sizelimit, value)
[docs] @classmethod def set_tls( cls, *, ca_certfile: str | None = None, ca_certdir: str | None = None, certfile: str | None = None, keyfile: str | None = None, require_cert: TLSRequireCert = TLSRequireCert.Demand, require_san: TLSRequireCert | None = None, minimum_protocol: TLSProtocol | None = None, cipher_suite: str | None = None, crlfile: None = None, crl_check: TLSCRLCheck | None = None, ) -> None: """Set the TLS certificate settings globally.""" for option, value in ( (TLSOption.CACertfile, ca_certfile), (TLSOption.CACertdir, ca_certdir), (TLSOption.Certfile, certfile), (TLSOption.Keyfile, keyfile), (TLSOption.ProtocolMin, minimum_protocol), (TLSOption.CipherSuite, cipher_suite), (TLSOption.RequireCert, require_cert), (TLSOption.RequireSAN, require_san), (TLSOption.CRLFile, crlfile), (TLSOption.CRLCheck, crl_check), ): if value is not None: cls.set_global_option(option, value) # apply the pending TLS settings, create new context: cls.set_global_option(TLSOption.NewContext, 0)
[docs] @classmethod def get_global_option(cls, option: AnyOption) -> AnyOptionValue: """Get a LDAP connection option.""" return cast('AnyOptionValue', ldap.get_option(option))
[docs] @classmethod def set_global_option(cls, option: AnyOption, value: AnyOptionValue) -> None: """Set a global LDAP option.""" ldap.set_option(option, value)
[docs] def connect(self, fileno: bool | None = None) -> None: """Connect to the LDAP server.""" if self.connected: raise RuntimeError('already connected') # noqa: TRY003 if self.automatic_reconnect: self._conn = ldap.ldapobject.ReconnectLDAPObject( self.uri, trace_level=0, trace_file=None, trace_stack_limit=5, fileno=fileno, retry_max=self.max_connection_attempts, retry_delay=self.retry_delay, ) else: self._conn = ldap.ldapobject.SimpleLDAPObject( self.uri, trace_level=0, trace_file=None, trace_stack_limit=5, fileno=fileno, ) self.__conn_s = None self.__reconnects_counter = 0 if self._start_tls: self.start_tls()
[docs] def disconnect(self) -> None: """Disconnect from LDAP server.""" self._remove_reader() self._conn = None self.__conn_s = None
[docs] def reconnect(self, *, force: bool = True) -> None: """Reconnect to the LDAP server.""" self._remove_reader() with errors.LdapError.wrap(self._hide_parent_exception): self.conn.reconnect(self.uri, self.max_connection_attempts, self.retry_delay, force=force) # if self._start_tls: # self.start_tls() self._restore_options() self._restore_auth_state()
[docs] def start_tls(self) -> None: """Start TLS.""" self._start_tls = True with errors.LdapError.wrap(self._hide_parent_exception): self.conn.start_tls_s()
[docs] async def get_schema(self, subschema_dn: DN | str | None = None) -> Schema: """Get LDAP Schema.""" conn = self.conn # cache schema by connection if isinstance(conn, ldap.ldapobject.ReconnectLDAPObject) and conn._reconnects_done > self.__reconnects_counter: del self.__schema[subschema_dn] self.__reconnects_counter = self.conn._reconnects_done if not self.__schema.get(subschema_dn): try: subschemasubentry = await ( self.get(subschema_dn, attrs=['subschemaSubentry']) if subschema_dn else self.get_root_dse(['subschemaSubentry']) ) try: subschemasubentry_dn = cast('Attributes', subschemasubentry.attr)['subschemaSubentry'][0].decode('UTF-8') except KeyError: # pragma: no cover; impossible? if subschema_dn: return await self.get_schema() subschema = None except (errors.NoSuchObject, errors.NoSuchAttribute, errors.InsufficientAccess, errors.UndefinedType): subschema = None else: try: subschema = (await self.get(subschemasubentry_dn, SCHEMA_ATTRS, '(objectClass=subschema)')).attr except errors.NoSuchObject: # pragma: no cover subschema = None self.__schema[subschema_dn] = Schema(ldap.schema.SubSchema(cast('dict[str, list[bytes]]', subschema), 0)) Attributes.set_schema(self.__schema[subschema_dn]) return self.__schema[subschema_dn]
[docs] async def bind(self, authzid: str | None, password: str | None, *, controls: Controls | None = None) -> Result: """Authenticate via plaintext credentials.""" conn = self.conn self._last_auth_state = ('simple_bind_s', authzid, password) response = await self._execute(conn, conn.simple_bind, authzid, password, **Controls.expand(controls)) return Result.from_response(None, None, controls, response)
[docs] async def bind_external(self) -> None: # pragma: no cover """Authenticate via EXTERNAL method e.g. UNIX socket or TLS client certificate.""" with errors.LdapError.wrap(self._hide_parent_exception): self.conn.sasl_interactive_bind_s('', ldap.sasl.external())
[docs] async def bind_sasl_gssapi(self) -> None: # pragma: no cover """Authenticate via GSSAPI e.g. via Kerberos ticket.""" with errors.LdapError.wrap(self._hide_parent_exception): self.conn.sasl_interactive_bind_s('', ldap.sasl.gssapi())
[docs] async def bind_oauthbearer(self, authzid: str | None, token: str) -> None: # pragma: no cover; requires SASL module """Authenticate via OAuth 2.0 Access Token.""" oauth = ldap.sasl.sasl( { ldap.sasl.CB_AUTHNAME: authzid, ldap.sasl.CB_PASS: token, }, 'OAUTHBEARER', ) with errors.LdapError.wrap(self._hide_parent_exception): self.conn.sasl_interactive_bind_s('', oauth)
def _restore_options(self) -> None: for option, value in self._options: self.set_option(option, value, append=False) def _restore_auth_state(self) -> None: if self._last_auth_state: with errors.LdapError.wrap(self._hide_parent_exception): getattr(self.conn, self._last_auth_state[0])(*self._last_auth_state[1:])
[docs] async def unbind(self, *, controls: Controls | None = None) -> Result | None: """Unbind.""" self._last_auth_state = None try: conn = self.conn except RuntimeError: # not connected return None try: response = await self._execute(conn, conn.unbind_ext, **Controls.expand(controls)) except AttributeError as exc: # duplicated unbind if exc.args and exc.args[0] == "ReconnectLDAPObject has no attribute '_l'": return None raise # pragma: no cover; not possible return Result.from_response(None, None, controls, response)
[docs] async def whoami(self, *, controls: Controls | None = None) -> DN | str | None: """Get authenticated user DN (authzid). "Who am I?" Operation.""" try: with errors.LdapError.wrap(self._hide_parent_exception): dn = self.conn.whoami_s(**Controls.expand(controls)) except RuntimeError: return None if dn.startswith('dn:'): return DN(dn.removeprefix('dn:')) return dn # pragma: no cover
[docs] async def change_password(self, dn: DN | str, old_password: str, new_password: str, *, controls: Controls | None = None) -> Result: """Change password.""" conn = self.conn response = await self._execute(conn, conn.passwd, str(dn), old_password, new_password, **Controls.expand(controls)) return Result.from_response(dn, None, controls, response) # pragma: no cover
[docs] async def exists(self, dn: DN | str, unique: bool = False, *, controls: Controls | None = None) -> bool: """Check if LDAP object exists.""" try: await self.get(dn, ['1.1'], unique=unique, controls=controls) except errors.NoSuchObject: return False return True
[docs] async def get( self, dn: DN | str, attrs: list[str] | None = None, filter_expr: str = '(objectClass=*)', *, unique: bool = False, controls: Controls | None = None, ) -> Result: """Get a LDAP object.""" for obj in await self.search(base=dn, scope=Scope.BASE, filter_expr=filter_expr, attrs=attrs, unique=unique, controls=controls): return obj return None # type: ignore[return-value] # pragma: no cover; impossible
# obj, = [_ async for _ in self.search_iter(base=dn, scope=Scope.BASE, filter_expr=filter_expr, attrs=attrs, unique=unique, controls=controls)] # noqa: E501 # return obj[0] # # GC calls gen.aclose() causing unnecessary .cancel() to be called: # # return await anext(self.search_iter(base=dn, scope=Scope.BASE, filter_expr=filter_expr, attrs=attrs, unique=unique, controls=controls))
[docs] async def get_attr( self, dn: DN | str, attr: str, filter_expr: str = '(objectClass=*)', *, unique: bool = False, controls: Controls | None = None ) -> list[bytes]: """Get attribute of an LDAP object.""" attributes = (await self.get(dn, attrs=[attr], filter_expr=filter_expr, unique=unique, controls=controls)).attr assert attributes is not None # noqa: S101 try: return attributes[attr] except KeyError: await self.get_schema() return attributes[attr]
[docs] async def search_iter( self, base: DN | str = '', scope: Scope = Scope.SUBTREE, filter_expr: str = '(objectClass=*)', attrs: list[str] | None = None, *, unique: bool = False, sizelimit: bool | None = None, sorting: Sorting | None = None, controls: Controls | None = None, _attrsonly: bool = False, ) -> AsyncGenerator[Result, None]: """Search iterative for DN and Attributes of LDAP objects.""" conn = self.conn all_results = [] if sorting: controls = Controls.set_server(controls, server_side_sorting(*sorting, criticality=True)) # sizelimit = 1 if unique else sizelimit try: async for response in self._execute_iter( conn, conn.search_ext, str(base), scope, filterstr=filter_expr, attrlist=attrs, attrsonly=int(_attrsonly), **Controls.expand(controls), timeout=self.timeout, sizelimit=sizelimit or OptionValue.NoLimit, ): Result.set_controls(response, controls) assert response.data is not None # noqa: S101 results = [Result.from_response(dn, attributes, controls, response) for dn, attributes in response.data] all_results.extend(results) if unique and len(all_results) > 1: raise errors.NotUnique(all_results) try: for result in results: yield result except GeneratorExit as exc: with contextlib.suppress(errors.NoSuchOperation): # await self.cancel(response.msgid) # better do it immediately assert response.msgid is not None # noqa: S101 self._sync_connection.cancel(response.msgid) raise exc from exc except errors.NoSuchObject as no_object_error: no_object_error.base_dn = DN.get(base) no_object_error.filter = filter_expr no_object_error.scope = scope no_object_error.attrs = attrs raise
# except errors.SizelimitExceeded: # if not unique: # raise # raise errors.NotUnique() from None
[docs] async def search( self, base: DN | str = '', scope: Scope = Scope.SUBTREE, filter_expr: str = '(objectClass=*)', attrs: list[str] | None = None, *, unique: bool = False, sizelimit: bool | None = None, sorting: Sorting | None = None, controls: Controls | None = None, _attrsonly: bool = False, ) -> list[Result]: """Search for DN and Attributes of LDAP objects.""" conn = self.conn all_results = [] if sorting: controls = Controls.set_server(controls, server_side_sorting(*sorting)) try: response = await self._execute( conn, conn.search_ext, str(base), scope, filterstr=filter_expr, attrlist=attrs, attrsonly=int(_attrsonly), **Controls.expand(controls), timeout=self.timeout, sizelimit=sizelimit or OptionValue.NoLimit, ) Result.set_controls(response, controls) assert response.data is not None # noqa: S101 results = [Result.from_response(dn, attributes, controls, response) for dn, attributes in response.data] all_results.extend(results) if unique and len(all_results) > 1: raise errors.NotUnique(all_results) except errors.NoSuchObject as no_object_error: no_object_error.base_dn = DN.get(base) no_object_error.filter = filter_expr no_object_error.scope = scope no_object_error.attrs = attrs raise return results
[docs] async def search_dn( self, base: DN | str = '', scope: Scope = Scope.SUBTREE, filter_expr: str = '(objectClass=*)', *, unique: bool = False, sizelimit: bool | None = None, sorting: Sorting | None = None, controls: Controls | None = None, ) -> AsyncGenerator[DN, None]: """Search for DNs of LDAP objects.""" # FIXME: the following hangs forever as the iterative search is unfinished while the FD reader is replaced # async for result in self.search( # base, scope, filter_expr, ['1.1'], unique=unique, sizelimit=sizelimit, sorting=sorting, controls=controls, _attrsonly=True # ): for result in await self.search( base, scope, filter_expr, [], unique=unique, sizelimit=sizelimit, sorting=sorting, controls=controls, _attrsonly=True ): assert result.dn is not None # noqa: S101 yield result.dn
[docs] async def search_paginated( self, base: DN | str = '', scope: Scope = Scope.SUBTREE, filter_expr: str = '(objectClass=*)', attrs: list[str] | None = None, *, page_size: int = 100, sorting: Sorting, unique: bool = False, sizelimit: bool | None = None, controls: Controls | None = None, ) -> AsyncGenerator[Result, None]: """Search paginated using Virtual List View control.""" controls = Controls.set_server(controls, server_side_sorting(*sorting)) res_vlv = virtual_list_view.response() context_id = None length = None last_page = None page = 1 while True: offset = ((page or 1) - 1) * page_size pagination = virtual_list_view( before_count=0, after_count=page_size - 1, offset=offset + 1, content_count=0, greater_than_or_equal=None, context_id=context_id, criticality=True, ) controls = Controls.set_server(controls, pagination) if length is not None and offset > length: break # end reached vlv: ldap.controls.vlv.VLVResponseControl | None = None current = None for entry_number, result in enumerate( await self.search(base, scope, filter_expr, attrs, unique=unique, sizelimit=sizelimit, controls=controls), 1 ): if last_page is None: vlv = cast('ldap.controls.vlv.VLVResponseControl', controls.get(res_vlv)) length = vlv.contentCount last_page = math.ceil(length / (page_size or length)) result.page = Page( page=page, entry=entry_number, page_size=page_size, results=length, last_page=last_page, ) current = result yield result if current is None: # no search results break page += 1 vlv = cast('ldap.controls.vlv.VLVResponseControl', controls.get(res_vlv)) context_id = vlv.context_id length = vlv.contentCount
[docs] async def search_paged( self, base: DN | str = '', scope: Scope = Scope.SUBTREE, filter_expr: str = '(objectClass=*)', attrs: list[str] | None = None, page_size: int = 100, *, unique: bool = False, sizelimit: bool | None = None, sorting: Sorting | None = None, controls: Controls | None = None, ) -> AsyncGenerator[Result, None]: """Search paginated using SimplePagedResults control.""" pagination = simple_paged_results(size=page_size, cookie='', criticality=True) controls = Controls.append_server(controls, pagination) if sorting: controls = Controls.set_server(controls, server_side_sorting(*sorting)) page = 0 while True: current = None page += 1 entry_number = 0 async for result in self.search_iter(base, scope, filter_expr, attrs, unique=unique, sizelimit=sizelimit, controls=controls): entry_number += 1 result.page = Page(page=page, entry=entry_number, page_size=page_size) current = result yield result if current is None: # no search results break control = controls.get(pagination) if not control: # pragma: no cover break # Server doesn't support pagination pagination.cookie = cast('ldap.controls.pagedresults.SimplePagedResultsControl', controls.get(pagination)).cookie if not pagination.cookie: break
[docs] async def add( self, dn: DN | str, attrs: dict[str, list[bytes]] | Attributes, *, controls: Controls | None = None, ) -> Result: """Create a LDAP object.""" al = ldap.modlist.addModlist(attrs) return await self.add_al(dn, al, controls=controls)
[docs] async def add_al( self, dn: DN | str, al: LDAPAddList, *, controls: Controls | None = None, ) -> Result: """Create a LDAP object from addlist.""" conn = self.conn response = await self._execute(conn, conn.add_ext, str(dn), al, **Controls.expand(controls)) return Result.from_response(dn, None, controls, response)
[docs] async def modify( self, dn: DN | str, oldattr: dict[str, list[bytes]] | Attributes, newattr: dict[str, list[bytes]] | Attributes, *, controls: Controls | None = None, ) -> Result: """Modify a LDAP object.""" ml = ldap.modlist.modifyModlist(oldattr, newattr) return await self.modify_ml(dn, ml, controls=controls)
[docs] async def modify_ml( self, dn: DN | str, ml: LDAPModList, *, controls: Controls | None = None, ) -> Result: """Modify a LDAP object from modlist.""" conn = self.conn new_dn = self._compute_changed_dn(DN.get(dn), ml) if dn != new_dn: dn = cast('DN', (await self.rename(dn, new_dn)).dn) response = await self._execute(conn, conn.modify_ext, str(dn), ml, **Controls.expand(controls)) return Result.from_response(dn, None, controls, response)
@classmethod def _compute_changed_dn(cls, dn: DN, ml: LDAPModList) -> DN: """ Get changed DN. >>> Connection._compute_changed_dn('cn=foo,dc=bar', [(ldap.MOD_REPLACE, 'cn', b'foo')]) 'cn=foo,dc=bar' >>> Connection._compute_changed_dn('cn=foo,dc=bar', [(ldap.MOD_REPLACE, 'cn', b'bar')]) 'cn=bar,dc=bar' >>> Connection._compute_changed_dn('cn=foo,dc=bar', [(ldap.MOD_REPLACE, 'cn', b'föo')]) == 'cn=föo,dc=bar' True """ rdn = dn.rdns[0] dn_vals = {x[0].lower(): x[1] for x in rdn} new_vals = { key.lower(): val.decode('UTF-8') if isinstance(val, bytes) else val[0].decode('UTF-8') for op, key, val in ml if key.lower() in dn_vals and val and op != ldap.MOD_DELETE } new_rdn_ava = [(x, new_vals.get(x.lower(), dn_vals[x.lower()]), ldap.AVA_STRING) for x in [y[0] for y in rdn]] new_rdn = DN( ldap.dn.dn2str( [ [(key, val, ava_type) for key, val, ava_type in new_rdn_ava], ], ), ) if dn[0] != new_rdn: return new_rdn + cast('DN', dn.parent) return dn
[docs] async def move( self, dn: DN | str, newposition: DN | str, *, controls: Controls | None = None, ) -> Result: """Move a LDAP object.""" dn = DN.get(dn) newposition = DN.get(newposition) return await self.rename(dn, dn[0] + newposition, delete_old=True, controls=controls)
[docs] async def rename( self, dn: DN | str, newdn: DN | str, delete_old: bool = True, *, controls: Controls | None = None, ) -> Result: """Rename a LDAP object.""" conn = self.conn newdn = DN.get(newdn) response = await self._execute(conn, conn.rename, str(dn), str(newdn[0]), str(newdn.parent), int(delete_old), **Controls.expand(controls)) return Result.from_response(newdn, None, controls, response)
[docs] async def modrdn( self, dn: DN | str, newrdn: DN | str, delete_old: bool = True, *, controls: Controls | None = None, ) -> Result: """Rename a LDAP object.""" return await self.rename(dn, DN.get(newrdn) + cast('DN', DN.get(dn).parent), delete_old, controls=controls)
[docs] async def delete(self, dn: DN | str, *, controls: Controls | None = None) -> Result: """Delete a LDAP object.""" conn = self.conn response = await self._execute(conn, conn.delete_ext, str(dn), **Controls.expand(controls)) return Result.from_response(dn, None, controls, response)
[docs] async def delete_recursive(self, dn: DN | str, *, controls: Controls | None = None) -> Result: """Delete a LDAP object recursively.""" try: return await self.delete(dn, controls=controls) except errors.NotAllowedOnNonleaf: async for child in self.search_dn(dn, Scope.ONELEVEL): await self.delete_recursive(child) return await self.delete(dn, controls=controls)
[docs] async def compare( self, dn: DN | str, attr: str, value: bytes, *, controls: Controls | None = None, ) -> bool: """Compare the value of an LDAP object.""" conn = self.conn try: await self._execute(conn, conn.compare_ext, str(dn), attr, value, **Controls.expand(controls)) except errors.NoSuchObject as no_object_error: no_object_error.base_dn = DN.get(dn) raise except errors.CompareTrue: return True except errors.CompareFalse: return False raise RuntimeError() # pragma: no cover; impossible
[docs] async def compare_dn(self, entry: DN | str, dn: DN | str) -> bool: """Compare LDAP DN with existing entry.""" dn = DN.get(dn) entry = DN.get(entry) for i, parent in enumerate(entry.walk()): for attr, value, _ in dn.rdns[-i - 1]: try: equal = await self.compare(str(parent), attr, value.encode('UTF-8')) if not equal: # pragma: no cover; https://github.com/nedbat/coveragepy/issues/2014 return False except errors.NoSuchObject: if attr == entry.rdns[-1][0][0]: continue raise return True
[docs] async def get_root_dse(self, attrs: list[str] | None = None, filter_expr: str = '(objectClass=*)') -> Result: """Get Root DSE (Directory Server Entry).""" return await self.get('', attrs or ['*', '+'], filter_expr=filter_expr)
[docs] async def get_naming_contexts(self) -> list[str]: """Return namingContexts of Root DSE.""" result = await self.get_attr('', 'namingContexts') return [x.decode('UTF-8') for x in result]
[docs] async def abandon(self, msgid: int, *, controls: Controls | None = None) -> Result: """Abandon a LDAP operation.""" log.debug('Abandon: %s', msgid) conn = self.conn response = await self._execute(conn, conn.abandon_ext, msgid, **Controls.expand(controls)) return Result.from_response(None, None, controls, response)
[docs] async def cancel(self, msgid: int, *, controls: Controls | None = None) -> bool: """Cancel a LDAP operation.""" log.debug('Cancel: %s', msgid) try: conn = self.conn except RuntimeError: return False try: await self._execute(conn, conn.cancel, msgid, **Controls.expand(controls)) except errors.NoSuchOperation: log.warning('Cancel failed', extra={'msgid': msgid}) raise except (errors.Cancelled, errors.Success): # pragma: no cover; theoretically, according to python-ldap return True except errors.TooLate: # pragma: no cover return False else: # pragma: no cover return True
[docs] @contextlib.asynccontextmanager async def transaction(self, set_controls: bool = True) -> AsyncGenerator[bytes, None]: """Context manager to make a transaction, which is aborted on errors.""" result = await self.extended(transaction_start(), transaction_start.response) txn_id = result.extended_value if set_controls: self.set_controls(Controls.set_server(None, transaction(txn_id, criticality=True))) try: yield txn_id except BaseException: self.set_option(Option.ServerControls, []) await self.extended(transaction_commit(txn_id, commit=False), transaction_commit.response) raise else: self.set_option(Option.ServerControls, []) try: await self.extended(transaction_commit(txn_id, commit=True), transaction_commit.response) except errors.OperationsError as exc: log.warning('Failure during commiting transaction', extra={'error': exc})
[docs] async def refresh_ttl(self, dn: DN | str, ttl: int) -> Result: """Perform Refresh extended operation.""" req = refresh_ttl(dn, ttl) result = await self.extended(req, refresh_ttl.response) result.dn = DN.get(dn) return result
[docs] async def extended( self, request: ExtendedRequest, response_class: type[ExtendedResponse] | None = None, *, controls: Controls | None = None ) -> Result: """Perform extended operation.""" conn = self.conn response = await self._execute(conn, conn.extop, request, **Controls.expand(controls)) decoded_value = None if response_class: with errors.LdapError.wrap(self._hide_parent_exception): if response_class.responseName != response.name: raise errors.ProtocolError({ 'desc': 'OID in extended response does not match response class.', 'info': f'expected: {response_class.responseName}; got: {response.name}', }) if response.value: decoded_value = response_class(response.name, response.value).responseValue return Result.from_response(None, None, controls, response, extended_value=decoded_value)
def __getstate__(self) -> dict[str, Any]: """Return state for pickle.""" return {slot: getattr(self, slot) for slot in set(self.__slots__) - {'_conn'} | {'connected'} if not slot.startswith('__')} def __setstate__(self, state: dict[str, Any]) -> None: """Set state for pickle.""" self._conn = None connected = state.pop('connected', None) for slot, value in state.items(): setattr(self, slot, value) if connected: self.connect() self._restore_options() self._restore_auth_state() async def _execute(self, conn: LDAPObject, operation: Callable[..., Any], *args: Any, **kwargs: Any) -> _Response: """Execute the operation and wait asynchronously for the result.""" msgid = await self._retry(self.request, operation, *args, **kwargs) if msgid is None: # abandon_ext, unbind_ext return _Response(None, None, msgid, [], None, None) response: _Response | None = None async for resp in self._poll(conn, msgid, 1): if response is not None: # pragma: no cover raise RuntimeError('Wrong method used! Use _execute_iter instead!') # noqa: TRY003 response = resp assert response is not None # noqa: S101 return response async def _execute_iter(self, conn: LDAPObject, operation: Callable[..., Any], *args: Any, **kwargs: Any) -> AsyncGenerator[_Response, None]: """Execute the operation and yield the results asynchronously.""" msgid = await self._retry(self.request, operation, *args, **kwargs) if msgid is None: # abandon_ext, unbind_ext return async for response in self._poll(conn, msgid, 0): yield response
[docs] def get_result(self, conn: LDAPObject, msgid: int = ResponseType.Any, _all: int = 0, timeout: int = 0) -> _Response: """Get the LDAP result for the given msgid.""" log.debug('result(%r, timeout=%r)', msgid, timeout, extra={'MSGID': msgid, 'ALL': _all, 'TIMEOUT': timeout, 'FUNC': 'result'}) try: with errors.LdapError.wrap(self._hide_parent_exception): response = _Response(*conn.result4(msgid, all=_all, timeout=timeout, add_extop=1)) # type: ignore[arg-type] except (errors.LdapError, OSError) as exc: log.debug('result(%r) -> raised %r', msgid, exc, extra={'MSGID': msgid, 'OPERATION': 'result', 'EXCEPTION': str(exc)}) raise log.debug('result(%r) -> %s', msgid, repr(response)[:200], extra={'MSGID': msgid, 'OPERATION': 'result'}) return response
[docs] def request(self, operation: Callable[..., int], *args: Any, **kwargs: Any) -> int | None: """Make the LDAP request for the given operation.""" op = operation.__name__ arg_str = ', '.join(map(repr, args)) if 'bind' not in op else '' kw = ', '.join(f'{k}={v!r}' for k, v in kwargs.items()) if 'bind' not in op else '' log.debug('Request %s(%s%s%s)', op, arg_str, ', ' if kw else '', kw, extra={'OPERATION': op, 'ARGUMENTS': arg_str, 'KEYWORDS': kw}) try: with errors.LdapError.wrap(self._hide_parent_exception): msgid = operation(*args, **kwargs) except (errors.LdapError, OSError) as exc: log.debug('%s() -> %r', op, exc, extra={'OPERATION': op, 'EXCEPTION': str(exc)}) raise log.debug('%s() -> %r', op, msgid, extra={'OPERATION': op, 'MSGID': msgid}) return msgid
async def _retry(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """Retry operation or reconnect if necessary.""" max_attempts = attempts = self.max_connection_attempts while attempts: try: if attempts != max_attempts: # do a sync reconnect, make sure we don't end in recursion self.reconnect() # FIXME: recursion!? # self.conn.reconnect(self.uri) # self._restore_auth_state() # FIXME: recursion return func(*args, **kwargs) except (errors.ServerDown, errors.Unavailable, errors.ConnectError, errors.Timeout): attempts -= 1 if not attempts: raise await asyncio.sleep(self.retry_delay) raise RuntimeError() # pragma: no cover; impossible def _poll_s( self, conn: LDAPObject, msgid: ResponseType = ResponseType.Any, _all: int = 0 ) -> Generator[_Response, None, None]: # pragma: no cover """Wait synchronously for operation to succeed.""" # this method must only used by the synchronous variant of this class while True: try: response = self._retry(self.get_result, conn, msgid, _all=_all, timeout=self.timeout) except errors.NoResultsReturned: # pragma: no cover break rtype = response.type # type: ignore[attr-defined] if rtype is None: continue yield response # type: ignore[misc] if rtype == ldap.RES_SEARCH_ENTRY: continue if rtype == ldap.RES_SEARCH_RESULT: break break async def _poll(self, conn: LDAPObject, msgid: ResponseType = ResponseType.Any, _all: int = 0) -> AsyncGenerator[_Response, None]: """Wait asynchronously for operation to succeed.""" loop = asyncio.get_running_loop() while True: # TODO: move the asyncio stuff out of here fut = loop.create_future() fd = conn.fileno() self._add_reader(loop, fd, self._ready, conn, msgid, fut, _all) try: response = await self._wait_for(fut) except errors.NoResultsReturned: # pragma: no cover; how? self._remove_reader(fd) break except Exception: self._remove_reader(fd) raise rtype = response.type if rtype is None: # pragma: no cover; handled in _ready() continue yield response if rtype == ldap.RES_SEARCH_ENTRY: continue if rtype == ldap.RES_SEARCH_RESULT: break break def _ready(self, fd: int, conn: LDAPObject, msgid: int, fut: asyncio.Future[_Response], _all: ResponseType) -> None: log.debug('FD %s is ready', fd) try: os.fstat(fd) response = self.get_result(conn, msgid, _all=_all, timeout=0) if response.type is None: return fut.set_result(response) self._remove_reader(fd) except (OSError, errors.LdapError) as exc: log.error('FD %s is not valid - maybe connection is closed', fd) # noqa: TRY400 fut.set_exception(exc) self._remove_reader(fd) async def _wait_for(self, fut: asyncio.Future[_Response]) -> _Response: if self.timeout > 0: return await asyncio.wait_for(fut, timeout=self.timeout) return await fut @classmethod def _add_reader(cls, loop: asyncio.AbstractEventLoop, fd: int, func: Callable[..., Any], *args: Any) -> None: log.debug('Select on FD %s', fd) os.fstat(fd) # FIXME: deadlock; it replaces a previous reader, e.g. when in a iterative search or other parallel operation loop.add_reader(fd, func, fd, *args) # register reader from the loop thread # loop.call_soon_threadsafe(lambda: loop.add_reader(fd, func, fd, *args)) def _remove_reader(self, fd: int | None = None) -> None: fd = fd or self.fileno log.debug('Remove reader FD %s', fd) if fd == -1: # pragma: no cover return loop = asyncio.get_running_loop() loop.remove_reader(fd)
# loop.call_soon_threadsafe(lambda: loop.remove_reader(fd))