Skip to content

Cst

FunctionAndClassCollector

Bases: cst.CSTVisitor

A CSTVisitor that collects the names of functions and classes from a CST tree.

Source code in write_the/cst/function_and_class_collector.py
class FunctionAndClassCollector(cst.CSTVisitor):
    """
    A CSTVisitor that collects the names of functions and classes from a CST tree.
    """
    def __init__(self, force, update=False):
        """
        Initializes the FunctionAndClassCollector.

        Args:
          force (bool): Whether to force the collection of functions and classes even if they have docstrings.
          update (bool): Whether to update the collection of functions and classes if they have docstrings.
        """
        self.functions = []
        self.classes = []
        self.force = force
        self.update = update
        self.current_class = None

    def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
        """
        Visits a FunctionDef node and adds its name to the list of functions if it does not have a docstring or if `force` or `update` is `True`.

        Args:
          node (cst.FunctionDef): The FunctionDef node to visit.
        """
        name = (
            f"{self.current_class}.{node.name.value}"
            if self.current_class
            else node.name.value
        )
        if self.force:
            self.functions.append(name)
        elif has_docstring(node) and self.update:
            self.functions.append(name)
        elif not has_docstring(node) and not self.update:
            self.functions.append(name)

    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        """
        Visits a ClassDef node and adds its name to the list of classes if it does not have a docstring or if `force` or `update` is `True`. Also sets the current class name for nested function collection.

        Args:
          node (cst.ClassDef): The ClassDef node to visit.
        """
        self.current_class = node.name.value
        if self.force:
            self.classes.append(node.name.value)
        elif has_docstring(node) and self.update:
            self.classes.append(node.name.value)
        elif not has_docstring(node) and not self.update:
            self.classes.append(node.name.value)
        # self.visit_ClassDef(node)  # Call the superclass method to continue the visit

    def leave_ClassDef(self, node: cst.ClassDef) -> None:
        """
        Resets the current class name when leaving a ClassDef node.
        """
        self.current_class = None

__init__(force, update=False)

Initializes the FunctionAndClassCollector.

Parameters:

Name Type Description Default
force bool

Whether to force the collection of functions and classes even if they have docstrings.

required
update bool

Whether to update the collection of functions and classes if they have docstrings.

False
Source code in write_the/cst/function_and_class_collector.py
def __init__(self, force, update=False):
    """
    Initializes the FunctionAndClassCollector.

    Args:
      force (bool): Whether to force the collection of functions and classes even if they have docstrings.
      update (bool): Whether to update the collection of functions and classes if they have docstrings.
    """
    self.functions = []
    self.classes = []
    self.force = force
    self.update = update
    self.current_class = None

leave_ClassDef(node)

Resets the current class name when leaving a ClassDef node.

Source code in write_the/cst/function_and_class_collector.py
def leave_ClassDef(self, node: cst.ClassDef) -> None:
    """
    Resets the current class name when leaving a ClassDef node.
    """
    self.current_class = None

visit_ClassDef(node)

Visits a ClassDef node and adds its name to the list of classes if it does not have a docstring or if force or update is True. Also sets the current class name for nested function collection.

Parameters:

Name Type Description Default
node cst.ClassDef

The ClassDef node to visit.

required
Source code in write_the/cst/function_and_class_collector.py
def visit_ClassDef(self, node: cst.ClassDef) -> None:
    """
    Visits a ClassDef node and adds its name to the list of classes if it does not have a docstring or if `force` or `update` is `True`. Also sets the current class name for nested function collection.

    Args:
      node (cst.ClassDef): The ClassDef node to visit.
    """
    self.current_class = node.name.value
    if self.force:
        self.classes.append(node.name.value)
    elif has_docstring(node) and self.update:
        self.classes.append(node.name.value)
    elif not has_docstring(node) and not self.update:
        self.classes.append(node.name.value)

visit_FunctionDef(node)

Visits a FunctionDef node and adds its name to the list of functions if it does not have a docstring or if force or update is True.

Parameters:

Name Type Description Default
node cst.FunctionDef

The FunctionDef node to visit.

required
Source code in write_the/cst/function_and_class_collector.py
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
    """
    Visits a FunctionDef node and adds its name to the list of functions if it does not have a docstring or if `force` or `update` is `True`.

    Args:
      node (cst.FunctionDef): The FunctionDef node to visit.
    """
    name = (
        f"{self.current_class}.{node.name.value}"
        if self.current_class
        else node.name.value
    )
    if self.force:
        self.functions.append(name)
    elif has_docstring(node) and self.update:
        self.functions.append(name)
    elif not has_docstring(node) and not self.update:
        self.functions.append(name)

get_node_names(tree, force, update=False)

Gets the names of functions and classes from a CST tree.

Parameters:

Name Type Description Default
tree cst.CSTNode

The CST tree to traverse.

required
force bool

Whether to force the collection of functions and classes even if they have docstrings.

required
update bool

Whether to update the collection of functions and classes if they have docstrings. Defaults to False.

False

Returns:

Type Description

list[str]: A list of function and class names.

