diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index d7c45694..d546feb8 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -666,6 +666,92 @@ def _filter_invalid_constraints( filtered_constraints.append(constraint) return filtered_constraints + def _filter_properties_required_field( + self, node_types: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Sanitize the 'required' field in node type properties. Ensures 'required' is a valid boolean. + converts known string values (true, yes, 1, false, no, 0) to booleans and removes unrecognized values. + """ + for node_type in node_types: + properties = node_type.get("properties", []) + if not properties: + continue + for prop in properties: + if not isinstance(prop, dict): + continue + + required_value = prop.get("required") + + # Not provided - will use Pydantic default (false) + if required_value is None: + continue + + # already a valid boolean + if isinstance(required_value, bool): + continue + + prop_name = prop.get("name", "unknown") + node_label = node_type.get("label", "unknown") + + # Convert to string to handle int values like 1 or 0 + required_str = str(required_value).lower() + + if required_str in ("true", "yes", "1"): + prop["required"] = True + logging.info( + f"Converted 'required' value '{required_value}' to True " + f"for property '{prop_name}' on node '{node_label}'" + ) + elif required_str in ("false", "no", "0"): + prop["required"] = False + logging.info( + f"Converted 'required' value '{required_value}' to False " + f"for property '{prop_name}' on node '{node_label}'" + ) + else: + logging.info( + f"Removing unrecognized 'required' value '{required_value}' " + f"for property '{prop_name}' on node '{node_label}'. " + f"Using default (False)." + ) + prop.pop("required", None) + + return node_types + + def _enforce_required_for_constraint_properties( + self, + node_types: List[Dict[str, Any]], + constraints: List[Dict[str, Any]], + ) -> None: + """Ensure properties with UNIQUENESS constraints are marked as required.""" + if not constraints: + return + + # Build a lookup for property_names and constraints + constraint_props: Dict[str, set[str]] = {} + for c in constraints: + if c.get("type") == "UNIQUENESS": + label = c.get("node_type") + prop = c.get("property_name") + if label and prop: + constraint_props.setdefault(label, set()).add(prop) + + # Skip node_types without constraints + for node_type in node_types: + label = node_type.get("label") + if label not in constraint_props: + continue + + props_to_fix = constraint_props[label] + for prop in node_type.get("properties", []): + if isinstance(prop, dict) and prop.get("name") in props_to_fix: + if prop.get("required") is not True: + logging.info( + f"Auto-setting 'required' as True for property '{prop.get('name')}' " + f"on node '{label}' (has UNIQUENESS constraint)." + ) + prop["required"] = True + def _clean_json_content(self, content: str) -> str: content = content.strip() @@ -746,12 +832,22 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema extracted_relationship_types ) + extracted_node_types = self._filter_properties_required_field( + extracted_node_types + ) + # Filter out invalid patterns before validation if extracted_patterns: extracted_patterns = self._filter_invalid_patterns( extracted_patterns, extracted_node_types, extracted_relationship_types ) + # Enforce required=true for properties with UNIQUENESS constraints + if extracted_constraints: + self._enforce_required_for_constraint_properties( + extracted_node_types, extracted_constraints + ) + # Filter out invalid constraints if extracted_constraints: extracted_constraints = self._filter_invalid_constraints( diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 6fedb511..e7269427 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -223,7 +223,13 @@ class SchemaExtractionTemplate(PromptTemplate): 8.2 Only use properties that seem to not have too many missing values in the sample. 8.3 Constraints reference node_types by label and specify which property is unique. 8.4 If a property appears in a uniqueness constraint it MUST also appear in the corresponding node_type as a property. - +9. REQUIRED PROPERTIES: +9.1 Mark a property as "required": true if every instance of that node/relationship type MUST have this property (non-nullable). +9.2 Mark a property as "required": false if the property is optional and may be absent on some instances. +9.3 Properties that are identifiers, names, or essential characteristics are typically required. +9.4 Properties that are supplementary information (phone numbers, descriptions, metadata) are typically optional. +9.5 When uncertain, default to "required": false. +9.6 If a property has a UNIQUENESS constraint, it MUST be marked as "required": true. Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST, LOCAL_DATETIME, LOCAL_TIME, POINT, STRING, ZONED_DATETIME, ZONED_TIME. @@ -236,7 +242,13 @@ class SchemaExtractionTemplate(PromptTemplate): "properties": [ {{ "name": "name", - "type": "STRING" + "type": "STRING", + "required": true + }}, + {{ + "name": "email", + "type": "STRING", + "required": false }} ] }} diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 98bb3fe5..3918ef8a 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -79,6 +79,28 @@ def test_node_type_additional_properties_default() -> None: assert node_type.additional_properties is True +def test_property_type_initalization() -> None: + prop = PropertyType(name="email", type="STRING") + assert prop.name == "email" + assert prop.type == "STRING" + assert prop.required is False + + +def test_property_type_with_required_true() -> None: + prop = PropertyType(name="id", type="INTEGER", required=True) + assert prop.required is True + + +def test_property_type_is_frozen() -> None: + prop = PropertyType(name="email", type="STRING", required=False) + + with pytest.raises(ValidationError): + prop.name = "other" + + with pytest.raises(ValidationError): + prop.required = True + + def test_relationship_type_initialization_from_string() -> None: relationship_type = RelationshipType.model_validate("REL") assert isinstance(relationship_type, RelationshipType) @@ -730,6 +752,55 @@ def schema_json_with_null_constraints() -> str: """ +@pytest.fixture +def schema_json_with_required_properties() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING", "required": true}, + {"name": "email", "type": "STRING", "required": false}, + {"name": "phone", "type": "STRING"} + ] + } + ], + "relationship_types": [ + {"label": "KNOWS"} + ], + "patterns": [ + ["Person", "KNOWS", "Person"] + ] + } + """ + + +@pytest.fixture +def schema_json_with_string_required_values() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING", "required": "true"}, + {"name": "email", "type": "STRING", "required": "yes"}, + {"name": "phone", "type": "STRING", "required": "false"}, + {"name": "address", "type": "STRING", "required": "no"} + ] + } + ], + "relationship_types": [ + {"label": "KNOWS"} + ], + "patterns": [ + ["Person", "KNOWS", "Person"] + ] + } + """ + + @pytest.fixture def invalid_schema_json() -> str: return """ @@ -1388,6 +1459,363 @@ def test_clean_json_content_plain_json( assert cleaned == '{"node_types": [{"label": "Person"}]}' +def test_filter_properties_required_field_valid_true( + schema_from_text: SchemaFromTextExtractor, +) -> None: + node_types = [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING", "required": True}], + } + ] + result = schema_from_text._filter_properties_required_field(node_types) + assert result[0]["properties"][0]["required"] is True + + +def test_filter_properties_required_field_valid_false( + schema_from_text: SchemaFromTextExtractor, +) -> None: + node_types = [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING", "required": False}], + } + ] + result = schema_from_text._filter_properties_required_field(node_types) + assert result[0]["properties"][0]["required"] is False + + +def test_filter_properties_required_field_string( + schema_from_text: SchemaFromTextExtractor, +) -> None: + node_types = [ + { + "label": "Person", + "properties": [ + {"name": "prop1", "type": "STRING", "required": "true"}, + {"name": "prop2", "type": "STRING", "required": "yes"}, + {"name": "prop3", "type": "STRING", "required": "1"}, + {"name": "prop4", "type": "STRING", "required": "TRUE"}, + ], + } + ] + result = schema_from_text._filter_properties_required_field(node_types) + for prop in result[0]["properties"]: + assert prop["required"] is True + node_types = [ + { + "label": "Person", + "properties": [ + {"name": "prop1", "type": "STRING", "required": "false"}, + {"name": "prop2", "type": "STRING", "required": "no"}, + {"name": "prop3", "type": "STRING", "required": "0"}, + {"name": "prop4", "type": "STRING", "required": "FALSE"}, + ], + } + ] + result = schema_from_text._filter_properties_required_field(node_types) + for prop in result[0]["properties"]: + assert prop["required"] is False + + +def test_filter_properties_required_field_invalid_string( + schema_from_text: SchemaFromTextExtractor, +) -> None: + node_types = [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING", "required": "mandatory"}, + {"name": "email", "type": "STRING", "required": "always"}, + ], + } + ] + result = schema_from_text._filter_properties_required_field(node_types) + + assert "required" not in result[0]["properties"][0] + assert "required" not in result[0]["properties"][1] + + +def test_filter_properties_required_field_int_values( + schema_from_text: SchemaFromTextExtractor, +) -> None: + """Test that int values like 1 and 0 are converted to True/False.""" + node_types = [ + { + "label": "Person", + "properties": [ + {"name": "prop1", "type": "STRING", "required": 1}, + {"name": "prop2", "type": "STRING", "required": 0}, + ], + } + ] + result = schema_from_text._filter_properties_required_field(node_types) + assert result[0]["properties"][0]["required"] is True + assert result[0]["properties"][1]["required"] is False + + +def test_filter_properties_required_field_invalid_type( + schema_from_text: SchemaFromTextExtractor, +) -> None: + """Test that unrecognized types like list and dict are removed.""" + node_types = [ + { + "label": "Person", + "properties": [ + {"name": "prop1", "type": "STRING", "required": []}, + {"name": "prop2", "type": "STRING", "required": {"value": True}}, + ], + } + ] + result = schema_from_text._filter_properties_required_field(node_types) + for prop in result[0]["properties"]: + assert "required" not in prop + + +def test_filter_properties_required_field_missing( + schema_from_text: SchemaFromTextExtractor, +) -> None: + node_types = [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ] + result = schema_from_text._filter_properties_required_field(node_types) + assert "required" not in result[0]["properties"][0] + + +def test_enforce_required_for_constraint_properties_sets_required_true( + schema_from_text: SchemaFromTextExtractor, +) -> None: + node_types: list[dict[str, Any]] = [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING", "required": False}, + {"name": "email", "type": "STRING", "required": False}, + ], + } + ] + constraints = [ + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"} + ] + + schema_from_text._enforce_required_for_constraint_properties( + node_types, constraints + ) + + # name should now be required=true + assert node_types[0]["properties"][0]["required"] is True + # email should remain required=false + assert node_types[0]["properties"][1]["required"] is False + + +def test_enforce_required_for_constraint_properties_already_true( + schema_from_text: SchemaFromTextExtractor, +) -> None: + node_types: list[dict[str, Any]] = [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING", "required": True}, + ], + } + ] + constraints = [ + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"} + ] + + schema_from_text._enforce_required_for_constraint_properties( + node_types, constraints + ) + + assert node_types[0]["properties"][0]["required"] is True + + +def test_enforce_required_for_constraint_properties_missing_required_field( + schema_from_text: SchemaFromTextExtractor, +) -> None: + node_types: list[dict[str, Any]] = [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"}, # No required field + ], + } + ] + constraints = [ + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"} + ] + + schema_from_text._enforce_required_for_constraint_properties( + node_types, constraints + ) + + assert node_types[0]["properties"][0]["required"] is True + + +def test_enforce_required_for_constraint_properties_no_constraints( + schema_from_text: SchemaFromTextExtractor, +) -> None: + node_types: list[dict[str, Any]] = [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING", "required": False}, + ], + } + ] + constraints: list[dict[str, Any]] = [] + + schema_from_text._enforce_required_for_constraint_properties( + node_types, constraints + ) + + assert node_types[0]["properties"][0]["required"] is False + + +def test_enforce_required_for_constraint_properties_skips_unconstrained_nodes( + schema_from_text: SchemaFromTextExtractor, +) -> None: + node_types: list[dict[str, Any]] = [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING", "required": False}, + ], + }, + { + "label": "Company", + "properties": [ + {"name": "name", "type": "STRING", "required": False}, + ], + }, + ] + constraints = [ + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"} + ] + + schema_from_text._enforce_required_for_constraint_properties( + node_types, constraints + ) + + # Person.name should be required=true + assert node_types[0]["properties"][0]["required"] is True + # Company.name should remain required=false (no constraint on Company) + assert node_types[1]["properties"][0]["required"] is False + + +@pytest.mark.asyncio +async def test_schema_from_text_with_required_properties( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_required_properties: str, +) -> None: + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_required_properties + ) + + schema = await schema_from_text.run(text="Sample text for test") + + person = schema.node_type_from_label("Person") + assert person is not None + + # Check required properties + name_prop = next((p for p in person.properties if p.name == "name"), None) + email_prop = next((p for p in person.properties if p.name == "email"), None) + phone_prop = next((p for p in person.properties if p.name == "phone"), None) + + assert name_prop is not None and name_prop.required is True + assert email_prop is not None and email_prop.required is False + assert phone_prop is not None and phone_prop.required is False + + +@pytest.mark.asyncio +async def test_schema_from_text_sanitizes_string_required_values( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_string_required_values: str, +) -> None: + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_string_required_values + ) + + schema = await schema_from_text.run(text="Sample text for test") + + person = schema.node_type_from_label("Person") + assert person is not None + + # true and yes should become True + name_prop = next((p for p in person.properties if p.name == "name"), None) + email_prop = next((p for p in person.properties if p.name == "email"), None) + assert name_prop is not None and name_prop.required is True + assert email_prop is not None and email_prop.required is True + + # false and no should become False + phone_prop = next((p for p in person.properties if p.name == "phone"), None) + address_prop = next((p for p in person.properties if p.name == "address"), None) + assert phone_prop is not None and phone_prop.required is False + assert address_prop is not None and address_prop.required is False + + +@pytest.mark.asyncio +async def test_schema_from_text_handles_missing_required_field( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + valid_schema_json: str, +) -> None: + mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json) + + schema = await schema_from_text.run(text="Sample text") + + person = schema.node_type_from_label("Person") + assert person is not None + + # All properties should have required=False (default) + for prop in person.properties: + assert prop.required is False + + +@pytest.mark.asyncio +async def test_schema_from_text_enforces_required_for_constrained_properties( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, +) -> None: + schema_json = """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING", "required": false}, + {"name": "email", "type": "STRING", "required": false} + ] + } + ], + "relationship_types": [], + "patterns": [], + "constraints": [ + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"} + ] + } + """ + mock_llm.ainvoke.return_value = LLMResponse(content=schema_json) + + schema = await schema_from_text.run(text="Sample text") + + person = schema.node_type_from_label("Person") + assert person is not None + + name_prop = next((p for p in person.properties if p.name == "name"), None) + email_prop = next((p for p in person.properties if p.name == "email"), None) + + # name should be auto-fixed to required=true + assert name_prop is not None and name_prop.required is True + # email should remain required=false + assert email_prop is not None and email_prop.required is False + + @pytest.mark.asyncio @patch("neo4j_graphrag.experimental.components.schema.get_structured_schema") async def test_schema_from_existing_graph(mock_get_structured_schema: Mock) -> None: