autogen_ext.memory.canvas._text_canvas 源代码

import difflib
from typing import Any, Dict, List, Union

try:  # pragma: no cover
    from unidiff import PatchSet
except ModuleNotFoundError:  # pragma: no cover
    PatchSet = None  # type: ignore

from ._canvas import BaseCanvas


class FileRevision:
    """跟踪一个文件内容的历史记录。"""

    __slots__ = ("content", "revision")

    def __init__(self, content: str, revision: int) -> None:
        self.content: str = content
        self.revision: int = revision  # e.g. an integer, a timestamp, or git hash


[文档] class TextCanvas(BaseCanvas): """一个内存中的画布,存储带有完整修订历史的*文本*文件。 .. warning:: 这是一个实验性API,未来可能会发生变化。 除了原始的类CRUD操作外,这个增强实现还增加了: * **apply_patch** - 使用``unidiff``库应用补丁,以实现精确的块应用和上下文行验证。 * **get_revision_content** - 随机访问任何历史修订版本。 * **get_revision_diffs** - 获取每对连续修订版本之间应用的差异列表,以便调用者可以重放或审计完整的变更历史。 """ # ---------------------------------------------------------------------------------- # Construction helpers # ---------------------------------------------------------------------------------- def __init__(self) -> None: # For each file we keep an *ordered* list of FileRevision where the last # element is the most recent. Using a list keeps the memory footprint # small and preserves order without any extra bookkeeping. self._files: Dict[str, List[FileRevision]] = {} # ---------------------------------------------------------------------------------- # Internal utilities # ---------------------------------------------------------------------------------- def _latest_idx(self, filename: str) -> int: """返回最新修订版本的索引(非修订号)。""" return len(self._files.get(filename, [])) - 1 def _ensure_file(self, filename: str) -> None: if filename not in self._files: raise ValueError(f"File '{filename}' does not exist on the canvas; create it first.") # ---------------------------------------------------------------------------------- # Revision inspection helpers # ----------------------------------------------------------------------------------
[文档] def get_revision_content(self, filename: str, revision: int) -> str: # NEW 🚀 """返回存储在 *revision* 中的确切内容。 如果该修订版本不存在,则返回空字符串,以便下游代码无需抛出异常即可处理"未找到"的情况。 """ for rev in self._files.get(filename, []): if rev.revision == revision: return rev.content return ""
[文档] def get_revision_diffs(self, filename: str) -> List[str]: # NEW 🚀 """返回 *filename* 的统一差异(unified-diff)的*按时间顺序*排列的列表。 返回列表中的每个元素表示将修订版本 *n* 转换为修订版本 *n+1* 的差异(从修订版本 1 → 2 开始)。 """ revisions = self._files.get(filename, []) diffs: List[str] = [] for i in range(1, len(revisions)): older, newer = revisions[i - 1], revisions[i] diff = difflib.unified_diff( older.content.splitlines(keepends=True), newer.content.splitlines(keepends=True), fromfile=f"{filename}@r{older.revision}", tofile=f"{filename}@r{newer.revision}", ) diffs.append("".join(diff)) return diffs
# ---------------------------------------------------------------------------------- # BaseCanvas interface implementation # ----------------------------------------------------------------------------------
[文档] def list_files(self) -> Dict[str, int]: """返回一个 *filename → 最新修订版本号* 的映射。""" return {fname: revs[-1].revision for fname, revs in self._files.items() if revs}
[文档] def get_latest_content(self, filename: str) -> str: # noqa: D401 – keep API identical """返回最近的内容,如果文件是新的则返回空字符串。""" revs = self._files.get(filename, []) return revs[-1].content if revs else ""
[文档] def add_or_update_file(self, filename: str, new_content: Union[str, bytes, Any]) -> None: """创建*filename*或追加包含*new_content*的新修订版本。""" if isinstance(new_content, bytes): new_content = new_content.decode("utf-8") if not isinstance(new_content, str): raise ValueError(f"Expected str or bytes, got {type(new_content)}") if filename not in self._files: self._files[filename] = [FileRevision(new_content, 1)] else: last_rev_num = self._files[filename][-1].revision self._files[filename].append(FileRevision(new_content, last_rev_num + 1))
[文档] def get_diff(self, filename: str, from_revision: int, to_revision: int) -> str: """返回*from_revision*和*to_revision*之间的统一差异对比。""" revisions = self._files.get(filename, []) if not revisions: return "" # Fetch the contents for the requested revisions. from_content = self.get_revision_content(filename, from_revision) to_content = self.get_revision_content(filename, to_revision) if from_content == "" and to_content == "": # one (or both) revision ids not found return "" diff = difflib.unified_diff( from_content.splitlines(keepends=True), to_content.splitlines(keepends=True), fromfile=f"{filename}@r{from_revision}", tofile=f"{filename}@r{to_revision}", ) return "".join(diff)
[文档] def apply_patch(self, filename: str, patch_data: Union[str, bytes, Any]) -> None: """将 *patch_text*(统一差异)应用到最新修订版本并保存为新修订版本。 使用 *unidiff* 库来准确应用差异块并验证上下文行。 """ if isinstance(patch_data, bytes): patch_data = patch_data.decode("utf-8") if not isinstance(patch_data, str): raise ValueError(f"Expected str or bytes, got {type(patch_data)}") self._ensure_file(filename) original_content = self.get_latest_content(filename) if PatchSet is None: raise ImportError( "The 'unidiff' package is required for patch application. Install with 'pip install unidiff'." ) patch = PatchSet(patch_data) # Our canvas stores exactly one file per patch operation so we # use the first (and only) patched_file object. if not patch: raise ValueError("Empty patch text provided.") patched_file = patch[0] working_lines = original_content.splitlines(keepends=True) line_offset = 0 for hunk in patched_file: # Calculate the slice boundaries in the *current* working copy. start = hunk.source_start - 1 + line_offset end = start + hunk.source_length # Build the replacement block for this hunk. replacement: List[str] = [] for line in hunk: if line.is_added or line.is_context: replacement.append(line.value) # removed lines (line.is_removed) are *not* added. # Replace the slice with the hunk‑result. working_lines[start:end] = replacement line_offset += len(replacement) - (end - start) new_content = "".join(working_lines) # Finally commit the new revision. self.add_or_update_file(filename, new_content)
# ---------------------------------------------------------------------------------- # Convenience helpers # ----------------------------------------------------------------------------------
[文档] def get_all_contents_for_context(self) -> str: # noqa: D401 – keep public API stable """返回每个文件及其*最新*修订版本的摘要视图。""" out: List[str] = ["=== CANVAS FILES ==="] for fname, revs in self._files.items(): latest = revs[-1] out.append(f"File: {fname} (rev {latest.revision}):\n{latest.content}\n") out.append("=== END OF CANVAS ===") return "\n".join(out)