Source code in write_the/cst/function_and_class_collector.py
def get_node_names(tree, force, update=False):
    """
    Gets the names of functions and classes from a CST tree.

    Args:
      tree (cst.CSTNode): The CST tree to traverse.
      force (bool): Whether to force the collection of functions and classes even if they have docstrings.
      update (bool, optional): Whether to update the collection of functions and classes if they have docstrings. Defaults to False.

    Returns:
      list[str]: A list of function and class names.
    """
    collector = FunctionAndClassCollector(force, update)
    tree.visit(collector)
    return collector.classes + collector.functions

DocstringAdder

Bases: cst.CSTTransformer

Source code in write_the/cst/docstring_adder.py
class DocstringAdder(cst.CSTTransformer):
    def __init__(self, docstrings, force, indent="    "):
        self.docstrings = docstrings
        self.force = force
        self.indent = indent
        self.current_class = None

    def leave_FunctionDef(
        self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
    ) -> cst.FunctionDef:
        """
        Adds a docstring to a function definition if it doesn't have one.

        Args:
          original_node (cst.FunctionDef): The original CST node representing the function definition.
          updated_node (cst.FunctionDef): The updated CST node representing the function definition.

        Returns:
          cst.FunctionDef: The updated CST node with a docstring added if it didn't have one.
        """
        return self.add_docstring(updated_node)

    def visit_ClassDef(self, original_node: cst.ClassDef) -> None:
        self.current_class = original_node.name.value

    def leave_ClassDef(
        self, original_node: cst.ClassDef, updated_node: cst.ClassDef
    ) -> cst.ClassDef:
        """
        Adds a docstring to a class definition if it doesn't have one.

        Args:
          original_node (cst.ClassDef): The original CST node representing the class definition.
          updated_node (cst.ClassDef): The updated CST node representing the class definition.

        Returns:
          cst.ClassDef: The updated CST node with a docstring added if it didn't have one.
        """
        self.current_class = None
        updated_node = self.add_docstring(updated_node)
        return updated_node

    def add_docstring(self, node):
        """
        Adds a docstring to a CST node if it doesn't have one.

        Args:
          node (cst.CSTNode): The CST node to add a docstring to.

        Returns:
          cst.CSTNode: The updated CST node with a docstring added if it didn't have one.

        Note:
          If the node already has a docstring and the force flag is set, the existing docstring is removed before adding the new one.
        """
        key = (
            f"{self.current_class}.{node.name.value}"
            if self.current_class
            else node.name.value
        )
        docstring: str = self.docstrings.get(key, None)
        if docstring and (self.force or not has_docstring(node)):
            if self.force and has_docstring(node):
                # Remove existing docstring
                node = remove_docstring(node)
            escaped_docstring = re.sub(r"(?<!\\)\\n", "\\\\\\\\n", docstring)
            dedented_docstring = textwrap.dedent(escaped_docstring)
            indent = self.indent
            if self.current_class:
                indent = indent * 2
            indented_docstring = textwrap.indent(dedented_docstring, indent)
            new_docstring = cst.parse_statement(f'"""{indented_docstring}{indent}"""')
            body = node.body.with_changes(body=(new_docstring, *node.body.body))
            return node.with_changes(body=body)

        return node

add_docstring(node)

Adds a docstring to a CST node if it doesn't have one.

Parameters:

Name Type Description Default
node cst.CSTNode

The CST node to add a docstring to.

required

Returns:

Type Description

cst.CSTNode: The updated CST node with a docstring added if it didn't have one.

Note

If the node already has a docstring and the force flag is set, the existing docstring is removed before adding the new one.

Source code in write_the/cst/docstring_adder.py
def add_docstring(self, node):
    """
    Adds a docstring to a CST node if it doesn't have one.

    Args:
      node (cst.CSTNode): The CST node to add a docstring to.

    Returns:
      cst.CSTNode: The updated CST node with a docstring added if it didn't have one.

    Note:
      If the node already has a docstring and the force flag is set, the existing docstring is removed before adding the new one.
    """
    key = (
        f"{self.current_class}.{node.name.value}"
        if self.current_class
        else node.name.value
    )
    docstring: str = self.docstrings.get(key, None)
    if docstring and (self.force or not has_docstring(node)):
        if self.force and has_docstring(node):
            # Remove existing docstring
            node = remove_docstring(node)
        escaped_docstring = re.sub(r"(?<!\\)\\n", "\\\\\\\\n", docstring)
        dedented_docstring = textwrap.dedent(escaped_docstring)
        indent = self.indent
        if self.current_class:
            indent = indent * 2
        indented_docstring = textwrap.indent(dedented_docstring, indent)
        new_docstring = cst.parse_statement(f'"""{indented_docstring}{indent}"""')
        body = node.body.with_changes(body=(new_docstring, *node.body.body))
        return node.with_changes(body=body)

    return node

leave_ClassDef(original_node, updated_node)

Adds a docstring to a class definition if it doesn't have one.

Parameters:

Name Type Description Default
original_node cst.ClassDef

The original CST node representing the class definition.

required
updated_node cst.ClassDef

The updated CST node representing the class definition.

required

Returns:

Type Description
cst.ClassDef

