StrictDoc Documentation
strictdoc/backend/sdoc_source_code/reader_c.py
Source file coverage
Path:
strictdoc/backend/sdoc_source_code/reader_c.py
Lines:
499
Non-empty lines:
445
Non-empty lines covered with requirements:
445 / 445 (100.0%)
Functions:
7
Functions covered by requirements:
7 / 7 (100.0%)
1
"""
2
@relation(SDOC-SRS-142, SDOC-SRS-146, scope=file)
3
"""
4
 
5
from typing import Final, List, Optional, Sequence
6
 
7
import tree_sitter_cpp
8
from tree_sitter import Language, Node, Parser
9
 
10
from strictdoc.backend.sdoc_source_code.constants import FunctionAttribute
11
from strictdoc.backend.sdoc_source_code.marker_parser import MarkerParser
12
from strictdoc.backend.sdoc_source_code.models.language import LanguageItem
13
from strictdoc.backend.sdoc_source_code.models.language_item_marker import (
14
    LanguageItemMarker,
15
    RangeMarkerType,
16
)
17
from strictdoc.backend.sdoc_source_code.models.line_marker import LineMarker
18
from strictdoc.backend.sdoc_source_code.models.range_marker import (
19
    RangeMarker,
20
)
21
from strictdoc.backend.sdoc_source_code.models.source_file_info import (
22
    SourceFileTraceabilityInfo,
23
)
24
from strictdoc.backend.sdoc_source_code.models.source_location import ByteRange
25
from strictdoc.backend.sdoc_source_code.models.source_node import SourceNode
26
from strictdoc.backend.sdoc_source_code.parse_context import ParseContext
27
from strictdoc.backend.sdoc_source_code.processors.general_language_marker_processors import (
28
    language_item_marker_processor,
29
    line_marker_processor,
30
    range_marker_processor,
31
    source_file_traceability_info_processor,
32
)
33
from strictdoc.backend.sdoc_source_code.tree_sitter_helpers import (
34
    traverse_tree,
35
    ts_find_child_node_by_type,
36
    ts_find_child_nodes_by_type,
37
)
38
from strictdoc.helpers.cast import assert_cast
39
from strictdoc.helpers.file_stats import SourceFileStats
40
from strictdoc.helpers.file_system import file_open_read_bytes
41
 
42
KNOWN_FUNCTION_DEFINITION_MACROS: Final[frozenset[str]] = frozenset(
43
    (
44
        # Linux
45
        "COMPAT_SYSCALL_DEFINE0",
46
        "COMPAT_SYSCALL_DEFINE1",
47
        "COMPAT_SYSCALL_DEFINE2",
48
        "COMPAT_SYSCALL_DEFINE3",
49
        "COMPAT_SYSCALL_DEFINE4",
50
        "COMPAT_SYSCALL_DEFINE5",
51
        "COMPAT_SYSCALL_DEFINE6",
52
        "FIXTURE_SETUP",
53
        "FIXTURE_TEARDOWN",
54
        "SYSCALL_DEFINE0",
55
        "SYSCALL_DEFINE1",
56
        "SYSCALL_DEFINE2",
57
        "SYSCALL_DEFINE3",
58
        "SYSCALL_DEFINE4",
59
        "SYSCALL_DEFINE5",
60
        "SYSCALL_DEFINE6",
61
        # Google Test and Linux
62
        "TEST",
63
        "TEST_F",
64
        "TEST_P",
65
        "TEST_F_SIGNAL",
66
        "TYPED_TEST",
67
        # Zephyr
68
        "ZTEST_USER",
69
    )
70
)
71
 
72
 
73
class SourceFileTraceabilityReader_C:
74
    @staticmethod
75
    def supported_elements() -> list[str]:
76
        return ["function", "class"]
77
 
78
    def __init__(self, custom_tags: Optional[set[str]] = None) -> None:
79
        self.custom_tags: Optional[set[str]] = custom_tags
80
 
81
    def read(
82
        self,
83
        input_buffer: bytes,
84
        file_path: Optional[str] = None,
85
    ) -> SourceFileTraceabilityInfo:
