# SPDX-FileCopyrightText: 2025 Florian Best
# SPDX-License-Identifier: MIT OR Apache-2.0
"""LDAP Connection."""
import asyncio
import contextlib
import logging
import math
import os
from collections.abc import AsyncGenerator, Callable, Generator, Sequence
from types import TracebackType
from typing import Any, Literal, Self, TypeAlias, cast, overload
import ldap.ldapobject
import ldap.modlist
import ldap.sasl
from ldap.schema import SCHEMA_ATTRS
from freeiam import errors
from freeiam.ldap._wrapper import Page, Result, _Response
from freeiam.ldap.attr import Attributes
from freeiam.ldap.constants import (
AnyOption,
AnyOptionValue,
Option,
OptionValue,
ResponseType,
Scope,
TLSCRLCheck,
TLSOption,
TLSProtocol,
TLSRequireCert,
Version,
)
from freeiam.ldap.controls import Controls, server_side_sorting, simple_paged_results, transaction, virtual_list_view
from freeiam.ldap.dn import DN
from freeiam.ldap.extended_operations import ExtendedRequest, ExtendedResponse, refresh_ttl, transaction_commit, transaction_start
from freeiam.ldap.schema import Schema
from freeiam.ldap.sync_connection import Connection as SynchronousConnection
__all__ = ('Connection',)
log = logging.getLogger(__name__)
LDAPObject: TypeAlias = ldap.ldapobject.SimpleLDAPObject
LDAPAddList: TypeAlias = list[tuple[str, list[bytes]]]
LDAPModList: TypeAlias = list[tuple[int, str, list[bytes]]]
Sorting: TypeAlias = list[str | tuple[str, str | None, bool]]
[docs]
class Connection:
"""
A LDAP Connection.
:ivar str uri: The LDAP URI.
:ivar int timelimit: The global timelimit.
:ivar bool automatic_reconnect: Whether automatic reconnection is enabled.
:ivar int max_connection_attempts: number of connection attempt on connection loss.
:ivar float retry_delay: The retry delay (in seconds) between the reconnection attempts.
"""
__slots__ = (
'__conn_s',
'__reconnects_counter',
'__schema',
'_conn',
'_hide_parent_exception',
'_last_auth_state',
'_options',
'_start_tls',
'automatic_reconnect',
'max_connection_attempts',
'retry_delay',
'timeout',
'uri',
)
@property
def conn(self) -> LDAPObject:
"""The underlying connection."""
if self._conn is None:
raise RuntimeError('not connected') # noqa: TRY003
return self._conn
@property
def fileno(self) -> int:
"""Get the file descriptor number of the active socket connection."""
if not self._conn or not hasattr(self._conn, '_l'):
return -1
return self._conn.fileno()
@property
def connected(self) -> bool:
"""Whether the connection is established."""
return self.fileno != -1
def __init__(
self,
uri: str | None = '',
*,
start_tls: bool = False,
timeout: int = -1,
automatic_reconnect: bool = True,
max_connection_attempts: int = 10,
retry_delay: float = 0.0,
_hide_parent_exception: bool = True,
_conn: LDAPObject | None = None,
) -> None:
self._conn = _conn
self.uri = uri
self.timeout = timeout
self.automatic_reconnect = automatic_reconnect
self.max_connection_attempts = max_connection_attempts
self.retry_delay = retry_delay
self._start_tls = start_tls
self.__reconnects_counter = 0
self.__schema: dict[DN | str | None, Schema] = {}
self._last_auth_state: tuple[str, str | None, str | None] | None = None
self._options: list[tuple[AnyOption, AnyOptionValue | Sequence[ldap.controls.RequestControl]]] = []
self._hide_parent_exception = _hide_parent_exception
self.__conn_s: SynchronousConnection | None = None
@property
def _sync_connection(self) -> SynchronousConnection:
if self.__conn_s is None:
self.__conn_s = SynchronousConnection(
self.uri,
# start_tls=self._start_tls,
timeout=self.timeout,
automatic_reconnect=self.automatic_reconnect,
max_connection_attempts=self.max_connection_attempts,
retry_delay=self.retry_delay,
_hide_parent_exception=self._hide_parent_exception,
_conn=self._conn,
)
return self.__conn_s
async def __aenter__(self) -> Self:
"""Initialize asynchronous connection."""
self.connect()
return self
async def __aexit__(self, etype: type[BaseException] | None, exc: BaseException | None, etraceback: TracebackType | None) -> None:
"""Close connection on shutdown."""
await self.unbind()
self.disconnect()
def __enter__(self) -> SynchronousConnection:
"""Initialize synchronous connection."""
return self._sync_connection.__enter__()
def __exit__(self, etype: type[BaseException] | None, exc: BaseException | None, etraceback: TracebackType | None) -> None:
"""Close connection on shutdown."""
self._sync_connection.__exit__(etype, exc, etraceback)
self.disconnect()
@overload
def get_option(
self, option: Literal[Option.ProtocolVersion | Option.Timelimit | Option.NetworkTimeout | Option.Dereference | Option.Sizelimit]
) -> int: ...
@overload
def get_option(self, option: AnyOption) -> AnyOptionValue: ...
[docs]
def get_option(self, option: AnyOption) -> AnyOptionValue:
"""Get a LDAP connection option."""
with errors.LdapError.wrap(self._hide_parent_exception):
return cast('AnyOptionValue', self.conn.get_option(option))
@overload
def set_option(self, option: Literal[Option.ServerControls], value: Sequence[ldap.controls.RequestControl]) -> None: ...
@overload
def set_option(self, option: Literal[Option.ClientControls], value: Sequence[ldap.controls.RequestControl]) -> None: ...
@overload
def set_option(self, option: Literal[Option.ProtocolVersion], value: int) -> None: ...
@overload
def set_option(self, option: AnyOption, value: AnyOptionValue | Sequence[ldap.controls.RequestControl], *, append: bool = True) -> None: ...
[docs]
def set_option(self, option: AnyOption, value: AnyOptionValue | Sequence[ldap.controls.RequestControl], *, append: bool = True) -> None:
"""Set a LDAP connection option."""
if append:
self._options.append((option, value))
with errors.LdapError.wrap(self._hide_parent_exception):
self.conn.set_option(option, value)
[docs]
def set_controls(self, controls: Controls) -> None:
"""Set LDAP controls for all operations on this connection."""
if controls.server is not None:
self.set_option(Option.ServerControls, controls.server)
if controls.client is not None:
self.set_option(Option.ClientControls, controls.client)
@property
def protocol_version(self) -> Version:
"""Get the LDAP protocol version."""
return Version(self.get_option(Option.ProtocolVersion))
@protocol_version.setter
def protocol_version(self, value: Version) -> None:
"""Set the LDAP protocol version."""
self.set_option(Option.ProtocolVersion, value)
@property
def timelimit(self) -> int:
"""Get the LDAP time limit."""
return self.get_option(Option.Timelimit)
@timelimit.setter
def timelimit(self, value: int) -> None:
"""Set the LDAP time limit."""
self.set_option(Option.Timelimit, value)
@property
def network_timeout(self) -> int:
"""Get the LDAP network timeout."""
return self.get_option(Option.NetworkTimeout)
@network_timeout.setter
def network_timeout(self, value: int) -> None:
"""Set the LDAP network timeout."""
self.set_option(Option.NetworkTimeout, value)
@property
def dereference(self) -> int:
"""Get the de-reference setting."""
return self.get_option(Option.Dereference)
@dereference.setter
def dereference(self, value: int) -> None:
"""Set the de-reference setting."""
self.set_option(Option.Dereference, value)
@property
def follow_referrals(self) -> bool | None:
"""Follow referrals enabled."""
follow = self.get_option(Option.Referrals)
if follow == -1:
return None
return follow == OptionValue.On
@follow_referrals.setter
def follow_referrals(self, value: bool) -> None:
"""Enable following of referrals."""
self.set_option(Option.Referrals, OptionValue.On if value else OptionValue.Off)
@property
def sizelimit(self) -> int:
"""Get the sizelimit setting."""
return self.get_option(Option.Sizelimit)
@sizelimit.setter
def sizelimit(self, value: int) -> None:
"""Set the sizelimit setting."""
self.set_option(Option.Sizelimit, value)
[docs]
@classmethod
def set_tls(
cls,
*,
ca_certfile: str | None = None,
ca_certdir: str | None = None,
certfile: str | None = None,
keyfile: str | None = None,
require_cert: TLSRequireCert = TLSRequireCert.Demand,
require_san: TLSRequireCert | None = None,
minimum_protocol: TLSProtocol | None = None,
cipher_suite: str | None = None,
crlfile: None = None,
crl_check: TLSCRLCheck | None = None,
) -> None:
"""Set the TLS certificate settings globally."""
for option, value in (
(TLSOption.CACertfile, ca_certfile),
(TLSOption.CACertdir, ca_certdir),
(TLSOption.Certfile, certfile),
(TLSOption.Keyfile, keyfile),
(TLSOption.ProtocolMin, minimum_protocol),
(TLSOption.CipherSuite, cipher_suite),
(TLSOption.RequireCert, require_cert),
(TLSOption.RequireSAN, require_san),
(TLSOption.CRLFile, crlfile),
(TLSOption.CRLCheck, crl_check),
):
if value is not None:
cls.set_global_option(option, value)
# apply the pending TLS settings, create new context:
cls.set_global_option(TLSOption.NewContext, 0)
[docs]
@classmethod
def get_global_option(cls, option: AnyOption) -> AnyOptionValue:
"""Get a LDAP connection option."""
return cast('AnyOptionValue', ldap.get_option(option))
[docs]
@classmethod
def set_global_option(cls, option: AnyOption, value: AnyOptionValue) -> None:
"""Set a global LDAP option."""
ldap.set_option(option, value)
[docs]
def connect(self, fileno: bool | None = None) -> None:
"""Connect to the LDAP server."""
if self.connected:
raise RuntimeError('already connected') # noqa: TRY003
if self.automatic_reconnect:
self._conn = ldap.ldapobject.ReconnectLDAPObject(
self.uri,
trace_level=0,
trace_file=None,
trace_stack_limit=5,
fileno=fileno,
retry_max=self.max_connection_attempts,
retry_delay=self.retry_delay,
)
else:
self._conn = ldap.ldapobject.SimpleLDAPObject(
self.uri,
trace_level=0,
trace_file=None,
trace_stack_limit=5,
fileno=fileno,
)
self.__conn_s = None
self.__reconnects_counter = 0
if self._start_tls:
self.start_tls()
[docs]
def disconnect(self) -> None:
"""Disconnect from LDAP server."""
self._remove_reader()
self._conn = None
self.__conn_s = None
[docs]
def reconnect(self, *, force: bool = True) -> None:
"""Reconnect to the LDAP server."""
self._remove_reader()
with errors.LdapError.wrap(self._hide_parent_exception):
self.conn.reconnect(self.uri, self.max_connection_attempts, self.retry_delay, force=force)
# if self._start_tls:
# self.start_tls()
self._restore_options()
self._restore_auth_state()
[docs]
def start_tls(self) -> None:
"""Start TLS."""
self._start_tls = True
with errors.LdapError.wrap(self._hide_parent_exception):
self.conn.start_tls_s()
[docs]
async def get_schema(self, subschema_dn: DN | str | None = None) -> Schema:
"""Get LDAP Schema."""
conn = self.conn
# cache schema by connection
if isinstance(conn, ldap.ldapobject.ReconnectLDAPObject) and conn._reconnects_done > self.__reconnects_counter:
del self.__schema[subschema_dn]
self.__reconnects_counter = self.conn._reconnects_done
if not self.__schema.get(subschema_dn):
try:
subschemasubentry = await (
self.get(subschema_dn, attrs=['subschemaSubentry']) if subschema_dn else self.get_root_dse(['subschemaSubentry'])
)
try:
subschemasubentry_dn = cast('Attributes', subschemasubentry.attr)['subschemaSubentry'][0].decode('UTF-8')
except KeyError: # pragma: no cover; impossible?
if subschema_dn:
return await self.get_schema()
subschema = None
except (errors.NoSuchObject, errors.NoSuchAttribute, errors.InsufficientAccess, errors.UndefinedType):
subschema = None
else:
try:
subschema = (await self.get(subschemasubentry_dn, SCHEMA_ATTRS, '(objectClass=subschema)')).attr
except errors.NoSuchObject: # pragma: no cover
subschema = None
self.__schema[subschema_dn] = Schema(ldap.schema.SubSchema(cast('dict[str, list[bytes]]', subschema), 0))
Attributes.set_schema(self.__schema[subschema_dn])
return self.__schema[subschema_dn]
[docs]
async def bind(self, authzid: str | None, password: str | None, *, controls: Controls | None = None) -> Result:
"""Authenticate via plaintext credentials."""
conn = self.conn
self._last_auth_state = ('simple_bind_s', authzid, password)
response = await self._execute(conn, conn.simple_bind, authzid, password, **Controls.expand(controls))
return Result.from_response(None, None, controls, response)
[docs]
async def bind_external(self) -> None: # pragma: no cover
"""Authenticate via EXTERNAL method e.g. UNIX socket or TLS client certificate."""
with errors.LdapError.wrap(self._hide_parent_exception):
self.conn.sasl_interactive_bind_s('', ldap.sasl.external())
[docs]
async def bind_sasl_gssapi(self) -> None: # pragma: no cover
"""Authenticate via GSSAPI e.g. via Kerberos ticket."""
with errors.LdapError.wrap(self._hide_parent_exception):
self.conn.sasl_interactive_bind_s('', ldap.sasl.gssapi())
[docs]
async def bind_oauthbearer(self, authzid: str | None, token: str) -> None: # pragma: no cover; requires SASL module
"""Authenticate via OAuth 2.0 Access Token."""
oauth = ldap.sasl.sasl(
{
ldap.sasl.CB_AUTHNAME: authzid,
ldap.sasl.CB_PASS: token,
},
'OAUTHBEARER',
)
with errors.LdapError.wrap(self._hide_parent_exception):
self.conn.sasl_interactive_bind_s('', oauth)
def _restore_options(self) -> None:
for option, value in self._options:
self.set_option(option, value, append=False)
def _restore_auth_state(self) -> None:
if self._last_auth_state:
with errors.LdapError.wrap(self._hide_parent_exception):
getattr(self.conn, self._last_auth_state[0])(*self._last_auth_state[1:])
[docs]
async def unbind(self, *, controls: Controls | None = None) -> Result | None:
"""Unbind."""
self._last_auth_state = None
try:
conn = self.conn
except RuntimeError: # not connected
return None
try:
response = await self._execute(conn, conn.unbind_ext, **Controls.expand(controls))
except AttributeError as exc: # duplicated unbind
if exc.args and exc.args[0] == "ReconnectLDAPObject has no attribute '_l'":
return None
raise # pragma: no cover; not possible
return Result.from_response(None, None, controls, response)
[docs]
async def whoami(self, *, controls: Controls | None = None) -> DN | str | None:
"""Get authenticated user DN (authzid). "Who am I?" Operation."""
try:
with errors.LdapError.wrap(self._hide_parent_exception):
dn = self.conn.whoami_s(**Controls.expand(controls))
except RuntimeError:
return None
if dn.startswith('dn:'):
return DN(dn.removeprefix('dn:'))
return dn # pragma: no cover
[docs]
async def change_password(self, dn: DN | str, old_password: str, new_password: str, *, controls: Controls | None = None) -> Result:
"""Change password."""
conn = self.conn
response = await self._execute(conn, conn.passwd, str(dn), old_password, new_password, **Controls.expand(controls))
return Result.from_response(dn, None, controls, response) # pragma: no cover
[docs]
async def exists(self, dn: DN | str, unique: bool = False, *, controls: Controls | None = None) -> bool:
"""Check if LDAP object exists."""
try:
await self.get(dn, ['1.1'], unique=unique, controls=controls)
except errors.NoSuchObject:
return False
return True
[docs]
async def get(
self,
dn: DN | str,
attrs: list[str] | None = None,
filter_expr: str = '(objectClass=*)',
*,
unique: bool = False,
controls: Controls | None = None,
) -> Result:
"""Get a LDAP object."""
for obj in await self.search(base=dn, scope=Scope.BASE, filter_expr=filter_expr, attrs=attrs, unique=unique, controls=controls):
return obj
return None # type: ignore[return-value] # pragma: no cover; impossible
# obj, = [_ async for _ in self.search_iter(base=dn, scope=Scope.BASE, filter_expr=filter_expr, attrs=attrs, unique=unique, controls=controls)] # noqa: E501
# return obj[0]
# # GC calls gen.aclose() causing unnecessary .cancel() to be called:
# # return await anext(self.search_iter(base=dn, scope=Scope.BASE, filter_expr=filter_expr, attrs=attrs, unique=unique, controls=controls))
[docs]
async def get_attr(
self, dn: DN | str, attr: str, filter_expr: str = '(objectClass=*)', *, unique: bool = False, controls: Controls | None = None
) -> list[bytes]:
"""Get attribute of an LDAP object."""
attributes = (await self.get(dn, attrs=[attr], filter_expr=filter_expr, unique=unique, controls=controls)).attr
assert attributes is not None # noqa: S101
try:
return attributes[attr]
except KeyError:
await self.get_schema()
return attributes[attr]
[docs]
async def search_iter(
self,
base: DN | str = '',
scope: Scope = Scope.SUBTREE,
filter_expr: str = '(objectClass=*)',
attrs: list[str] | None = None,
*,
unique: bool = False,
sizelimit: bool | None = None,
sorting: Sorting | None = None,
controls: Controls | None = None,
_attrsonly: bool = False,
) -> AsyncGenerator[Result, None]:
"""Search iterative for DN and Attributes of LDAP objects."""
conn = self.conn
all_results = []
if sorting:
controls = Controls.set_server(controls, server_side_sorting(*sorting, criticality=True))
# sizelimit = 1 if unique else sizelimit
try:
async for response in self._execute_iter(
conn,
conn.search_ext,
str(base),
scope,
filterstr=filter_expr,
attrlist=attrs,
attrsonly=int(_attrsonly),
**Controls.expand(controls),
timeout=self.timeout,
sizelimit=sizelimit or OptionValue.NoLimit,
):
Result.set_controls(response, controls)
assert response.data is not None # noqa: S101
results = [Result.from_response(dn, attributes, controls, response) for dn, attributes in response.data]
all_results.extend(results)
if unique and len(all_results) > 1:
raise errors.NotUnique(all_results)
try:
for result in results:
yield result
except GeneratorExit as exc:
with contextlib.suppress(errors.NoSuchOperation):
# await self.cancel(response.msgid) # better do it immediately
assert response.msgid is not None # noqa: S101
self._sync_connection.cancel(response.msgid)
raise exc from exc
except errors.NoSuchObject as no_object_error:
no_object_error.base_dn = DN.get(base)
no_object_error.filter = filter_expr
no_object_error.scope = scope
no_object_error.attrs = attrs
raise
# except errors.SizelimitExceeded:
# if not unique:
# raise
# raise errors.NotUnique() from None
[docs]
async def search(
self,
base: DN | str = '',
scope: Scope = Scope.SUBTREE,
filter_expr: str = '(objectClass=*)',
attrs: list[str] | None = None,
*,
unique: bool = False,
sizelimit: bool | None = None,
sorting: Sorting | None = None,
controls: Controls | None = None,
_attrsonly: bool = False,
) -> list[Result]:
"""Search for DN and Attributes of LDAP objects."""
conn = self.conn
all_results = []
if sorting:
controls = Controls.set_server(controls, server_side_sorting(*sorting))
try:
response = await self._execute(
conn,
conn.search_ext,
str(base),
scope,
filterstr=filter_expr,
attrlist=attrs,
attrsonly=int(_attrsonly),
**Controls.expand(controls),
timeout=self.timeout,
sizelimit=sizelimit or OptionValue.NoLimit,
)
Result.set_controls(response, controls)
assert response.data is not None # noqa: S101
results = [Result.from_response(dn, attributes, controls, response) for dn, attributes in response.data]
all_results.extend(results)
if unique and len(all_results) > 1:
raise errors.NotUnique(all_results)
except errors.NoSuchObject as no_object_error:
no_object_error.base_dn = DN.get(base)
no_object_error.filter = filter_expr
no_object_error.scope = scope
no_object_error.attrs = attrs
raise
return results
[docs]
async def search_dn(
self,
base: DN | str = '',
scope: Scope = Scope.SUBTREE,
filter_expr: str = '(objectClass=*)',
*,
unique: bool = False,
sizelimit: bool | None = None,
sorting: Sorting | None = None,
controls: Controls | None = None,
) -> AsyncGenerator[DN, None]:
"""Search for DNs of LDAP objects."""
# FIXME: the following hangs forever as the iterative search is unfinished while the FD reader is replaced
# async for result in self.search(
# base, scope, filter_expr, ['1.1'], unique=unique, sizelimit=sizelimit, sorting=sorting, controls=controls, _attrsonly=True
# ):
for result in await self.search(
base, scope, filter_expr, [], unique=unique, sizelimit=sizelimit, sorting=sorting, controls=controls, _attrsonly=True
):
assert result.dn is not None # noqa: S101
yield result.dn
[docs]
async def search_paginated(
self,
base: DN | str = '',
scope: Scope = Scope.SUBTREE,
filter_expr: str = '(objectClass=*)',
attrs: list[str] | None = None,
*,
page_size: int = 100,
sorting: Sorting,
unique: bool = False,
sizelimit: bool | None = None,
controls: Controls | None = None,
) -> AsyncGenerator[Result, None]:
"""Search paginated using Virtual List View control."""
controls = Controls.set_server(controls, server_side_sorting(*sorting))
res_vlv = virtual_list_view.response()
context_id = None
length = None
last_page = None
page = 1
while True:
offset = ((page or 1) - 1) * page_size
pagination = virtual_list_view(
before_count=0,
after_count=page_size - 1,
offset=offset + 1,
content_count=0,
greater_than_or_equal=None,
context_id=context_id,
criticality=True,
)
controls = Controls.set_server(controls, pagination)
if length is not None and offset > length:
break # end reached
vlv: ldap.controls.vlv.VLVResponseControl | None = None
current = None
for entry_number, result in enumerate(
await self.search(base, scope, filter_expr, attrs, unique=unique, sizelimit=sizelimit, controls=controls), 1
):
if last_page is None:
vlv = cast('ldap.controls.vlv.VLVResponseControl', controls.get(res_vlv))
length = vlv.contentCount
last_page = math.ceil(length / (page_size or length))
result.page = Page(
page=page,
entry=entry_number,
page_size=page_size,
results=length,
last_page=last_page,
)
current = result
yield result
if current is None: # no search results
break
page += 1
vlv = cast('ldap.controls.vlv.VLVResponseControl', controls.get(res_vlv))
context_id = vlv.context_id
length = vlv.contentCount
[docs]
async def search_paged(
self,
base: DN | str = '',
scope: Scope = Scope.SUBTREE,
filter_expr: str = '(objectClass=*)',
attrs: list[str] | None = None,
page_size: int = 100,
*,
unique: bool = False,
sizelimit: bool | None = None,
sorting: Sorting | None = None,
controls: Controls | None = None,
) -> AsyncGenerator[Result, None]:
"""Search paginated using SimplePagedResults control."""
pagination = simple_paged_results(size=page_size, cookie='', criticality=True)
controls = Controls.append_server(controls, pagination)
if sorting:
controls = Controls.set_server(controls, server_side_sorting(*sorting))
page = 0
while True:
current = None
page += 1
entry_number = 0
async for result in self.search_iter(base, scope, filter_expr, attrs, unique=unique, sizelimit=sizelimit, controls=controls):
entry_number += 1
result.page = Page(page=page, entry=entry_number, page_size=page_size)
current = result
yield result
if current is None: # no search results
break
control = controls.get(pagination)
if not control: # pragma: no cover
break # Server doesn't support pagination
pagination.cookie = cast('ldap.controls.pagedresults.SimplePagedResultsControl', controls.get(pagination)).cookie
if not pagination.cookie:
break
[docs]
async def add(
self,
dn: DN | str,
attrs: dict[str, list[bytes]] | Attributes,
*,
controls: Controls | None = None,
) -> Result:
"""Create a LDAP object."""
al = ldap.modlist.addModlist(attrs)
return await self.add_al(dn, al, controls=controls)
[docs]
async def add_al(
self,
dn: DN | str,
al: LDAPAddList,
*,
controls: Controls | None = None,
) -> Result:
"""Create a LDAP object from addlist."""
conn = self.conn
response = await self._execute(conn, conn.add_ext, str(dn), al, **Controls.expand(controls))
return Result.from_response(dn, None, controls, response)
[docs]
async def modify(
self,
dn: DN | str,
oldattr: dict[str, list[bytes]] | Attributes,
newattr: dict[str, list[bytes]] | Attributes,
*,
controls: Controls | None = None,
) -> Result:
"""Modify a LDAP object."""
ml = ldap.modlist.modifyModlist(oldattr, newattr)
return await self.modify_ml(dn, ml, controls=controls)
[docs]
async def modify_ml(
self,
dn: DN | str,
ml: LDAPModList,
*,
controls: Controls | None = None,
) -> Result:
"""Modify a LDAP object from modlist."""
conn = self.conn
new_dn = self._compute_changed_dn(DN.get(dn), ml)
if dn != new_dn:
dn = cast('DN', (await self.rename(dn, new_dn)).dn)
response = await self._execute(conn, conn.modify_ext, str(dn), ml, **Controls.expand(controls))
return Result.from_response(dn, None, controls, response)
@classmethod
def _compute_changed_dn(cls, dn: DN, ml: LDAPModList) -> DN:
"""
Get changed DN.
>>> Connection._compute_changed_dn('cn=foo,dc=bar', [(ldap.MOD_REPLACE, 'cn', b'foo')])
'cn=foo,dc=bar'
>>> Connection._compute_changed_dn('cn=foo,dc=bar', [(ldap.MOD_REPLACE, 'cn', b'bar')])
'cn=bar,dc=bar'
>>> Connection._compute_changed_dn('cn=foo,dc=bar', [(ldap.MOD_REPLACE, 'cn', b'föo')]) == 'cn=föo,dc=bar'
True
"""
rdn = dn.rdns[0]
dn_vals = {x[0].lower(): x[1] for x in rdn}
new_vals = {
key.lower(): val.decode('UTF-8') if isinstance(val, bytes) else val[0].decode('UTF-8')
for op, key, val in ml
if key.lower() in dn_vals and val and op != ldap.MOD_DELETE
}
new_rdn_ava = [(x, new_vals.get(x.lower(), dn_vals[x.lower()]), ldap.AVA_STRING) for x in [y[0] for y in rdn]]
new_rdn = DN(
ldap.dn.dn2str(
[
[(key, val, ava_type) for key, val, ava_type in new_rdn_ava],
],
),
)
if dn[0] != new_rdn:
return new_rdn + cast('DN', dn.parent)
return dn
[docs]
async def move(
self,
dn: DN | str,
newposition: DN | str,
*,
controls: Controls | None = None,
) -> Result:
"""Move a LDAP object."""
dn = DN.get(dn)
newposition = DN.get(newposition)
return await self.rename(dn, dn[0] + newposition, delete_old=True, controls=controls)
[docs]
async def rename(
self,
dn: DN | str,
newdn: DN | str,
delete_old: bool = True,
*,
controls: Controls | None = None,
) -> Result:
"""Rename a LDAP object."""
conn = self.conn
newdn = DN.get(newdn)
response = await self._execute(conn, conn.rename, str(dn), str(newdn[0]), str(newdn.parent), int(delete_old), **Controls.expand(controls))
return Result.from_response(newdn, None, controls, response)
[docs]
async def modrdn(
self,
dn: DN | str,
newrdn: DN | str,
delete_old: bool = True,
*,
controls: Controls | None = None,
) -> Result:
"""Rename a LDAP object."""
return await self.rename(dn, DN.get(newrdn) + cast('DN', DN.get(dn).parent), delete_old, controls=controls)
[docs]
async def delete(self, dn: DN | str, *, controls: Controls | None = None) -> Result:
"""Delete a LDAP object."""
conn = self.conn
response = await self._execute(conn, conn.delete_ext, str(dn), **Controls.expand(controls))
return Result.from_response(dn, None, controls, response)
[docs]
async def delete_recursive(self, dn: DN | str, *, controls: Controls | None = None) -> Result:
"""Delete a LDAP object recursively."""
try:
return await self.delete(dn, controls=controls)
except errors.NotAllowedOnNonleaf:
async for child in self.search_dn(dn, Scope.ONELEVEL):
await self.delete_recursive(child)
return await self.delete(dn, controls=controls)
[docs]
async def compare(
self,
dn: DN | str,
attr: str,
value: bytes,
*,
controls: Controls | None = None,
) -> bool:
"""Compare the value of an LDAP object."""
conn = self.conn
try:
await self._execute(conn, conn.compare_ext, str(dn), attr, value, **Controls.expand(controls))
except errors.NoSuchObject as no_object_error:
no_object_error.base_dn = DN.get(dn)
raise
except errors.CompareTrue:
return True
except errors.CompareFalse:
return False
raise RuntimeError() # pragma: no cover; impossible
[docs]
async def compare_dn(self, entry: DN | str, dn: DN | str) -> bool:
"""Compare LDAP DN with existing entry."""
dn = DN.get(dn)
entry = DN.get(entry)
for i, parent in enumerate(entry.walk()):
for attr, value, _ in dn.rdns[-i - 1]:
try:
equal = await self.compare(str(parent), attr, value.encode('UTF-8'))
if not equal: # pragma: no cover; https://github.com/nedbat/coveragepy/issues/2014
return False
except errors.NoSuchObject:
if attr == entry.rdns[-1][0][0]:
continue
raise
return True
[docs]
async def get_root_dse(self, attrs: list[str] | None = None, filter_expr: str = '(objectClass=*)') -> Result:
"""Get Root DSE (Directory Server Entry)."""
return await self.get('', attrs or ['*', '+'], filter_expr=filter_expr)
[docs]
async def get_naming_contexts(self) -> list[str]:
"""Return namingContexts of Root DSE."""
result = await self.get_attr('', 'namingContexts')
return [x.decode('UTF-8') for x in result]
[docs]
async def abandon(self, msgid: int, *, controls: Controls | None = None) -> Result:
"""Abandon a LDAP operation."""
log.debug('Abandon: %s', msgid)
conn = self.conn
response = await self._execute(conn, conn.abandon_ext, msgid, **Controls.expand(controls))
return Result.from_response(None, None, controls, response)
[docs]
async def cancel(self, msgid: int, *, controls: Controls | None = None) -> bool:
"""Cancel a LDAP operation."""
log.debug('Cancel: %s', msgid)
try:
conn = self.conn
except RuntimeError:
return False
try:
await self._execute(conn, conn.cancel, msgid, **Controls.expand(controls))
except errors.NoSuchOperation:
log.warning('Cancel failed', extra={'msgid': msgid})
raise
except (errors.Cancelled, errors.Success): # pragma: no cover; theoretically, according to python-ldap
return True
except errors.TooLate: # pragma: no cover
return False
else: # pragma: no cover
return True
[docs]
@contextlib.asynccontextmanager
async def transaction(self, set_controls: bool = True) -> AsyncGenerator[bytes, None]:
"""Context manager to make a transaction, which is aborted on errors."""
result = await self.extended(transaction_start(), transaction_start.response)
txn_id = result.extended_value
if set_controls:
self.set_controls(Controls.set_server(None, transaction(txn_id, criticality=True)))
try:
yield txn_id
except BaseException:
self.set_option(Option.ServerControls, [])
await self.extended(transaction_commit(txn_id, commit=False), transaction_commit.response)
raise
else:
self.set_option(Option.ServerControls, [])
try:
await self.extended(transaction_commit(txn_id, commit=True), transaction_commit.response)
except errors.OperationsError as exc:
log.warning('Failure during commiting transaction', extra={'error': exc})
[docs]
async def refresh_ttl(self, dn: DN | str, ttl: int) -> Result:
"""Perform Refresh extended operation."""
req = refresh_ttl(dn, ttl)
result = await self.extended(req, refresh_ttl.response)
result.dn = DN.get(dn)
return result
[docs]
async def extended(
self, request: ExtendedRequest, response_class: type[ExtendedResponse] | None = None, *, controls: Controls | None = None
) -> Result:
"""Perform extended operation."""
conn = self.conn
response = await self._execute(conn, conn.extop, request, **Controls.expand(controls))
decoded_value = None
if response_class:
with errors.LdapError.wrap(self._hide_parent_exception):
if response_class.responseName != response.name:
raise errors.ProtocolError({
'desc': 'OID in extended response does not match response class.',
'info': f'expected: {response_class.responseName}; got: {response.name}',
})
if response.value:
decoded_value = response_class(response.name, response.value).responseValue
return Result.from_response(None, None, controls, response, extended_value=decoded_value)
def __getstate__(self) -> dict[str, Any]:
"""Return state for pickle."""
return {slot: getattr(self, slot) for slot in set(self.__slots__) - {'_conn'} | {'connected'} if not slot.startswith('__')}
def __setstate__(self, state: dict[str, Any]) -> None:
"""Set state for pickle."""
self._conn = None
connected = state.pop('connected', None)
for slot, value in state.items():
setattr(self, slot, value)
if connected:
self.connect()
self._restore_options()
self._restore_auth_state()
async def _execute(self, conn: LDAPObject, operation: Callable[..., Any], *args: Any, **kwargs: Any) -> _Response:
"""Execute the operation and wait asynchronously for the result."""
msgid = await self._retry(self.request, operation, *args, **kwargs)
if msgid is None: # abandon_ext, unbind_ext
return _Response(None, None, msgid, [], None, None)
response: _Response | None = None
async for resp in self._poll(conn, msgid, 1):
if response is not None: # pragma: no cover
raise RuntimeError('Wrong method used! Use _execute_iter instead!') # noqa: TRY003
response = resp
assert response is not None # noqa: S101
return response
async def _execute_iter(self, conn: LDAPObject, operation: Callable[..., Any], *args: Any, **kwargs: Any) -> AsyncGenerator[_Response, None]:
"""Execute the operation and yield the results asynchronously."""
msgid = await self._retry(self.request, operation, *args, **kwargs)
if msgid is None: # abandon_ext, unbind_ext
return
async for response in self._poll(conn, msgid, 0):
yield response
[docs]
def get_result(self, conn: LDAPObject, msgid: int = ResponseType.Any, _all: int = 0, timeout: int = 0) -> _Response:
"""Get the LDAP result for the given msgid."""
log.debug('result(%r, timeout=%r)', msgid, timeout, extra={'MSGID': msgid, 'ALL': _all, 'TIMEOUT': timeout, 'FUNC': 'result'})
try:
with errors.LdapError.wrap(self._hide_parent_exception):
response = _Response(*conn.result4(msgid, all=_all, timeout=timeout, add_extop=1)) # type: ignore[arg-type]
except (errors.LdapError, OSError) as exc:
log.debug('result(%r) -> raised %r', msgid, exc, extra={'MSGID': msgid, 'OPERATION': 'result', 'EXCEPTION': str(exc)})
raise
log.debug('result(%r) -> %s', msgid, repr(response)[:200], extra={'MSGID': msgid, 'OPERATION': 'result'})
return response
[docs]
def request(self, operation: Callable[..., int], *args: Any, **kwargs: Any) -> int | None:
"""Make the LDAP request for the given operation."""
op = operation.__name__
arg_str = ', '.join(map(repr, args)) if 'bind' not in op else ''
kw = ', '.join(f'{k}={v!r}' for k, v in kwargs.items()) if 'bind' not in op else ''
log.debug('Request %s(%s%s%s)', op, arg_str, ', ' if kw else '', kw, extra={'OPERATION': op, 'ARGUMENTS': arg_str, 'KEYWORDS': kw})
try:
with errors.LdapError.wrap(self._hide_parent_exception):
msgid = operation(*args, **kwargs)
except (errors.LdapError, OSError) as exc:
log.debug('%s() -> %r', op, exc, extra={'OPERATION': op, 'EXCEPTION': str(exc)})
raise
log.debug('%s() -> %r', op, msgid, extra={'OPERATION': op, 'MSGID': msgid})
return msgid
async def _retry(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
"""Retry operation or reconnect if necessary."""
max_attempts = attempts = self.max_connection_attempts
while attempts:
try:
if attempts != max_attempts:
# do a sync reconnect, make sure we don't end in recursion
self.reconnect() # FIXME: recursion!?
# self.conn.reconnect(self.uri)
# self._restore_auth_state() # FIXME: recursion
return func(*args, **kwargs)
except (errors.ServerDown, errors.Unavailable, errors.ConnectError, errors.Timeout):
attempts -= 1
if not attempts:
raise
await asyncio.sleep(self.retry_delay)
raise RuntimeError() # pragma: no cover; impossible
def _poll_s(
self, conn: LDAPObject, msgid: ResponseType = ResponseType.Any, _all: int = 0
) -> Generator[_Response, None, None]: # pragma: no cover
"""Wait synchronously for operation to succeed."""
# this method must only used by the synchronous variant of this class
while True:
try:
response = self._retry(self.get_result, conn, msgid, _all=_all, timeout=self.timeout)
except errors.NoResultsReturned: # pragma: no cover
break
rtype = response.type # type: ignore[attr-defined]
if rtype is None:
continue
yield response # type: ignore[misc]
if rtype == ldap.RES_SEARCH_ENTRY:
continue
if rtype == ldap.RES_SEARCH_RESULT:
break
break
async def _poll(self, conn: LDAPObject, msgid: ResponseType = ResponseType.Any, _all: int = 0) -> AsyncGenerator[_Response, None]:
"""Wait asynchronously for operation to succeed."""
loop = asyncio.get_running_loop()
while True:
# TODO: move the asyncio stuff out of here
fut = loop.create_future()
fd = conn.fileno()
self._add_reader(loop, fd, self._ready, conn, msgid, fut, _all)
try:
response = await self._wait_for(fut)
except errors.NoResultsReturned: # pragma: no cover; how?
self._remove_reader(fd)
break
except Exception:
self._remove_reader(fd)
raise
rtype = response.type
if rtype is None: # pragma: no cover; handled in _ready()
continue
yield response
if rtype == ldap.RES_SEARCH_ENTRY:
continue
if rtype == ldap.RES_SEARCH_RESULT:
break
break
def _ready(self, fd: int, conn: LDAPObject, msgid: int, fut: asyncio.Future[_Response], _all: ResponseType) -> None:
log.debug('FD %s is ready', fd)
try:
os.fstat(fd)
response = self.get_result(conn, msgid, _all=_all, timeout=0)
if response.type is None:
return
fut.set_result(response)
self._remove_reader(fd)
except (OSError, errors.LdapError) as exc:
log.error('FD %s is not valid - maybe connection is closed', fd) # noqa: TRY400
fut.set_exception(exc)
self._remove_reader(fd)
async def _wait_for(self, fut: asyncio.Future[_Response]) -> _Response:
if self.timeout > 0:
return await asyncio.wait_for(fut, timeout=self.timeout)
return await fut
@classmethod
def _add_reader(cls, loop: asyncio.AbstractEventLoop, fd: int, func: Callable[..., Any], *args: Any) -> None:
log.debug('Select on FD %s', fd)
os.fstat(fd)
# FIXME: deadlock; it replaces a previous reader, e.g. when in a iterative search or other parallel operation
loop.add_reader(fd, func, fd, *args)
# register reader from the loop thread
# loop.call_soon_threadsafe(lambda: loop.add_reader(fd, func, fd, *args))
def _remove_reader(self, fd: int | None = None) -> None:
fd = fd or self.fileno
log.debug('Remove reader FD %s', fd)
if fd == -1: # pragma: no cover
return
loop = asyncio.get_running_loop()
loop.remove_reader(fd)
# loop.call_soon_threadsafe(lambda: loop.remove_reader(fd))