cst.ClassDef: The updated CST node with a docstring added if it didn't have one.

Source code in write_the/cst/docstring_adder.py
def leave_ClassDef(
    self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.ClassDef:
    """
    Adds a docstring to a class definition if it doesn't have one.

    Args:
      original_node (cst.ClassDef): The original CST node representing the class definition.
      updated_node (cst.ClassDef): The updated CST node representing the class definition.

    Returns:
      cst.ClassDef: The updated CST node with a docstring added if it didn't have one.
    """
    self.current_class = None
    updated_node = self.add_docstring(updated_node)
    return updated_node

leave_FunctionDef(original_node, updated_node)

Adds a docstring to a function definition if it doesn't have one.

Parameters:

Name Type Description Default
original_node cst.FunctionDef

The original CST node representing the function definition.

required
updated_node cst.FunctionDef

The updated CST node representing the function definition.

required

Returns:

Type Description
cst.FunctionDef

cst.FunctionDef: The updated CST node with a docstring added if it didn't have one.

Source code in write_the/cst/docstring_adder.py
def leave_FunctionDef(
    self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
    """
    Adds a docstring to a function definition if it doesn't have one.

    Args:
      original_node (cst.FunctionDef): The original CST node representing the function definition.
      updated_node (cst.FunctionDef): The updated CST node representing the function definition.

    Returns:
      cst.FunctionDef: The updated CST node with a docstring added if it didn't have one.
    """
    return self.add_docstring(updated_node)

DocstringRemover

Bases: cst.CSTTransformer

Source code in write_the/cst/docstring_remover.py
class DocstringRemover(cst.CSTTransformer):
    def __init__(self, nodes):
        """
        Initializes the DocstringRemover object.

        Args:
          nodes (list): A list of nodes to remove docstrings from.
        """
        self.nodes = nodes
        self.current_class = None

    def leave_FunctionDef(
        self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
    ) -> cst.FunctionDef:
        """
        Removes the docstring from a FunctionDef node if it is in the list of nodes.

        Args:
          original_node (cst.FunctionDef): The original FunctionDef node.
          updated_node (cst.FunctionDef): The updated FunctionDef node.

        Returns:
          cst.FunctionDef: The updated FunctionDef node with the docstring removed if it is in the list of nodes.
        """
        name = (
            f"{self.current_class}.{original_node.name.value}"
            if self.current_class
            else original_node.name.value
        )
        if name in self.nodes:
            return remove_docstring(updated_node)
        return updated_node

    def visit_ClassDef(self, original_node: cst.ClassDef) -> None:
        self.current_class = original_node.name.value

    def leave_ClassDef(
        self, original_node: cst.ClassDef, updated_node: cst.ClassDef
    ) -> cst.ClassDef:
        """
        Removes the docstring from a ClassDef node if it is in the list of nodes and resets the current_class attribute to None.

        Args:
          original_node (cst.ClassDef): The original ClassDef node.
          updated_node (cst.ClassDef): The updated ClassDef node.

        Returns:
          cst.ClassDef: The updated ClassDef node with the docstring removed if it is in the list of nodes.
        """
        self.current_class = None
        if original_node.name.value in self.nodes:
            return remove_docstring(updated_node)
        return updated_node

__init__(nodes)

Initializes the DocstringRemover object.

Parameters:

Name Type Description Default
nodes list

A list of nodes to remove docstrings from.

required
Source code in write_the/cst/docstring_remover.py
def __init__(self, nodes):
    """
    Initializes the DocstringRemover object.

    Args:
      nodes (list): A list of nodes to remove docstrings from.
    """
    self.nodes = nodes
    self.current_class = None

leave_ClassDef(original_node, updated_node)

Removes the docstring from a ClassDef node if it is in the list of nodes and resets the current_class attribute to None.

Parameters:

Name Type Description Default
original_node cst.ClassDef

The original ClassDef node.

required
updated_node cst.ClassDef

The updated ClassDef node.

required

Returns:

Type Description
cst.ClassDef

cst.ClassDef: The updated ClassDef node with the docstring removed if it is in the list of nodes.

Source code in write_the/cst/docstring_remover.py
def leave_ClassDef(
    self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.ClassDef:
    """
    Removes the docstring from a ClassDef node if it is in the list of nodes and resets the current_class attribute to None.

    Args:
      original_node (cst.ClassDef): The original ClassDef node.
      updated_node (cst.ClassDef): The updated ClassDef node.

    Returns:
      cst.ClassDef: The updated ClassDef node with the docstring removed if it is in the list of nodes.
    """
    self.current_class = None
    if original_node.name.value in self.nodes:
        return remove_docstring(updated_node)
    return updated_node

leave_FunctionDef(original_node, updated_node)

Removes the docstring from a FunctionDef node if it is in the list of nodes.

Parameters:

Name Type Description Default
original_node cst.FunctionDef

The original FunctionDef node.

required
updated_node cst.FunctionDef

The updated FunctionDef node.

required

Returns:

Type Description
cst.FunctionDef

cst.FunctionDef: The updated FunctionDef node with the docstring removed if it is in the list of nodes.

Source code in write_the/cst/docstring_remover.py
def leave_FunctionDef(
    self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
    """
    Removes the docstring from a FunctionDef node if it is in the list of nodes.

    Args:
      original_node (cst.FunctionDef): The original FunctionDef node.
      updated_node (cst.FunctionDef): The updated FunctionDef node.

    Returns:
      cst.FunctionDef: The updated FunctionDef node with the docstring removed if it is in the list of nodes.
    """
    name = (
        f"{self.current_class}.{original_node.name.value}"
        if self.current_class
        else original_node.name.value
    )
    if name in self.nodes:
        return remove_docstring(updated_node)
    return updated_node

remove_docstrings_from_tree(tree, nodes)

Removes the docstrings from a tree of nodes.

Parameters:

Name Type Description Default
tree cst.CSTNode

The tree of nodes to remove the docstrings from.

required
nodes list

A list of nodes to remove docstrings from.

required

Returns:

Type Description

cst.CSTNode: The tree of nodes with the docstrings removed.

Source code in write_the/cst/docstring_remover.py
def remove_docstrings_from_tree(tree, nodes):
    """
    Removes the docstrings from a tree of nodes.

    Args:
      tree (cst.CSTNode): The tree of nodes to remove the docstrings from.
      nodes (list): A list of nodes to remove docstrings from.

    Returns:
      cst.CSTNode: The tree of nodes with the docstrings removed.
    """
    remover = DocstringRemover(nodes)
    tree = tree.visit(remover)
    return tree

get_docstring(node)

Retrieves the docstring of a CSTNode if it has one.

Parameters:

Name Type Description Default
node cst.CSTNode

The node to check.

required

Returns:

Type Description
Optional[str]

Optional[str]: The docstring of the node if it exists, None otherwise.

Notes

Only retrieves docstrings for FunctionDef and ClassDef nodes.

Source code in write_the/cst/utils.py
def get_docstring(node: cst.CSTNode) -> Optional[str]:
    """
    Retrieves the docstring of a CSTNode if it has one.

    Args:
      node (cst.CSTNode): The node to check.

    Returns:
      Optional[str]: The docstring of the node if it exists, None otherwise.

    Notes:
      Only retrieves docstrings for FunctionDef and ClassDef nodes.
    """
    if has_docstring(node):
        body = node.body.body
        stmt = body[0].body[0]
        return stmt.value.value
    return None

has_docstring(node)

Checks if a CSTNode has a docstring.

Parameters:

Name Type Description Default
node cst.CSTNode

The node to check.

required

Returns:

Name Type Description
bool bool

Whether or not the node has a docstring.

Notes

Only checks for docstrings on FunctionDef and ClassDef nodes.

Source code in write_the/cst/utils.py
def has_docstring(node: cst.CSTNode) -> bool:
    """
    Checks if a CSTNode has a docstring.

    Args:
      node (cst.CSTNode): The node to check.

    Returns:
      bool: Whether or not the node has a docstring.

    Notes:
      Only checks for docstrings on FunctionDef and ClassDef nodes.
    """
    if isinstance(node, cst.FunctionDef) or isinstance(node, cst.ClassDef):
        body = node.body.body
        if body and isinstance(body[0], cst.SimpleStatementLine):
            stmt = body[0].body[0]
            if isinstance(stmt, cst.Expr) and isinstance(stmt.value, cst.SimpleString):
                return True
    return False

nodes_to_tree(nodes)

Converts a list of CSTNodes into a CSTModule.

Parameters:

Name Type Description Default
nodes list[cst.CSTNode]

The list of nodes to convert.

required

Returns:

Type Description

cst.Module: The CSTModule containing the given nodes.

Source code in write_the/cst/utils.py
def nodes_to_tree(nodes):
    """
    Converts a list of CSTNodes into a CSTModule.

    Args:
      nodes (list[cst.CSTNode]): The list of nodes to convert.

    Returns:
      cst.Module: The CSTModule containing the given nodes.
    """
    module = cst.Module(body=nodes)
    return module

remove_docstring(node)

Removes the docstring from a CSTNode.

Parameters:

Name Type Description Default
node cst.CSTNode

The node to remove the docstring from.

required

Returns:

Type Description

cst.CSTNode: The node with the docstring removed.

Source code in write_the/cst/utils.py
def remove_docstring(node):
    """
    Removes the docstring from a CSTNode.

    Args:
      node (cst.CSTNode): The node to remove the docstring from.

    Returns:
      cst.CSTNode: The node with the docstring removed.
    """
    if not node.body.body:
        return node
    first_stmt = node.body.body[0]
    if isinstance(first_stmt, cst.SimpleStatementLine):
        stmt = first_stmt.body[0]
        if isinstance(stmt, cst.Expr) and isinstance(stmt.value, cst.SimpleString):
            new_body = node.body.with_changes(body=node.body.body[1:])
            return node.with_changes(body=new_body)
    return node

Background

Bases: Node

A class representing a background in a CST tree.

Parameters:

Name Type Description Default
body cst.CSTNode

The CST node of the background.

required
Source code in write_the/cst/node_batcher.py
class Background(Node):
    """
    A class representing a background in a CST tree.

    Args:
      body (cst.CSTNode): The CST node of the background.
    """

    def __init__(self, body) -> None:
        """
        Initializes a Background object.

        Args:
          body (cst.CSTNode): The CST node of the background.
        """
        self.node = body
        self.name = "background"
        self.code = self.node.code
        encoding = tiktoken.encoding_for_model("gpt-4")
        self.tokens = len(encoding.encode(self.code))

__init__(body)

Initializes a Background object.

Parameters:

Name Type Description Default
body cst.CSTNode

The CST node of the background.

required
Source code in write_the/cst/node_batcher.py
def __init__(self, body) -> None:
    """
    Initializes a Background object.

    Args:
      body (cst.CSTNode): The CST node of the background.
    """
    self.node = body
    self.name = "background"
    self.code = self.node.code
    encoding = tiktoken.encoding_for_model("gpt-4")
    self.tokens = len(encoding.encode(self.code))

Node

A class representing a node in a CST tree.

Parameters:

Name Type Description Default
name str

The name of the node.

required
node cst.CSTNode

The CST node.

required
code str

The code of the node.

required
tokens int

The number of tokens in the node.

required
Source code in write_the/cst/node_batcher.py
class Node:
    """
    A class representing a node in a CST tree.

    Args:
      name (str): The name of the node.
      node (cst.CSTNode): The CST node.
      code (str): The code of the node.
      tokens (int): The number of tokens in the node.
    """

    name: str
    node: cst.CSTNode
    code: str
    tokens: int

    def __init__(self, *, tree, node_name, response_size=80) -> None:
        """
        Initializes a Node object.

        Args:
          tree (cst.Module): The CST tree.
          node_name (str): The name of the node.
          response_size (int): The size of the response.
        """
        self.node = extract_node_from_tree(tree=tree, node=node_name)
        self.name = node_name
        self.code = get_code_from_node(self.node)
        encoding = tiktoken.encoding_for_model("gpt-4")
        self.tokens = len(encoding.encode(self.code)) + response_size

__init__(*, tree, node_name, response_size=80)

Initializes a Node object.

Parameters:

Name Type Description Default
tree cst.Module

The CST tree.

required
node_name str

The name of the node.

required
response_size int

The size of the response.

80
Source code in write_the/cst/node_batcher.py
def __init__(self, *, tree, node_name, response_size=80) -> None:
    """
    Initializes a Node object.

    Args:
      tree (cst.Module): The CST tree.
      node_name (str): The name of the node.
      response_size (int): The size of the response.
    """
    self.node = extract_node_from_tree(tree=tree, node=node_name)
    self.name = node_name
    self.code = get_code_from_node(self.node)
    encoding = tiktoken.encoding_for_model("gpt-4")
    self.tokens = len(encoding.encode(self.code)) + response_size

NodeBatch dataclass

A class representing a batch of nodes in a CST tree.

Parameters:

Name Type Description Default
tree cst.Module

The CST tree.

required
background Optional[Background]

The background of the tree.

required
max_tokens int

The maximum number of tokens in the batch.

required
prompt_size int

The size of the prompt.

required
nodes List[Node]

The list of nodes in the batch.

field(default_factory=list)
max_batch_size Optional[int]

The maximum size of the batch.

None
send_node_context bool

Whether to send the context of the nodes.

False
Source code in write_the/cst/node_batcher.py
@dataclass
class NodeBatch:
    """
    A class representing a batch of nodes in a CST tree.

    Args:
      tree (cst.Module): The CST tree.
      background (Optional[Background]): The background of the tree.
      max_tokens (int): The maximum number of tokens in the batch.
      prompt_size (int): The size of the prompt.
      nodes (List[Node]): The list of nodes in the batch.
      max_batch_size (Optional[int]): The maximum size of the batch.
      send_node_context (bool): Whether to send the context of the nodes.
    """

    tree: cst.Module
    background: Optional[Background]
    max_tokens: int
    prompt_size: int
    nodes: List[Node] = field(default_factory=list)
    max_batch_size: Optional[int] = None
    send_node_context: bool = False

    @property
    def tokens(self) -> int:
        """
        Gets the number of tokens in the batch.

        Returns:
          int: The number of tokens in the batch.
        """
        tokens = self.prompt_size + sum(n.tokens for n in self.nodes)
        if self.background:
            tokens += self.background.tokens
        return tokens

    @property
    def node_names(self) -> List[str]:
        """
        Gets the names of the nodes in the batch.

        Returns:
          List[str]: The names of the nodes in the batch.
        """
        return [n.name for n in self.nodes]

    @property
    def space_available(self) -> int:
        """
        Gets the amount of space available in the batch.

        Returns:
          int: The amount of space available in the batch.
        """
        return self.max_tokens - self.tokens

    @property
    def code(self):
        """
        Gets the code of the batch.

        Returns:
          str: The code of the batch.
        """
        if self.send_node_context:
            # send everything
            return self.tree.code
        if self.background:
            # remove all non batch nodes
            all_nodes = get_node_names(self.tree, True)
            classes_to_keep = [n.split(".")[0] for n in self.node_names if "." in n]
            nodes_to_remove: List[str] = [
                n for n in all_nodes if n not in self.node_names
            ]
            nodes_to_remove = [n for n in nodes_to_remove if n not in classes_to_keep]
            processed_tree = remove_nodes_from_tree(self.tree, nodes_to_remove)
        else:
            # extract batch nodes
            extracted_nodes = extract_nodes_from_tree(self.tree, self.node_names)
            processed_tree = nodes_to_tree(extracted_nodes)
        return processed_tree.code

    def add(self, node: Node):
        """
        Adds a node to the batch.

        Args:
          node (Node): The node to add.

        Raises:
          ValueError: If there is no space available in the batch.
        """
        if self.space_available - node.tokens < 0 or (
            self.max_batch_size and len(self.nodes) + 1 > self.max_batch_size
        ):
            raise ValueError("No space available in batch!")
        self.nodes.append(node)

code property

Gets the code of the batch.

Returns:

Name Type Description
str

The code of the batch.

node_names: List[str] property

Gets the names of the nodes in the batch.

Returns:

Type Description
List[str]

List[str]: The names of the nodes in the batch.

space_available: int property

Gets the amount of space available in the batch.

Returns:

Name Type Description
int int

The amount of space available in the batch.

tokens: int property

Gets the number of tokens in the batch.

Returns:

Name Type Description
int int

The number of tokens in the batch.

add(node)

Adds a node to the batch.

Parameters:

Name Type Description Default
node Node

The node to add.

required

Raises:

Type Description
ValueError

If there is no space available in the batch.

Source code in write_the/cst/node_batcher.py
def add(self, node: Node):
    """
    Adds a node to the batch.

    Args:
      node (Node): The node to add.

    Raises:
      ValueError: If there is no space available in the batch.
    """
    if self.space_available - node.tokens < 0 or (
        self.max_batch_size and len(self.nodes) + 1 > self.max_batch_size
    ):
        raise ValueError("No space available in batch!")
    self.nodes.append(node)

create_batches(tree, node_names, max_tokens, prompt_size, response_size_per_node, max_batch_size=None, send_background_context=True, send_node_context=True, remove_docstrings=True)

Creates batches of nodes from a tree.

Parameters:

Name Type Description Default
tree cst.Module

The tree to create batches from.

required
node_names List[str]

The names of the nodes to create batches for.

required
max_tokens int

The maximum number of tokens per batch.

required
prompt_size int

The size of the prompt for each node.

required
response_size_per_node int

The size of the response for each node.

required
max_batch_size Optional[int]

The maximum number of nodes per batch.

None
send_background_context bool

Whether to send background context.

True
send_node_context bool

Whether to send node context.

True
remove_docstrings bool

Whether to remove docstrings from the tree.

True

Returns:

Type Description
List[NodeBatch]

List[NodeBatch]: A list of batches of nodes.

Source code in write_the/cst/node_batcher.py
def create_batches(
    tree,
    node_names,
    max_tokens,
    prompt_size,
    response_size_per_node,
    max_batch_size=None,
    send_background_context=True,
    send_node_context=True,
    remove_docstrings=True,
) -> List[NodeBatch]:
    """
    Creates batches of nodes from a tree.

    Args:
      tree (cst.Module): The tree to create batches from.
      node_names (List[str]): The names of the nodes to create batches for.
      max_tokens (int): The maximum number of tokens per batch.
      prompt_size (int): The size of the prompt for each node.
      response_size_per_node (int): The size of the response for each node.
      max_batch_size (Optional[int]): The maximum number of nodes per batch.
      send_background_context (bool): Whether to send background context.
      send_node_context (bool): Whether to send node context.
      remove_docstrings (bool): Whether to remove docstrings from the tree.

    Returns:
      List[NodeBatch]: A list of batches of nodes.
    """
    if remove_docstrings:
      tree = remove_docstrings_from_tree(tree, node_names)  # TODO: fix to use Class.method syntax
    batches = []
    background = None
    if send_background_context:
        background = extract_background(tree)

    def create_batch():
        """
    Creates a batch of nodes from a tree.

    Args:
      tree (cst.Module): The tree to create batches from.
      max_tokens (int): The maximum number of tokens per batch.
      prompt_size (int): The size of the prompt for each node.
      background (Optional[cst.Module]): The background context for the batch.
      max_batch_size (Optional[int]): The maximum number of nodes per batch.
      send_node_context (bool): Whether to send node context.

    Returns:
      NodeBatch: A batch of nodes.
    """
        return NodeBatch(
            tree=tree,
            max_tokens=max_tokens,
            prompt_size=prompt_size,
            background=background,
            max_batch_size=max_batch_size,
            send_node_context=send_node_context,
        )

    current_batch = create_batch()
    for node_name in node_names:
        node = Node(
            tree=tree, node_name=node_name, response_size=response_size_per_node
        )
        try:
            current_batch.add(node)
        except ValueError:
            # full
            batches.append(current_batch)
            current_batch = create_batch()
            current_batch.add(node)
    batches.append(current_batch)
    return batches

extract_background(tree)

Extracts the background from a CST tree.

Parameters:

Name Type Description Default
tree cst.Module

The CST tree.

required

Returns:

Name Type Description
Background

The background of the tree.

Source code in write_the/cst/node_batcher.py
def extract_background(tree):
    """
    Extracts the background from a CST tree.

    Args:
      tree (cst.Module): The CST tree.

    Returns:
      Background: The background of the tree.
    """
    all_node_names = get_node_names(tree, force=True)
    background = remove_nodes_from_tree(tree, all_node_names)
    return Background(body=background)

NodeExtractor

Bases: cst.CSTVisitor

Source code in write_the/cst/node_extractor.py
class NodeExtractor(cst.CSTVisitor):
    def __init__(self, nodes):
        self.nodes = nodes
        self.extracted_nodes = []
        self.current_class = None

    def visit_FunctionDef(self, node: cst.FunctionDef):
        """
        Visits a FunctionDef node and adds it to the extracted_nodes list if its name is in the nodes list.

        Args:
          node (cst.FunctionDef): The FunctionDef node to visit.

        Side Effects:
          Modifies the extracted_nodes list of the NodeExtractor instance, adding the node if its name is in the nodes list.
        """
        name = (
            f"{self.current_class}.{node.name.value}"
            if self.current_class
            else node.name.value
        )
        if name in self.nodes:
            self.extracted_nodes.append(node)

    def visit_ClassDef(self, node: cst.ClassDef):
        """
        Visits a ClassDef node and sets the current_class attribute. If the class name is in the nodes list, it also adds the node to the extracted_nodes list.

        Args:
          node (cst.ClassDef): The ClassDef node to visit.

        Side Effects:
          Modifies the current_class attribute of the NodeExtractor instance, setting it to the name of the visited node. If the class name is in the nodes list, it also modifies the extracted_nodes list, adding the node.
        """
        self.current_class = node.name.value
        if node.name.value in self.nodes:
            self.extracted_nodes.append(node)

    def leave_ClassDef(self, node: cst.ClassDef) -> None:
        self.current_class = None

visit_ClassDef(node)

Visits a ClassDef node and sets the current_class attribute. If the class name is in the nodes list, it also adds the node to the extracted_nodes list.

Parameters:

Name Type Description Default
node cst.ClassDef

The ClassDef node to visit.

required
Side Effects

Modifies the current_class attribute of the NodeExtractor instance, setting it to the name of the visited node. If the class name is in the nodes list, it also modifies the extracted_nodes list, adding the node.

Source code in write_the/cst/node_extractor.py
def visit_ClassDef(self, node: cst.ClassDef):
    """
    Visits a ClassDef node and sets the current_class attribute. If the class name is in the nodes list, it also adds the node to the extracted_nodes list.

    Args:
      node (cst.ClassDef): The ClassDef node to visit.

    Side Effects:
      Modifies the current_class attribute of the NodeExtractor instance, setting it to the name of the visited node. If the class name is in the nodes list, it also modifies the extracted_nodes list, adding the node.
    """
    self.current_class = node.name.value
    if node.name.value in self.nodes:
        self.extracted_nodes.append(node)

visit_FunctionDef(node)

Visits a FunctionDef node and adds it to the extracted_nodes list if its name is in the nodes list.

Parameters:

Name Type Description Default
node cst.FunctionDef

The FunctionDef node to visit.

required
Side Effects

Modifies the extracted_nodes list of the NodeExtractor instance, adding the node if its name is in the nodes list.

Source code in write_the/cst/node_extractor.py
def visit_FunctionDef(self, node: cst.FunctionDef):
    """
    Visits a FunctionDef node and adds it to the extracted_nodes list if its name is in the nodes list.

    Args:
      node (cst.FunctionDef): The FunctionDef node to visit.

    Side Effects:
      Modifies the extracted_nodes list of the NodeExtractor instance, adding the node if its name is in the nodes list.
    """
    name = (
        f"{self.current_class}.{node.name.value}"
        if self.current_class
        else node.name.value
    )
    if name in self.nodes:
        self.extracted_nodes.append(node)

extract_nodes_from_tree(tree, nodes)

Extracts specified nodes from a CST tree.

Parameters:

Name Type Description Default
tree cst.CSTNode

The CST tree to extract nodes from.

required
nodes list of str

A list of node names to extract.

required

Returns:

Type Description

list of cst.CSTNode: A list of extracted nodes.

Examples:

>>> extract_nodes_from_tree(tree, ['FunctionDef', 'ClassDef'])
[cst.FunctionDef, cst.ClassDef]
Source code in write_the/cst/node_extractor.py
def extract_nodes_from_tree(tree, nodes):
    """
    Extracts specified nodes from a CST tree.

    Args:
      tree (cst.CSTNode): The CST tree to extract nodes from.
      nodes (list of str): A list of node names to extract.

    Returns:
      list of cst.CSTNode: A list of extracted nodes.

    Examples:
      >>> extract_nodes_from_tree(tree, ['FunctionDef', 'ClassDef'])
      [cst.FunctionDef, cst.ClassDef]
    """
    extractor = NodeExtractor(nodes)
    tree.visit(extractor)
    return extractor.extracted_nodes

NodeRemover

Bases: cst.CSTTransformer

Source code in write_the/cst/node_remover.py
class NodeRemover(cst.CSTTransformer):
    def __init__(self, nodes):
        """
        Initializes a NodeRemover instance.

        Args:
          nodes (list): A list of nodes to remove.
        """
        self.nodes = nodes
        self.current_class = None

    def leave_FunctionDef(
        self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
    ) -> cst.RemovalSentinel:
        """
        Removes a FunctionDef node from the tree if it is in the list of nodes to remove. The node is identified by its fully qualified name, which includes the class name if the function is a method.

        Args:
          original_node (cst.FunctionDef): The original FunctionDef node.
          updated_node (cst.FunctionDef): The updated FunctionDef node.

        Returns:
          cst.RemovalSentinel: A sentinel indicating whether the node should be removed.
        """
        name = (
            f"{self.current_class}.{original_node.name.value}"
            if self.current_class
            else original_node.name.value
        )
        if name in self.nodes:
            return cst.RemoveFromParent()
        return updated_node

    def visit_ClassDef(self, original_node: cst.ClassDef) -> None:
        self.current_class = original_node.name.value

    def leave_ClassDef(
        self, original_node: cst.ClassDef, updated_node: cst.ClassDef
    ) -> cst.RemovalSentinel:
        """
        Removes a ClassDef node from the tree if it is in the list of nodes to remove.

        Args:
          original_node (cst.ClassDef): The original ClassDef node.
          updated_node (cst.ClassDef): The updated ClassDef node.

        Returns:
          cst.RemovalSentinel: A sentinel indicating whether the node should be removed.
        """
        self.current_class = None
        if original_node.name.value in self.nodes:
            return cst.RemoveFromParent()

        return updated_node

__init__(nodes)

Initializes a NodeRemover instance.

Parameters:

Name Type Description Default
nodes list

A list of nodes to remove.

required
Source code in write_the/cst/node_remover.py
def __init__(self, nodes):
    """
    Initializes a NodeRemover instance.

    Args:
      nodes (list): A list of nodes to remove.
    """
    self.nodes = nodes
    self.current_class = None

leave_ClassDef(original_node, updated_node)

Removes a ClassDef node from the tree if it is in the list of nodes to remove.

Parameters:

Name Type Description Default
original_node cst.ClassDef

The original ClassDef node.

required
updated_node cst.ClassDef

The updated ClassDef node.

required

Returns:

Type Description
cst.RemovalSentinel

cst.RemovalSentinel: A sentinel indicating whether the node should be removed.

Source code in write_the/cst/node_remover.py
def leave_ClassDef(
    self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.RemovalSentinel:
    """
    Removes a ClassDef node from the tree if it is in the list of nodes to remove.

    Args:
      original_node (cst.ClassDef): The original ClassDef node.
      updated_node (cst.ClassDef): The updated ClassDef node.

    Returns:
      cst.RemovalSentinel: A sentinel indicating whether the node should be removed.
    """
    self.current_class = None
    if original_node.name.value in self.nodes:
        return cst.RemoveFromParent()

    return updated_node

leave_FunctionDef(original_node, updated_node)

Removes a FunctionDef node from the tree if it is in the list of nodes to remove. The node is identified by its fully qualified name, which includes the class name if the function is a method.

Parameters:

Name Type Description Default
original_node cst.FunctionDef

The original FunctionDef node.

required
updated_node cst.FunctionDef

The updated FunctionDef node.

required

Returns:

Type Description
cst.RemovalSentinel

cst.RemovalSentinel: A sentinel indicating whether the node should be removed.

Source code in write_the/cst/node_remover.py
def leave_FunctionDef(
    self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.RemovalSentinel:
    """
    Removes a FunctionDef node from the tree if it is in the list of nodes to remove. The node is identified by its fully qualified name, which includes the class name if the function is a method.

    Args:
      original_node (cst.FunctionDef): The original FunctionDef node.
      updated_node (cst.FunctionDef): The updated FunctionDef node.

    Returns:
      cst.RemovalSentinel: A sentinel indicating whether the node should be removed.
    """
    name = (
        f"{self.current_class}.{original_node.name.value}"
        if self.current_class
        else original_node.name.value
    )
    if name in self.nodes:
        return cst.RemoveFromParent()
    return updated_node

remove_nodes_from_tree(tree, nodes)

Removes specified nodes from a CST tree.

Parameters:

Name Type Description Default
tree cst.CSTNode

The CST tree to remove nodes from.

required
nodes list

A list of nodes to remove.

required

Returns:

Type Description

cst.CSTNode: The updated CST tree after removal of specified nodes.

Source code in write_the/cst/node_remover.py
def remove_nodes_from_tree(tree, nodes):
    """
    Removes specified nodes from a CST tree.

    Args:
      tree (cst.CSTNode): The CST tree to remove nodes from.
      nodes (list): A list of nodes to remove.

    Returns:
      cst.CSTNode: The updated CST tree after removal of specified nodes.
    """
    remover = NodeRemover(nodes)
    tree = tree.visit(remover)
    return tree