86
        assert isinstance(input_buffer, bytes)
87
 
88
        file_stats = SourceFileStats.create(input_buffer)
89
        parse_context = ParseContext(file_path, file_stats)
90
 
91
        language_arg = tree_sitter_cpp.language()
92
        py_language = Language(language_arg)
93
        parser = Parser(py_language)
94
 
95
        tree = parser.parse(input_buffer)
96
 
97
        traceability_info = SourceFileTraceabilityInfo([])
98
 
99
        nodes = traverse_tree(tree)
100
 
101
        source_node: Optional[SourceNode]
102
        for node_ in nodes:
103
            function_name: str
104
            language_item_markers: List[LanguageItemMarker]
105
            function_comment_node: Optional[Node]
106
            if node_.type == "translation_unit":
107
                if (
108
                    len(node_.children) > 0
109
                    and node_.children[0].type == "comment"
110
                    and (comment_node := node_.children[0])
111
                ):
112
                    if comment_node.text is not None:
113
                        comment_text = comment_node.text.decode("utf-8")
114
                        source_node = MarkerParser.parse(
115
                            input_string=comment_text,
116
                            line_start=node_.start_point[0] + 1,
117
                            # It is important that +1 is not present here because
118
                            # currently StrictDoc does not display the last empty line (\n is 10).
119
                            line_end=node_.end_point[0]
120
                            if input_buffer[-1] == 10
121
                            else node_.end_point[0] + 1,
122
                            comment_line_start=node_.start_point[0] + 1,
123
                            comment_byte_range=ByteRange.create_from_ts_node(
124
                                comment_node
125
                            ),
126
                            filename=parse_context.filename,
127
                            custom_tags=self.custom_tags,
128
                        )
129
                        for marker_ in source_node.markers:
130
                            if not isinstance(marker_, LanguageItemMarker):
131
                                continue
132
                            # At the top level, only accept the scope=file markers.
133
                            # Everything else will be handled by functions and classes.
134
                            if marker_.scope != RangeMarkerType.FILE:
135
                                continue
136
                            if isinstance(marker_, LanguageItemMarker) and (
137
                                language_item_marker_ := marker_
138
                            ):
139
                                language_item_marker_processor(
140
                                    language_item_marker_, parse_context
141
                                )
142
                                traceability_info.markers.append(
143
                                    language_item_marker_
144
                                )
145
 
146
            elif node_.type in ("declaration", "field_declaration"):
147
                function_declarator_node = ts_find_child_node_by_type(
148
                    node_, "function_declarator"
149
                )
150
 
151
                # C++ reference declaration wrap the function declaration one time.
152
                if function_declarator_node is None:
153
                    # Example: "TrkVertex& operator-=(const TrkVertex& c);".
154
                    reference_declarator_node = ts_find_child_node_by_type(
155
                        node_, "reference_declarator"
156
                    )
157
                    if reference_declarator_node is None:
158
                        continue
159
 
160
                    function_declarator_node = ts_find_child_node_by_type(
161
                        reference_declarator_node, "function_declarator"
162
                    )
163
                    if function_declarator_node is None:
164
                        continue
165
 
166
                # For normal C functions the identifier is "identifier".
167
                # For C++, there are:
168
                # Class function declarations: bool CanSend(const CanFrame &frame);         # noqa: ERA001
169
                # Operators:                   TrkVertex& operator-=(const TrkVertex& c);   # noqa: ERA001
170
                # Destructors:                 ~TrkVertex();                                # noqa: ERA001
171
                function_identifier_node = self._get_function_name_node(
172
                    function_declarator_node,
173
                )
174
                if function_identifier_node is None:
175
                    continue
176
 
177
                if function_identifier_node.text is None:
178
                    continue
179
 
180
                assert function_identifier_node.text is not None, node_.text
181
                function_display_name = function_identifier_node.text.decode(
182
                    "utf8"
183
                )
184
 
185
                assert function_declarator_node.text is not None, node_.text
