# SPDX-FileCopyrightText: 2025 Florian Best
# SPDX-License-Identifier: MIT OR Apache-2.0
"""LDAP Distinguished Name (DN) utilities."""
from __future__ import annotations
import functools
from typing import TYPE_CHECKING, Self
import ldap.dn
from freeiam.errors import InvalidDN
from freeiam.ldap.constants import AVA, DNFormat
if TYPE_CHECKING:
from collections.abc import Generator
__all__ = ('DN',)
@functools.lru_cache
def _to_dn(dn: str) -> list[list[tuple[str, str, int]]]:
return ldap.dn.str2dn(dn)
[docs]
class DN:
"""A LDAP Distinguished Name."""
_CASE_INSENSITIVE_ATTRIBUTES = ('c', 'cn', 'dc', 'l', 'o', 'ou', 'uid')
__slots__ = ('_cached_hash', '_cached_normalized', '_dn', '_format', 'dn')
[docs]
@classmethod
def get(cls, dn: Self | str) -> Self:
"""Get a DN from string or existing DN."""
return cls(dn) if isinstance(dn, str) else dn
[docs]
@classmethod
def escape(cls, value: str) -> str:
"""Escape LDAP DN value."""
return ldap.dn.escape_dn_chars(value)
[docs]
@classmethod
def compose(cls, *parts: DN | str | tuple[str, str] | tuple[str, str, int]) -> Self:
"""
Compose a DN from different segments.
>>> base = DN('dc=freeiam,dc=org')
>>> str(DN.compose(('cn', 'admin'), 'ou=foo,ou=bar', base))
"cn=admin,ou=foo,ou=bar,dc=freeiam,dc=org"
"""
rdns: list[list[tuple[str, str, int]]] = []
for part in parts:
if isinstance(part, DN):
rdns.extend(part.rdns)
elif isinstance(part, str):
rdns.extend(cls(part).rdns)
elif isinstance(part, tuple) and len(part) == 3: # noqa: PLR2004
rdns.append([part])
elif isinstance(part, tuple) and len(part) == 2: # noqa: PLR2004
rdns.append([(part[0], part[1], AVA.String)])
else:
raise TypeError(part)
return cls(ldap.dn.dn2str(rdns))
[docs]
@classmethod
def normalize(cls, dn: Self | str) -> str:
"""Normalize DN."""
return str(cls.get(dn))
[docs]
@classmethod
def get_unique(cls, dns: list[str]) -> set[Self]:
"""
Return a unique set of DNs.
>>> len(DN.unique(['CN=users,dc=freeiam,dc=org', 'cn=users,dc=freeiam,dc=org', 'cn = users,dc=freeiam,dc=org', 'CN=Users,dc=freeiam,dc=org']))
1
"""
return {cls(dn) for dn in dns}
[docs]
@classmethod
def get_unique_str(cls, dns: list[Self]) -> set[str]:
"""
Return a unique set of DN strings from DNs.
>>> DN.get_unique_str(DN.unique(['cn=foo', 'cn=bar']) - DN.unique(['cn = foo'])) == {'cn=bar'}
True
"""
return {str(dn) for dn in dns}
@property
def rdn(self) -> tuple[str, str] | tuple[()]:
"""
Get attr and value of the first RDN component.
>>> DN('cn=foo,cn=bar').rdn
('cn', 'foo')
"""
try:
return self.multi_rdn[0]
except IndexError:
return ()
@property
def attribute(self) -> str | None:
"""
Get attribute name of the first RDN component.
>>> DN('cn=foo,cn=bar').attribute
'cn'
"""
rdn = self.rdn
return rdn[0] if rdn else None
@property
def value(self) -> str | None:
"""
Get value of the first RDN component.
>>> DN('cn=foo,cn=bar').value
'foo'
"""
rdn = self.rdn
return rdn[1] if rdn else None
@property
def multi_rdn(self) -> tuple[tuple[str, str], ...]:
"""
Get all attrs and values of the RDN.
>>> DN('uid=1+cn=2,dc=3').rdn
(('uid', '1'), ('cn', '2'))
"""
try:
return tuple((rdn[0], rdn[1]) for rdn in self._dn[0])
except IndexError:
return ()
@property
def attributes(self) -> tuple[str, ...]:
"""
Get attribute name of the first RDN component.
>>> DN('uid=1+cn=2,dc=3').attributes
('uid', 'cn')
"""
try:
return tuple(rdn[0] for rdn in self._dn[0])
except IndexError:
return ()
@property
def values(self) -> tuple[str, ...]:
"""
Get value of the first RDN component.
>>> DN('uid=1+cn=2,dc=3').values
('1', '2')
"""
try:
return tuple(rdn[1] for rdn in self._dn[0])
except IndexError:
return ()
@property
def rdns(self) -> list[list[tuple[str, str, int]]]:
"""Get the single RDN items."""
return self._dn
@property
def parent(self) -> Self | None:
"""
Get the parent DN.
>>> DN('cn=item,cn=parent').parent == DN('cn=parent')
True
"""
if len(self._dn) > 1:
return self[1:]
return None
def __init__(self, dn: str, format: DNFormat | None = None) -> None: # noqa: A002
self.dn = dn
self._format = format
self._cached_hash: int | None = None
self._cached_normalized: str | None = None
try:
self._dn = _to_dn(self.dn)
except ldap.DECODING_ERROR:
try:
self._dn = _to_dn(self.dn.replace(r'\?', '?')) # Samba LDAP returns broken DN: https://bugzilla.samba.org/show_bug.cgi?id=14073
except ldap.DECODING_ERROR as exc:
err = InvalidDN()
err._description = 'Malformed DN syntax'
err._info = f'{self.dn!r}: {exc}'.removesuffix(': ')
raise err from exc
[docs]
def get_parent(self, end: Self | str) -> Self | None:
"""
Get the parent DN until a certain base.
>>> base = DN('dc=freeiam,dc=org')
>>> DN('cn=foo,dc=freeiam,dc=org').get_parent(base) == base
True
>>> DN('dc=freeiam,dc=org').get_parent(base)
None
"""
if not self.endswith(end) or self == end:
return None
return self.parent
[docs]
def endswith(self, other: Self | str) -> bool:
"""
Check if DN is descendant of another base DN.
>>> DN('cn=foo,cn=bar').endswith('cn=bar')
True
>>> DN('cn=foo,cn=bar').endswith('cn=foo')
False
>>> DN('cn=foo').endswith('cn=foo,cn=bar')
False
>>> DN('cn=foo,cn=bar').endswith('')
True
"""
other = self.get(other)
return self[-len(other) or len(self) :] == other
[docs]
def startswith(self, other: Self | str) -> bool:
"""
Check if DN starts with another DN.
>>> DN('cn=foo,cn=bar').startswith('cn=foo')
True
>>> DN('cn=foo,cn=bar').startswith('cn=bar')
False
>>> DN('cn=foo,cn=bar').startswith('')
True
"""
other = self.get(other)
return self[: len(other)] == other
[docs]
def walk(self, base: Self | str | None = None) -> Generator[Self, None, None]:
"""
Walk the reversed DN components from the given base.
>>> [str(x) for x in DN('cn=foo,cn=bar,cn=baz,cn=blub').walk('cn=baz,cn=blub')]
['cn=baz,cn=blub', 'cn=bar,cn=baz,cn=blub', 'cn=foo,cn=bar,cn=baz,cn=blub']
>>> [str(x) for x in DN('cn=foo,cn=bar,cn=baz,cn=blub').walk()]
['cn=blub', 'cn=baz,cn=blub', 'cn=bar,cn=baz,cn=blub', 'cn=foo,cn=bar,cn=baz,cn=blub']
"""
base = self.get(base or '')
if not self.endswith(base):
msg = 'DN does not end with given base'
raise ValueError(msg)
for i in reversed(range(len(self) - (len(base) or 1) + 1)):
yield self[i:]
def __str__(self) -> str:
"""
Get a normalized string representation of the DN.
>>> str(DN('cn = foo , cn = bar')) == "cn=foo,cn=bar"
True
"""
if self._cached_normalized is None:
self._cached_normalized = ldap.dn.dn2str(self._dn)
return self._cached_normalized
def __repr__(self) -> str:
"""
Get a representation.
>>> repr(DN('cn=foo,cn=bar')) == "DN('cn=foo,cn=bar')"
True
"""
return f'{type(self).__name__}({str(self)!r})'
def __len__(self) -> int:
"""Return number of components of the DN."""
return len(self._dn)
def __getitem__(self, key: int | slice) -> Self:
"""Get slice or item of the DN components."""
if isinstance(key, slice):
return self.__class__(ldap.dn.dn2str(self._dn[key]))
return self.__class__(ldap.dn.dn2str([self._dn[key]]))
def __eq__(self, other: object) -> bool:
"""
Check normalized DNs for equality.
>>> DN('cn=foo') == DN('cn=foo')
True
>>> DN('cn=foo') == DN('cn=bar')
False
>>> DN('Cn=Foo') == DN('cn=foo')
True
>>> DN('Cn=foo') == DN('cn=bar')
False
>>> DN('uid=Administrator') == DN('uid=administrator')
True
>>> DN('foo=Foo') == DN('foo=foo')
False
>>> DN('cn=foo,cn=bar') == DN('cn=foo,cn=bar')
True
>>> DN('cn=bar,cn=foo') == DN('cn=foo,cn=bar')
False
>>> DN('cn=foo+cn=bar') == DN('cn=foo+cn=bar')
True
>>> DN('cn=bar+cn=foo') == DN('cn=foo+cn=bar')
True
>>> DN('cn=bar+Cn=foo') == DN('cn=foo+Cn=bar')
True
>>> DN(r'cn=%s31foo' % chr(92)) == DN(r'cn=1foo')
True
"""
return hash(self) == hash(DN(other) if isinstance(other, str) else other)
def __ne__(self, other: object) -> bool:
return not self == other
def __hash__(self) -> int:
if self._cached_hash is None:
self._cached_hash = hash(tuple(
tuple(sorted(
(attr.lower(), val.lower() if attr.lower() in self._CASE_INSENSITIVE_ATTRIBUTES else val, ava)
for attr, val, ava in rdn
)) for rdn in self._dn
)) # fmt: skip
return self._cached_hash
def __add__(self, other: Self | str) -> Self:
return self.__class__(f'{self},{other}')