StrictDoc Documentation
strictdoc/backend/sdoc_source_code/reader_python.py
Source file coverage
Path:
strictdoc/backend/sdoc_source_code/reader_python.py
Lines:
391
Non-empty lines:
331
Non-empty lines covered with requirements:
331 / 331 (100.0%)
Functions:
8
Functions covered by requirements:
8 / 8 (100.0%)
1
"""
2
@relation(SDOC-SRS-142, SDOC-SRS-147, scope=file)
3
"""
4
 
5
from itertools import islice
6
from typing import Any, List, Optional, Sequence, Tuple
7
 
8
import tree_sitter_python
9
from tree_sitter import Language, Node, Parser
10
 
11
from strictdoc.backend.sdoc_source_code.marker_parser import MarkerParser
12
from strictdoc.backend.sdoc_source_code.models.language import LanguageItem
13
from strictdoc.backend.sdoc_source_code.models.language_item_marker import (
14
    LanguageItemMarker,
15
)
16
from strictdoc.backend.sdoc_source_code.models.line_marker import LineMarker
17
from strictdoc.backend.sdoc_source_code.models.range_marker import (
18
    RangeMarker,
19
)
20
from strictdoc.backend.sdoc_source_code.models.source_file_info import (
21
    RelationMarkerType,
22
    SourceFileTraceabilityInfo,
23
)
24
from strictdoc.backend.sdoc_source_code.models.source_location import ByteRange
25
from strictdoc.backend.sdoc_source_code.parse_context import ParseContext
26
from strictdoc.backend.sdoc_source_code.processors.general_language_marker_processors import (
27
    language_item_marker_processor,
28
    line_marker_processor,
29
    range_marker_processor,
30
    source_file_traceability_info_processor,
31
)
32
from strictdoc.backend.sdoc_source_code.tree_sitter_helpers import traverse_tree
33
from strictdoc.helpers.file_stats import SourceFileStats
34
from strictdoc.helpers.file_system import file_open_read_bytes
35
 
36
 
37
class SourceFileTraceabilityReader_Python:
38
    @staticmethod
39
    def supported_elements() -> list[str]:
40
        return ["function", "class"]
41
 
42
    def read(
43
        self, input_buffer: bytes, file_path: Optional[str] = None
44
    ) -> SourceFileTraceabilityInfo:
45
        assert isinstance(input_buffer, bytes)
46
 
47
        file_size = len(input_buffer)
48
 
49
        traceability_info = SourceFileTraceabilityInfo([])
50
 
51
        if file_size == 0:
52
            return traceability_info
53
 
54
        file_stats = SourceFileStats.create(input_buffer)
55
        parse_context = ParseContext(file_path, file_stats)
56
 
57
        language_arg = tree_sitter_python.language()
58
        py_language = Language(language_arg)
59
        parser = Parser(py_language)
60
 
61
        tree = parser.parse(input_buffer)
62
 
63
        module_function: Optional[LanguageItem] = None
64
 
65
        nodes = traverse_tree(tree)
66
        map_node_to_function: dict[Node, LanguageItem] = {}
67
 
68
        visited_comments = set()
69
        for node_ in nodes:
70
            if node_.type == "module":
71
                function = LanguageItem(
72
                    parent=traceability_info,
73
                    name="module",
74
                    display_name="module",
75
                    line_begin=node_.start_point[0] + 1,
76
                    line_end=node_.end_point[0] + 1,
77
                    code_byte_range=ByteRange.create_from_ts_node(node_),
78
                    child_functions=[],
79
                    markers=[],
80
                    attributes=set(),
81
                )
82
                module_function = function
83
                map_node_to_function[node_] = function
84
                if len(node_.children) > 0:
85
                    # Look for the docstring within the first 30 children (arbitrary chosen limit)
86
                    # so that we dont miss it if the file starts with comments (#!, encoding marker, etc...).
87
                    first_match = next(
88
                        (
89
                            child
90
                            for child in islice(node_.children, 30)
91
                            if child.type == "expression_statement"
92
                            and len(child.children) > 0
93
                            and child.children[0].type == "string"
94
                        ),
95
                        None,
96
                    )
97
 
98
                    if first_match is not None:
99
                        block_comment = first_match.children[0]
100
 
101
                        # String contains of three parts:
102
                        # (string_start string_content string_end)
103
                        string_content = block_comment.children[1]
104
                        assert string_content.text is not None
105
 
106
                        block_comment_text = string_content.text.decode("utf-8")