186
                function_name = function_declarator_node.text.decode("utf8")
187
                assert function_name is not None, node_.text
188
                # Remove extra trailing spaces, newlines etc added by code-formatting or linters
189
                function_name = " ".join(function_name.split())
190
 
191
                parent_names = self.get_node_ns(node_)
192
                if len(parent_names) > 0:
193
                    function_name = (
194
                        f"{'::'.join(parent_names)}::{function_name}"
195
                    )
196
                    function_display_name = (
197
                        f"{'::'.join(parent_names)}::{function_display_name}"
198
                    )
199
 
200
                function_attributes = {FunctionAttribute.DECLARATION}
201
                for specifier_node_ in ts_find_child_nodes_by_type(
202
                    node_, "storage_class_specifier"
203
                ):
204
                    if specifier_node_.text == b"static":
205
                        function_attributes.add(FunctionAttribute.STATIC)
206
 
207
                source_node = None
208
                language_item_markers = []
209
                function_comment_node = None
210
                if (
211
                    node_.prev_sibling is not None
212
                    and node_.prev_sibling.type == "comment"
213
                ):
214
                    function_comment_node = node_.prev_sibling
215
                    assert function_comment_node.text is not None, node_.text
216
                    function_comment_text = function_comment_node.text.decode(
217
                        "utf8"
218
                    )
219
 
220
                    function_last_line = node_.end_point[0] + 1
221
 
222
                    source_node = MarkerParser.parse(
223
                        input_string=function_comment_text,
224
                        line_start=function_comment_node.start_point[0] + 1,
225
                        line_end=function_last_line,
226
                        comment_line_start=function_comment_node.start_point[0]
227
                        + 1,
228
                        comment_byte_range=ByteRange.create_from_ts_node(
229
                            function_comment_node
230
                        ),
231
                        filename=parse_context.filename,
232
                        entity_name=function_display_name,
233
                        custom_tags=self.custom_tags,
234
                    )
235
                    for marker_ in source_node.markers:
236
                        if isinstance(marker_, LanguageItemMarker) and (
237
                            language_item_marker_ := marker_
238
                        ):
239
                            language_item_marker_processor(
240
                                language_item_marker_, parse_context
241
                            )
242
                            traceability_info.markers.append(
243
                                language_item_marker_
244
                            )
245
                            language_item_markers.append(marker_)
246
 
247
                # The function range includes the top comment if it exists.
248
                new_function = LanguageItem(
249
                    parent=traceability_info,
250
                    name=function_name,
251
                    display_name=function_display_name,
252
                    line_begin=function_comment_node.start_point[0] + 1
253
                    if function_comment_node is not None
254
                    else node_.range.start_point[0] + 1,
255
                    line_end=node_.range.end_point[0] + 1,
256
                    code_byte_range=ByteRange.create_from_ts_node(node_),
257
                    child_functions=[],
258
                    markers=language_item_markers,
259
                    attributes=function_attributes,
260
                )
261
                if source_node is not None:
262
                    source_node.function = new_function
263
                traceability_info.functions.append(new_function)
264
 
265
            elif node_.type == "function_definition":
266
                function_name = ""
267
 
268
                try:
269
                    function_declarator_node = ts_find_child_node_by_type(
270
                        node_, "function_declarator", raise_on_error=True
271
                    )
272
                except LookupError:
273
                    # Probably confused by macro, skip node to avoid processing random subtrees.
274
                    continue
275
                # C++ reference declaration wrap the function declaration one time.
276
                if function_declarator_node is None:
277
                    # Example: Foo& Foo::operator+(const Foo& c) { return *this; }
278
                    reference_declarator_node = ts_find_child_node_by_type(
279
                        node_, "reference_declarator"
280
                    )
281
                    if reference_declarator_node is None:
282
                        continue
283
 
284
                    function_declarator_node = ts_find_child_node_by_type(
285
                        reference_declarator_node, "function_declarator"
286
                    )
287
                    if function_declarator_node is None:
288
                        continue
289
 
290
                assert function_declarator_node is not None, node_.text