107
                        source_node = MarkerParser.parse(
108
                            input_string=block_comment_text,
109
                            line_start=node_.start_point[0] + 1,
110
                            # It is important that +1 is not present here because
111
                            # currently StrictDoc does not display the last empty line (\n is 10).
112
                            line_end=node_.end_point[0]
113
                            if input_buffer[-1] == 10
114
                            else node_.end_point[0] + 1,
115
                            comment_line_start=string_content.start_point[0]
116
                            + 1,
117
                            comment_byte_range=ByteRange.create_from_ts_node(
118
                                string_content
119
                            ),
120
                            filename=parse_context.filename,
121
                        )
122
                        for marker_ in source_node.markers:
123
                            if isinstance(marker_, LanguageItemMarker) and (
124
                                language_item_marker_ := marker_
125
                            ):
126
                                language_item_marker_processor(
127
                                    language_item_marker_, parse_context
128
                                )
129
                                traceability_info.markers.append(
130
                                    language_item_marker_
131
                                )
132
 
133
            elif node_.type in ("class_definition", "function_definition"):
134
                function_name: str = ""
135
                function_block: Optional[Node] = None
136
 
137
                for child_ in node_.children:
138
                    if child_.type == "identifier":
139
                        if child_.text is not None:
140
                            function_name = child_.text.decode("utf-8")
141
                    if child_.type == "block":
142
                        function_block = child_
143
 
144
                assert function_name is not None, "Function name"
145
 
146
                parent_names = self.get_node_ns(node_)
147
                if parent_names:
148
                    function_name = f"{'.'.join(parent_names)}.{function_name}"
149
 
150
                language_item_markers: List[RelationMarkerType] = []
151
                block_comment = None
152
                if (
153
                    function_block is not None
154
                    and len(function_block.children) > 0
155
                    and function_block.children[0].type
156
                    == "expression_statement"
157
                ):
158
                    if len(function_block.children[0].children) > 0:
159
                        if (
160
                            function_block.children[0].children[0].type
161
                            == "string"
162
                        ):
163
                            block_comment = function_block.children[0].children[
164
                                0
165
                            ]
166
                            # String contains of three parts:
167
                            # (string_start string_content string_end)
168
                            string_content = block_comment.children[1]
169
                            assert string_content.text is not None
170
 
171
                            block_comment_text = string_content.text.decode(
172
                                "utf-8"
173
                            )
174
                            source_node = MarkerParser.parse(
175
                                input_string=block_comment_text,
176
                                line_start=node_.start_point[0] + 1,
177
                                line_end=node_.end_point[0] + 1,
178
                                comment_line_start=string_content.start_point[0]
179
                                + 1,
180
                                comment_byte_range=ByteRange.create_from_ts_node(
181
                                    string_content
182
                                ),
183
                                filename=parse_context.filename,
184
                                entity_name=function_name,
185
                            )
186
                            for marker_ in source_node.markers:
187
                                if isinstance(marker_, LanguageItemMarker):
188
                                    language_item_marker_processor(
189
                                        marker_, parse_context
190
                                    )
191
                                    traceability_info.markers.append(marker_)
192
                                    language_item_markers.append(marker_)
193
 
194
                new_function = LanguageItem(
195
                    parent=traceability_info,
196
                    name=function_name,
197
                    display_name=function_name,
198
                    line_begin=node_.range.start_point[0] + 1,
199
                    line_end=node_.range.end_point[0] + 1,
200
                    code_byte_range=ByteRange.create_from_ts_node(node_),
201
                    child_functions=[],
202
                    # Python functions do not need to track markers.
203
                    markers=[],
204
                    attributes=set(),
205
                )
206
                map_node_to_function[node_] = new_function
207
 
208
                parent_function = self.get_parent_language_item(
209
                    node_, map_node_to_function, module_function, file_path
210
                )
211
 
212
                parent_function.child_functions.append(new_function)
213
                traceability_info.functions.append(new_function)
214
 
215
                traceability_info.ng_map_names_to_markers[function_name] = (
216
                    language_item_markers
217
                )
218
            elif node_.type == "comment":
219
                if node_ in visited_comments:
220
                    continue
221
 
222
                assert node_.parent is not None
223
                assert node_.text is not None, (
224
                    f"Comment without a text: {node_}"
225
                )
226
 
227
                if not SourceFileTraceabilityReader_Python.is_comment_alone_on_line(
228
                    node_
229
                ):
230
                    continue
231
 
232
                merged_comments, last_idx = (
233
                    SourceFileTraceabilityReader_Python.collect_consecutive_comments(
234
                        node_
235
                    )
236
                )
237
 
238
                for j in range(node_.parent.children.index(node_), last_idx):
239
                    visited_comments.add(node_.parent.children[j])
240
 
241
                last_comment = node_.parent.children[last_idx - 1]
242
 