291
 
292
                assert function_declarator_node.text is not None, node_.text
293
                function_name = function_declarator_node.text.decode("utf8")
294
 
295
                identifier_node = self._get_function_name_node(
296
                    function_declarator_node
297
                )
298
                if identifier_node is None:
299
                    raise NotImplementedError(function_declarator_node)
300
 
301
                assert identifier_node.text is not None, node_.text
302
                function_display_name = identifier_node.text.decode("utf8")
303
 
304
                assert function_name is not None, node_.text
305
                # Remove extra trailing spaces, newlines etc added by code-formatting or linters
306
                function_name = " ".join(function_name.split())
307
                parent_names = self.get_node_ns(node_)
308
 
309
                # The first if branch handles a special case where selected
310
                # macros are actually function definitions. Typical examples of
311
                # such functions:
312
                # 1) Google Test TEST(...)
313
                # 2) Zephyr RTOS ZTEST_USER(...)
314
                # 3) Linux SYSCALL2_DEFINE etc.
315
                if function_display_name in KNOWN_FUNCTION_DEFINITION_MACROS:
316
                    # Make the display name to include the entire macro/function
317
                    # signature.
318
                    # Example of this case:
319
                    # function_name: ZTEST_USER(semaphore, test_k_sem_correct_count_limit)  # noqa: ERA001
320
                    # function_display_name (before assignment): ZTEST_USER
321
                    function_display_name = function_name
322
                elif len(parent_names) > 0:
323
                    function_name = (
324
                        f"{'::'.join(parent_names)}::{function_name}"
325
                    )
326
                    function_display_name = (
327
                        f"{'::'.join(parent_names)}::{function_display_name}"
328
                    )
329
 
330
                source_node = None
331
                language_item_markers = []
332
                function_comment_node = None
333
                function_comment_text = None
334
 
335
                # In the condition below, it is important that the comment is
336
                # considered a function comment only if it there are no empty
337
                # lines between the comment and function.
338
                if (
339
                    node_.prev_sibling is not None
340
                    and node_.prev_sibling.type == "comment"
341
                    and (node_.prev_sibling.end_point[0] + 1)
342
                    == node_.start_point[0]
343
                ):
344
                    function_comment_node = node_.prev_sibling
345
                    assert function_comment_node.text is not None, node_.text
346
                    function_comment_text = function_comment_node.text.decode(
347
                        "utf8"
348
                    )
349
 
350
                    function_last_line = node_.end_point[0] + 1
351
 
352
                    source_node = MarkerParser.parse(
353
                        input_string=function_comment_text,
354
                        line_start=function_comment_node.start_point[0] + 1,
355
                        line_end=function_last_line,
356
                        comment_line_start=function_comment_node.start_point[0]
357
                        + 1,
358
                        comment_byte_range=ByteRange.create_from_ts_node(
359
                            function_comment_node
360
                        ),
361
                        filename=parse_context.filename,
362
                        entity_name=function_display_name,
363
                        custom_tags=self.custom_tags,
364
                    )
365
 
366
                    traceability_info.source_nodes.append(source_node)
367
                    for marker_ in source_node.markers:
368
                        if isinstance(marker_, LanguageItemMarker):
369
                            language_item_marker_processor(
370
                                marker_, parse_context
371
                            )
372
                            traceability_info.markers.append(marker_)
373
                            language_item_markers.append(marker_)
374
 
375
                # The function range includes the top comment if it exists.
376
                new_function = LanguageItem(
377
                    parent=traceability_info,
378
                    name=function_name,
379
                    display_name=function_display_name,
380
                    line_begin=function_comment_node.start_point[0] + 1
381
                    if function_comment_node is not None
382
                    else node_.range.start_point[0] + 1,
383
                    line_end=node_.range.end_point[0] + 1,
384
                    code_byte_range=ByteRange.create_from_ts_node(node_),
385
                    child_functions=[],
386
                    markers=language_item_markers,
387
                    attributes={FunctionAttribute.DEFINITION},
388
                )