243
                source_node = MarkerParser.parse(
244
                    input_string=merged_comments,
245
                    line_start=node_.start_point[0] + 1,
246
                    line_end=last_comment.end_point[0] + 1,
247
                    comment_line_start=node_.start_point[0] + 1,
248
                    comment_byte_range=ByteRange(
249
                        node_.start_byte, last_comment.end_byte
250
                    ),
251
                    filename=parse_context.filename,
252
                    entity_name=None,
253
                )
254
                for marker_ in source_node.markers:
255
                    if isinstance(marker_, RangeMarker) and (
256
                        range_marker := marker_
257
                    ):
258
                        range_marker_processor(range_marker, parse_context)
259
                    elif isinstance(marker_, LineMarker) and (
260
                        line_marker := marker_
261
                    ):
262
                        line_marker_processor(line_marker, parse_context)
263
                    else:
264
                        pass
265
            else:
266
                pass
267
 
268
        assert module_function is not None
269
        assert module_function.name in ("module", "translation_unit")
270
 
271
        source_file_traceability_info_processor(
272
            traceability_info, parse_context
273
        )
274
 
275
        return traceability_info
276
 
277
    def read_from_file(self, file_path: str) -> SourceFileTraceabilityInfo:
278
        with file_open_read_bytes(file_path) as file:
279
            sdoc_content = file.read()
280
            sdoc = self.read(sdoc_content, file_path=file_path)
281
            return sdoc
282
 
283
    @staticmethod
284
    def get_parent_language_item(
285
        node: Node,
286
        map_node_to_function: dict[Node, LanguageItem],
287
        module_function: Optional[LanguageItem],
288
        file_path: Optional[str],
289
    ) -> LanguageItem:
290
        cursor = node.parent
291
        while cursor is not None:
292
            if cursor.type in (
293
                "class_definition",
294
                "function_definition",
295
                "module",
296
            ):
297
                parent_function = map_node_to_function.get(cursor)
298
                if parent_function is not None:
299
                    return parent_function
300
            cursor = cursor.parent
301
 
302
        assert module_function is not None, file_path
303
        return module_function
304
 
305
    @staticmethod
306
    def get_node_ns(node: Node) -> Sequence[str]:
307
        """
308
        Walk up from node to find enclosing class and function names.
309
 
310
        Handles nested functions, methods, and classes.
311
        """
312
        parent_scopes: List[str] = []
313
        cursor: Optional[Node] = node
314
 
315
        while cursor is not None:
316
            if cursor.type in ("class_definition", "function_definition"):
317
                # Look for the identifier child (i.e., the name).
318
                name_node = next(
319
                    (
320
                        child
321
                        for child in cursor.children
322
                        if child.type == "identifier"
323
                    ),
324
                    None,
325
                )
326
                if name_node and name_node.text:
327
                    parent_scopes.insert(0, name_node.text.decode("utf-8"))
328
            cursor = cursor.parent
329
 
330
        # The array now contains the "fully qualified" node name,
331
        # we want to return the namespace, so don't return the last part.
332
        return parent_scopes[:-1]
333
 
334
    @staticmethod
335
    def collect_consecutive_comments(comment_node: Any) -> Tuple[str, int]:
336
        parent = comment_node.parent
337
 
338
        siblings = parent.children
339
        idx = siblings.index(comment_node)
340
 
341
        merged_texts = []
342
 
343
        last_node = None
344
 
345
        while idx < len(siblings) and siblings[idx].type == "comment":
346
            n = siblings[idx]
347
            assert n.text is not None
348
            text = n.text.decode("utf8")
349
 
350
            if last_node is not None:
351
                # Tree-sitter line numbers are 0-based
352
                last_end_line = last_node.end_point[0]
353
                curr_start_line = n.start_point[0]
354
 
355
                # Stop merging if there is an empty line between comments
356
                if curr_start_line > last_end_line + 1:
357
                    break
358
 
359
            merged_texts.append(text)
360
            last_node = n
361
            idx += 1
362
 
363
        return "\n".join(merged_texts), idx
364
 
365
    @staticmethod
366
    def is_comment_alone_on_line(node: Any) -> bool:
367
        """
368
        Return True if the comment node is the only thing on its line (ignoring whitespace).
369
        """
370
 
371
        if node.type != "comment":
372
            return False
373
 
374
        parent = node.parent
375
        assert parent is not None
376
 
377
        comment_line = node.start_point[0]
378
 
379
        for sibling in parent.children:
380
            if sibling is node:
381
                continue
382
            start_line = sibling.start_point[0]
383
            end_line = sibling.end_point[0]
384
 
385
            # If sibling shares the same line as comment
386
            if start_line <= comment_line <= end_line:
387
                # If it's not a comment (code, punctuation, etc.)
388
                if sibling.type != "comment":
389
                    return False
390
 
391
        return True