389
                traceability_info.functions.append(new_function)
390
                if len(language_item_markers) > 0:
391
                    traceability_info.ng_map_names_to_markers[function_name] = (
392
                        # FIXME: Cannot win the fight with mypy without assert_cast.
393
                        assert_cast(language_item_markers, list)
394
                    )
395
                    traceability_info.ng_map_names_to_definition_functions[
396
                        function_name
397
                    ] = new_function
398
                if source_node is not None:
399
                    source_node.function = new_function
400
            elif node_.type == "comment":
401
                #
402
                # FIXME: Here parsing of function comments can happen as well
403
                #        but this time the focus is ONLY on range and line markers.
404
                #        The case which is handled here is when a user adds a
405
                #        range_start marker in a function comment.
406
                #        It is not good that parsing of function comments
407
                #        happens twice.
408
                #
409
 
410
                assert node_.text is not None, (
411
                    f"Comment without a text: {node_}"
412
                )
413
 
414
                node_text_string = node_.text.decode("utf8")
415
 
416
                source_node = MarkerParser.parse(
417
                    input_string=node_text_string,
418
                    line_start=node_.start_point[0] + 1,
419
                    line_end=node_.end_point[0] + 1,
420
                    comment_line_start=node_.start_point[0] + 1,
421
                    comment_byte_range=ByteRange.create_from_ts_node(node_),
422
                    filename=parse_context.filename,
423
                    custom_tags=None,
424
                )
425
 
426
                for marker_ in source_node.markers:
427
                    if isinstance(marker_, RangeMarker) and (
428
                        range_marker_ := marker_
429
                    ):
430
                        range_marker_processor(range_marker_, parse_context)
431
                    elif isinstance(marker_, LineMarker) and (
432
                        line_marker_ := marker_
433
                    ):
434
                        line_marker_processor(line_marker_, parse_context)
435
                    else:
436
                        pass
437
            else:
438
                pass
439
 
440
        source_file_traceability_info_processor(
441
            traceability_info, parse_context
442
        )
443
 
444
        traceability_info.ng_map_reqs_to_markers = (
445
            parse_context.map_reqs_to_markers
446
        )
447
 
448
        return traceability_info
449
 
450
    def read_from_file(self, file_path: str) -> SourceFileTraceabilityInfo:
451
        with file_open_read_bytes(file_path) as file:
452
            sdoc_content = file.read()
453
            sdoc = self.read(sdoc_content, file_path=file_path)
454
            return sdoc
455
 
456
    @staticmethod
457
    def _get_function_name_node(
458
        function_declarator_node: Node,
459
    ) -> Optional[Node]:
460
        assert function_declarator_node.type == "function_declarator"
461
        function_identifier_node = ts_find_child_node_by_type(
462
            function_declarator_node,
463
            node_type=(
464
                "identifier",
465
                "field_identifier",
466
                "operator_name",
467
                "destructor_name",
468
                "qualified_identifier",
469
            ),
470
        )
471
        return function_identifier_node
472
 
473
    @staticmethod
474
    def get_node_ns(node: Node) -> Sequence[str]:
475
        """
476
        Walk up the tree and find parent classes.
477
        """
478
        parent_scopes = []
479
        cursor: Optional[Node] = node
480
        while cursor is not None:
481
            if cursor.type == "class_specifier" and len(cursor.children) > 1:
482
                second_node_or_none = cursor.children[1]
483
                if (
484
                    second_node_or_none.type == "type_identifier"
485
                    and second_node_or_none.text is not None
486
                ):
487
                    parent_class_name = second_node_or_none.text.decode("utf8")
488
                    parent_scopes.append(parent_class_name)
489
            elif cursor.type == "namespace_definition":
490
                for c in cursor.children:
491
                    if c.type == "namespace_identifier" and c.text is not None:
492
                        parent_class_name = c.text.decode("utf8")
493
                        parent_scopes.append(parent_class_name)
494
                        break
495
 
496
            cursor = cursor.parent
497
 
498
        parent_scopes.reverse()
499
        return parent_scopes