mirror of
https://github.com/modelcontextprotocol/servers.git
synced 2026-04-26 15:55:39 +02:00
feat: Add MCP server for generating notional data
- Implement data generation server with support for insurance data - Add comprehensive test suite with 16 test cases - Support custom schemas and data relationships - Use faker, mimesis, numpy, and SDV for realistic data - Pass all type checks with pyright and lint checks with ruff Co-Authored-By: alexander@anthropic.com <alexander@anthropic.com>
This commit is contained in:
42
src/datagen/README.md
Normal file
42
src/datagen/README.md
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# MCP Data Generation Server
|
||||||
|
|
||||||
|
This server implements the Model Context Protocol (MCP) to provide notional data generation capabilities using Python libraries including Faker, Mimesis, NumPy, and SDV.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Generate synthetic data tables based on specified schemas and parameters
|
||||||
|
- Support for multiple data generation libraries (Faker, Mimesis, SDV)
|
||||||
|
- Configurable row counts and column specifications
|
||||||
|
- Export data in CSV format
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mcp-server-datagen
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
The server exposes MCP tools for generating notional data:
|
||||||
|
|
||||||
|
- `generate_tables`: Generate multiple related tables based on a schema
|
||||||
|
- `define_schema`: Define table schemas with column specifications
|
||||||
|
- `export_csv`: Export generated data to CSV files
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
1. Create virtual environment and install dependencies:
|
||||||
|
```bash
|
||||||
|
uv venv
|
||||||
|
uv pip install -e ".[dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run type checking:
|
||||||
|
```bash
|
||||||
|
uv run --frozen pyright
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Build package:
|
||||||
|
```bash
|
||||||
|
uv build
|
||||||
|
```
|
||||||
30
src/datagen/pyproject.toml
Normal file
30
src/datagen/pyproject.toml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "mcp-server-datagen"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "MCP server for generating notional data using Python libraries"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"faker>=20.1.0",
|
||||||
|
"mimesis>=13.1.0",
|
||||||
|
"numpy>=1.26.0",
|
||||||
|
"sdv>=1.5.0",
|
||||||
|
"pandas>=2.1.0",
|
||||||
|
"mcp>=1.0.0",
|
||||||
|
"pydantic>=2.0.0"
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.0.0",
|
||||||
|
"black>=23.0.0",
|
||||||
|
"pyright>=1.1.0"
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"ruff>=0.8.2",
|
||||||
|
]
|
||||||
3
src/datagen/src/mcp_server_datagen/__init__.py
Normal file
3
src/datagen/src/mcp_server_datagen/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""MCP server for generating notional data."""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
7
src/datagen/src/mcp_server_datagen/__main__.py
Normal file
7
src/datagen/src/mcp_server_datagen/__main__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""Main entry point for the data generation server."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from mcp_server_datagen.server import serve
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(serve())
|
||||||
75
src/datagen/src/mcp_server_datagen/generators.py
Normal file
75
src/datagen/src/mcp_server_datagen/generators.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""Data generation utilities using Faker, Mimesis, NumPy, and SDV."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, cast
|
||||||
|
from faker import Faker
|
||||||
|
from mimesis import Generic
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from .synthetic import SyntheticDataGenerator
|
||||||
|
|
||||||
|
|
||||||
|
class DataGenerator:
|
||||||
|
"""Handles data generation using multiple libraries."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.faker = Faker()
|
||||||
|
self.generic = Generic()
|
||||||
|
self.synthetic = SyntheticDataGenerator()
|
||||||
|
|
||||||
|
async def generate_table(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
schema: Dict[str, Dict[str, Any]],
|
||||||
|
rows: int = 1000
|
||||||
|
) -> Dict[str, List[Any]]:
|
||||||
|
"""Generate a table of data based on the provided schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the table
|
||||||
|
schema: Column definitions and parameters
|
||||||
|
rows: Number of rows to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing the generated data
|
||||||
|
"""
|
||||||
|
data: Dict[str, List[Any]] = {}
|
||||||
|
|
||||||
|
# Use SDV for generating correlated data
|
||||||
|
if any(col_spec.get("correlated", False) for col_spec in schema.values()):
|
||||||
|
return await self.synthetic.generate_synthetic_data(name, schema, rows)
|
||||||
|
|
||||||
|
# Generate individual columns using specified generators
|
||||||
|
for col_name, col_spec in schema.items():
|
||||||
|
generator = col_spec.get("generator", "faker")
|
||||||
|
data_type = col_spec.get("type", "string")
|
||||||
|
|
||||||
|
if generator == "faker":
|
||||||
|
data[col_name] = [
|
||||||
|
getattr(self.faker, data_type)()
|
||||||
|
for _ in range(rows)
|
||||||
|
]
|
||||||
|
elif generator == "mimesis":
|
||||||
|
data[col_name] = [
|
||||||
|
getattr(self.generic, data_type)()
|
||||||
|
for _ in range(rows)
|
||||||
|
]
|
||||||
|
elif generator == "numpy":
|
||||||
|
if data_type == "int":
|
||||||
|
int_values: NDArray[np.int64] = np.random.randint(
|
||||||
|
low=col_spec.get("min", 0),
|
||||||
|
high=col_spec.get("max", 100),
|
||||||
|
size=rows,
|
||||||
|
dtype=np.int64
|
||||||
|
)
|
||||||
|
data[col_name] = cast(List[Any], int_values.tolist())
|
||||||
|
elif data_type == "float":
|
||||||
|
min_val = float(col_spec.get("min", 0.0))
|
||||||
|
max_val = float(col_spec.get("max", 1.0))
|
||||||
|
float_values = np.random.uniform(
|
||||||
|
low=min_val,
|
||||||
|
high=max_val,
|
||||||
|
size=rows
|
||||||
|
).astype(np.float64)
|
||||||
|
data[col_name] = cast(List[Any], float_values.tolist())
|
||||||
|
|
||||||
|
return data
|
||||||
252
src/datagen/src/mcp_server_datagen/server.py
Normal file
252
src/datagen/src/mcp_server_datagen/server.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Sequence
|
||||||
|
|
||||||
|
from mcp.server import Server
|
||||||
|
from mcp.server.stdio import stdio_server
|
||||||
|
from mcp.types import Tool, TextContent, ImageContent, EmbeddedResource
|
||||||
|
from mcp.shared.exceptions import McpError
|
||||||
|
|
||||||
|
from mcp_server_datagen.synthetic import SyntheticDataGenerator
|
||||||
|
|
||||||
|
|
||||||
|
class DataGenServer:
|
||||||
|
"""MCP server for generating notional data."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.default_schemas = {
|
||||||
|
"customers": {
|
||||||
|
"customer_id": {
|
||||||
|
"type": "int",
|
||||||
|
"generator": "numpy",
|
||||||
|
"min": 10000,
|
||||||
|
"max": 99999
|
||||||
|
},
|
||||||
|
"first_name": {
|
||||||
|
"type": "first_name",
|
||||||
|
"generator": "faker"
|
||||||
|
},
|
||||||
|
"last_name": {
|
||||||
|
"type": "last_name",
|
||||||
|
"generator": "faker"
|
||||||
|
},
|
||||||
|
"email": {
|
||||||
|
"type": "email",
|
||||||
|
"generator": "faker"
|
||||||
|
},
|
||||||
|
"phone": {
|
||||||
|
"type": "phone_number",
|
||||||
|
"generator": "faker"
|
||||||
|
},
|
||||||
|
"address": {
|
||||||
|
"type": "address",
|
||||||
|
"generator": "faker"
|
||||||
|
},
|
||||||
|
"date_of_birth": {
|
||||||
|
"type": "date_of_birth",
|
||||||
|
"generator": "faker"
|
||||||
|
},
|
||||||
|
"credit_score": {
|
||||||
|
"type": "int",
|
||||||
|
"generator": "numpy",
|
||||||
|
"min": 300,
|
||||||
|
"max": 850
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"policies": {
|
||||||
|
"policy_id": {
|
||||||
|
"type": "int",
|
||||||
|
"generator": "numpy",
|
||||||
|
"min": 100000,
|
||||||
|
"max": 999999
|
||||||
|
},
|
||||||
|
"customer_id": {
|
||||||
|
"type": "int",
|
||||||
|
"generator": "numpy",
|
||||||
|
"min": 10000,
|
||||||
|
"max": 99999,
|
||||||
|
"correlated": True
|
||||||
|
},
|
||||||
|
"policy_type": {
|
||||||
|
"type": "category",
|
||||||
|
"generator": "numpy",
|
||||||
|
"categories": ["auto", "home", "life", "health"]
|
||||||
|
},
|
||||||
|
"start_date": {
|
||||||
|
"type": "date_this_decade",
|
||||||
|
"generator": "faker"
|
||||||
|
},
|
||||||
|
"end_date": {
|
||||||
|
"type": "date_this_decade",
|
||||||
|
"generator": "faker"
|
||||||
|
},
|
||||||
|
"premium": {
|
||||||
|
"type": "float",
|
||||||
|
"generator": "numpy",
|
||||||
|
"min": 500.0,
|
||||||
|
"max": 5000.0
|
||||||
|
},
|
||||||
|
"coverage_amount": {
|
||||||
|
"type": "float",
|
||||||
|
"generator": "numpy",
|
||||||
|
"min": 50000.0,
|
||||||
|
"max": 1000000.0
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "category",
|
||||||
|
"generator": "numpy",
|
||||||
|
"categories": ["active", "expired", "cancelled", "pending"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"claims": {
|
||||||
|
"claim_id": {
|
||||||
|
"type": "int",
|
||||||
|
"generator": "numpy",
|
||||||
|
"min": 1000000,
|
||||||
|
"max": 9999999
|
||||||
|
},
|
||||||
|
"policy_id": {
|
||||||
|
"type": "int",
|
||||||
|
"generator": "numpy",
|
||||||
|
"min": 100000,
|
||||||
|
"max": 999999,
|
||||||
|
"correlated": True
|
||||||
|
},
|
||||||
|
"date_filed": {
|
||||||
|
"type": "date_this_year",
|
||||||
|
"generator": "faker"
|
||||||
|
},
|
||||||
|
"incident_date": {
|
||||||
|
"type": "date_this_year",
|
||||||
|
"generator": "faker"
|
||||||
|
},
|
||||||
|
"claim_type": {
|
||||||
|
"type": "category",
|
||||||
|
"generator": "numpy",
|
||||||
|
"categories": ["accident", "theft", "natural_disaster", "medical", "property_damage"]
|
||||||
|
},
|
||||||
|
"amount_claimed": {
|
||||||
|
"type": "float",
|
||||||
|
"generator": "numpy",
|
||||||
|
"min": 1000.0,
|
||||||
|
"max": 100000.0
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "category",
|
||||||
|
"generator": "numpy",
|
||||||
|
"categories": ["pending", "approved", "denied", "in_review"]
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "text",
|
||||||
|
"generator": "faker"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.generator = SyntheticDataGenerator()
|
||||||
|
self.generator.default_schemas = self.default_schemas
|
||||||
|
|
||||||
|
async def list_tools(self) -> List[Tool]:
|
||||||
|
"""List available data generation tools."""
|
||||||
|
return [
|
||||||
|
Tool(
|
||||||
|
name="generate_tables",
|
||||||
|
description="Generate multiple tables of notional data",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tables": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"}
|
||||||
|
},
|
||||||
|
"rows": {"type": "integer", "minimum": 1},
|
||||||
|
"schemas": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {"type": "string"},
|
||||||
|
"generator": {"type": "string"},
|
||||||
|
"min": {"type": "number"},
|
||||||
|
"max": {"type": "number"},
|
||||||
|
"categories": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["tables"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def handle_generate_tables(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Handle generate_tables tool requests."""
|
||||||
|
tables = params.get("tables", [])
|
||||||
|
rows = params.get("rows", 1000)
|
||||||
|
custom_schemas = params.get("schemas", {})
|
||||||
|
|
||||||
|
if rows <= 0:
|
||||||
|
raise ValueError("Row count must be positive")
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
try:
|
||||||
|
for table_name in tables:
|
||||||
|
if table_name not in self.default_schemas and table_name not in custom_schemas:
|
||||||
|
raise ValueError(f"Unknown table: {table_name}")
|
||||||
|
|
||||||
|
# Use custom schema if provided, otherwise use default
|
||||||
|
schema = custom_schemas.get(table_name, self.default_schemas.get(table_name, {}))
|
||||||
|
data = await self.generator.generate_synthetic_data(
|
||||||
|
table_name=table_name,
|
||||||
|
schema=schema,
|
||||||
|
rows=rows
|
||||||
|
)
|
||||||
|
results[table_name] = data
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
# Re-raise validation errors directly
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
# Wrap unexpected errors in McpError
|
||||||
|
raise McpError(f"Error generating data: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def serve() -> None:
|
||||||
|
"""Start the MCP server."""
|
||||||
|
server = Server("mcp-datagen")
|
||||||
|
datagen_server = DataGenServer()
|
||||||
|
|
||||||
|
@server.list_tools()
|
||||||
|
async def list_tools() -> List[Tool]:
|
||||||
|
"""List available data generation tools."""
|
||||||
|
return await datagen_server.list_tools()
|
||||||
|
|
||||||
|
@server.call_tool()
|
||||||
|
async def call_tool(
|
||||||
|
name: str, arguments: Dict[str, Any]
|
||||||
|
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
||||||
|
"""Handle tool calls."""
|
||||||
|
if name == "generate_tables":
|
||||||
|
result = await datagen_server.handle_generate_tables(arguments)
|
||||||
|
return [
|
||||||
|
TextContent(
|
||||||
|
type="text",
|
||||||
|
text=json.dumps({"tables": result}, indent=2)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
raise McpError(f"Unknown tool: {name}")
|
||||||
|
|
||||||
|
options = server.create_initialization_options()
|
||||||
|
async with stdio_server() as (read_stream, write_stream):
|
||||||
|
await server.run(read_stream, write_stream, options)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(serve())
|
||||||
328
src/datagen/src/mcp_server_datagen/synthetic.py
Normal file
328
src/datagen/src/mcp_server_datagen/synthetic.py
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
"""Synthetic data generation using SDV."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Any, Set
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from faker import Faker
|
||||||
|
from mimesis import Generic
|
||||||
|
from sdv.single_table import GaussianCopulaSynthesizer
|
||||||
|
from sdv.metadata import SingleTableMetadata
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticDataGenerator:
|
||||||
|
"""Handles synthetic data generation using SDV."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the generator."""
|
||||||
|
self.synthesizers: Dict[str, GaussianCopulaSynthesizer] = {}
|
||||||
|
self.metadata: Dict[str, SingleTableMetadata] = {}
|
||||||
|
self.faker = Faker()
|
||||||
|
self.mimesis = Generic()
|
||||||
|
# Store generated IDs for relationships
|
||||||
|
self.generated_ids: Dict[str, Set[int]] = {}
|
||||||
|
# Initialize empty sets for each table
|
||||||
|
self.generated_ids["customers"] = set()
|
||||||
|
self.generated_ids["policies"] = set()
|
||||||
|
self.generated_ids["claims"] = set()
|
||||||
|
self.default_schemas: Dict[str, Dict[str, Dict[str, Any]]] = {}
|
||||||
|
# Counter for ID generation
|
||||||
|
self.id_counters: Dict[str, int] = {}
|
||||||
|
|
||||||
|
def create_metadata(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
schema: Dict[str, Dict[str, Any]]
|
||||||
|
) -> SingleTableMetadata:
|
||||||
|
"""Create metadata for a table based on schema."""
|
||||||
|
metadata = SingleTableMetadata()
|
||||||
|
|
||||||
|
for col_name, col_spec in schema.items():
|
||||||
|
data_type = col_spec.get("type", "string")
|
||||||
|
sdtype = self._map_type_to_sdtype(data_type)
|
||||||
|
metadata.add_column(
|
||||||
|
column_name=col_name,
|
||||||
|
sdtype=sdtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _map_type_to_sdtype(self, data_type: str) -> str:
|
||||||
|
"""Map data type to SDV type."""
|
||||||
|
type_mapping = {
|
||||||
|
"string": "categorical",
|
||||||
|
"int": "numerical",
|
||||||
|
"float": "numerical",
|
||||||
|
"datetime": "datetime",
|
||||||
|
"boolean": "boolean",
|
||||||
|
"category": "categorical"
|
||||||
|
}
|
||||||
|
return type_mapping.get(data_type, "categorical")
|
||||||
|
|
||||||
|
def _generate_faker_value(self, generator: str) -> Any:
|
||||||
|
"""Generate value using Faker."""
|
||||||
|
if not generator.startswith("faker."):
|
||||||
|
return None
|
||||||
|
|
||||||
|
method_name = generator.split(".", 1)[1]
|
||||||
|
if hasattr(self.faker, method_name):
|
||||||
|
return getattr(self.faker, method_name)()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _generate_mimesis_value(self, generator: str) -> Any:
|
||||||
|
"""Generate value using Mimesis."""
|
||||||
|
if not generator.startswith("mimesis."):
|
||||||
|
return None
|
||||||
|
|
||||||
|
category, method = generator.split(".", 1)[1].split(".")
|
||||||
|
if hasattr(self.mimesis, category):
|
||||||
|
category_instance = getattr(self.mimesis, category)
|
||||||
|
if hasattr(category_instance, method):
|
||||||
|
return getattr(category_instance, method)()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _generate_unique_id(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
col_spec: Dict[str, Any]
|
||||||
|
) -> int:
|
||||||
|
"""Generate a unique ID for a table using a hybrid sequential-random approach."""
|
||||||
|
min_val = col_spec.get("min", 1)
|
||||||
|
max_val = col_spec.get("max", 1000000)
|
||||||
|
range_size = max_val - min_val + 1
|
||||||
|
|
||||||
|
if table_name not in self.id_counters:
|
||||||
|
self.id_counters[table_name] = 0
|
||||||
|
if table_name not in self.generated_ids:
|
||||||
|
self.generated_ids[table_name] = set()
|
||||||
|
|
||||||
|
# Calculate a random offset within a smaller window
|
||||||
|
window_size = max(1, range_size // 1000) # Use 0.1% of range as window
|
||||||
|
attempts = 0
|
||||||
|
max_attempts = 10 # Limit retries to avoid infinite loops
|
||||||
|
|
||||||
|
while attempts < max_attempts:
|
||||||
|
base = min_val + (self.id_counters[table_name] * window_size)
|
||||||
|
offset = np.random.randint(0, window_size)
|
||||||
|
new_id = base + offset
|
||||||
|
|
||||||
|
# Handle wraparound
|
||||||
|
if new_id > max_val:
|
||||||
|
self.id_counters[table_name] = 0
|
||||||
|
new_id = min_val + np.random.randint(0, window_size)
|
||||||
|
|
||||||
|
# Check if ID is unique
|
||||||
|
if new_id not in self.generated_ids[table_name]:
|
||||||
|
self.generated_ids[table_name].add(new_id)
|
||||||
|
self.id_counters[table_name] += 1
|
||||||
|
return new_id
|
||||||
|
|
||||||
|
attempts += 1
|
||||||
|
|
||||||
|
# If we couldn't find a unique ID in the current window, move to next window
|
||||||
|
self.id_counters[table_name] += 1
|
||||||
|
return self._generate_unique_id(table_name, col_spec) # Recursive call with new window
|
||||||
|
|
||||||
|
def _generate_correlated_id(self, parent_table: str) -> int:
|
||||||
|
"""Generate a correlated ID from a parent table."""
|
||||||
|
if not self.generated_ids.get(parent_table):
|
||||||
|
raise ValueError(f"No IDs available for parent table {parent_table}")
|
||||||
|
parent_ids = list(self.generated_ids[parent_table])
|
||||||
|
return np.random.choice(parent_ids)
|
||||||
|
|
||||||
|
def _extract_parent_table(self, column_name: str) -> str:
|
||||||
|
"""Extract parent table name from column name."""
|
||||||
|
if not column_name.endswith("_id"):
|
||||||
|
raise ValueError(f"Column {column_name} is not a foreign key")
|
||||||
|
# Handle both singular and plural forms with special cases
|
||||||
|
table_name = column_name[:-3] # Remove _id
|
||||||
|
# Handle irregular plurals
|
||||||
|
irregular_plurals = {
|
||||||
|
"policy": "policies",
|
||||||
|
"company": "companies",
|
||||||
|
"category": "categories"
|
||||||
|
}
|
||||||
|
if table_name in irregular_plurals:
|
||||||
|
return irregular_plurals[table_name]
|
||||||
|
# Handle regular plurals
|
||||||
|
if not table_name.endswith('s'):
|
||||||
|
table_name += 's'
|
||||||
|
return table_name
|
||||||
|
|
||||||
|
def _clear_generated_ids(self, table_name: str) -> None:
|
||||||
|
"""Clear generated IDs for a table."""
|
||||||
|
if table_name in self.generated_ids:
|
||||||
|
del self.generated_ids[table_name]
|
||||||
|
|
||||||
|
async def fit_synthesizer(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
schema: Dict[str, Dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
|
"""Fit a synthesizer for the given table schema."""
|
||||||
|
metadata = self.create_metadata(table_name, schema)
|
||||||
|
fitting_size = min(100, 1000) # Use a small sample size for fitting
|
||||||
|
|
||||||
|
# Generate sample data for fitting
|
||||||
|
sample_data = {}
|
||||||
|
for col_name, col_spec in schema.items():
|
||||||
|
col_type = col_spec["type"]
|
||||||
|
is_correlated = col_spec.get("correlated", False)
|
||||||
|
|
||||||
|
if is_correlated and col_name.endswith("_id"):
|
||||||
|
# For correlated fields, use IDs from parent table
|
||||||
|
parent_table = self._extract_parent_table(col_name)
|
||||||
|
if parent_table not in self.generated_ids:
|
||||||
|
raise ValueError(f"Parent table {parent_table} must be generated before {table_name}")
|
||||||
|
parent_ids = list(self.generated_ids[parent_table])
|
||||||
|
sample_data[col_name] = [
|
||||||
|
np.random.choice(parent_ids) for _ in range(fitting_size)
|
||||||
|
]
|
||||||
|
elif col_type == "int":
|
||||||
|
min_val = col_spec.get("min", 0)
|
||||||
|
max_val = col_spec.get("max", 100)
|
||||||
|
if col_name.endswith("_id"):
|
||||||
|
# Generate unique IDs for primary keys
|
||||||
|
unique_ids = set()
|
||||||
|
while len(unique_ids) < fitting_size:
|
||||||
|
unique_ids.add(self._generate_unique_id(min_val, max_val))
|
||||||
|
sample_data[col_name] = list(unique_ids)
|
||||||
|
else:
|
||||||
|
sample_data[col_name] = [
|
||||||
|
np.random.randint(min_val, max_val + 1)
|
||||||
|
for _ in range(fitting_size)
|
||||||
|
]
|
||||||
|
elif col_type == "float":
|
||||||
|
min_val = col_spec.get("min", 0.0)
|
||||||
|
max_val = col_spec.get("max", 1.0)
|
||||||
|
sample_data[col_name] = [
|
||||||
|
np.random.uniform(min_val, max_val)
|
||||||
|
for _ in range(fitting_size)
|
||||||
|
]
|
||||||
|
elif col_type == "category":
|
||||||
|
categories = col_spec.get("categories", [])
|
||||||
|
sample_data[col_name] = [
|
||||||
|
np.random.choice(categories)
|
||||||
|
for _ in range(fitting_size)
|
||||||
|
]
|
||||||
|
elif col_type == "datetime":
|
||||||
|
if "generator" in col_spec:
|
||||||
|
generator_str = col_spec["generator"]
|
||||||
|
if generator_str.startswith("faker."):
|
||||||
|
sample_data[col_name] = [
|
||||||
|
self._generate_faker_value(generator_str)
|
||||||
|
for _ in range(fitting_size)
|
||||||
|
]
|
||||||
|
elif generator_str.startswith("mimesis."):
|
||||||
|
sample_data[col_name] = [
|
||||||
|
self._generate_mimesis_value(generator_str)
|
||||||
|
for _ in range(fitting_size)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Default to current year's range
|
||||||
|
current_year = datetime.now().year
|
||||||
|
start = datetime(current_year, 1, 1)
|
||||||
|
end = datetime(current_year, 12, 31)
|
||||||
|
sample_data[col_name] = [
|
||||||
|
start + timedelta(
|
||||||
|
seconds=np.random.randint(0, int((end - start).total_seconds()))
|
||||||
|
)
|
||||||
|
for _ in range(fitting_size)
|
||||||
|
]
|
||||||
|
elif col_type == "string":
|
||||||
|
if "generator" in col_spec:
|
||||||
|
generator_str = col_spec["generator"]
|
||||||
|
if generator_str.startswith("faker."):
|
||||||
|
sample_data[col_name] = [
|
||||||
|
self._generate_faker_value(generator_str)
|
||||||
|
for _ in range(fitting_size)
|
||||||
|
]
|
||||||
|
elif generator_str.startswith("mimesis."):
|
||||||
|
sample_data[col_name] = [
|
||||||
|
self._generate_mimesis_value(generator_str)
|
||||||
|
for _ in range(fitting_size)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Default to random string
|
||||||
|
sample_data[col_name] = [
|
||||||
|
''.join(np.random.choice(list('abcdefghijklmnopqrstuvwxyz'), size=10))
|
||||||
|
for _ in range(fitting_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create DataFrame and fit synthesizer
|
||||||
|
df = pd.DataFrame(sample_data)
|
||||||
|
synthesizer = GaussianCopulaSynthesizer(metadata)
|
||||||
|
synthesizer.fit(df)
|
||||||
|
self.synthesizers[table_name] = synthesizer
|
||||||
|
|
||||||
|
async def generate_synthetic_data(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
schema: Dict[str, Dict[str, Any]],
|
||||||
|
rows: int = 1000
|
||||||
|
) -> Dict[str, List[Any]]:
|
||||||
|
"""Generate synthetic data for a table."""
|
||||||
|
# Initialize result dictionary with empty lists for all columns
|
||||||
|
result: Dict[str, List[Any]] = {col_name: [] for col_name in schema.keys()}
|
||||||
|
|
||||||
|
# Generate parent tables first if needed
|
||||||
|
parent_tables = set()
|
||||||
|
for col_name, col_spec in schema.items():
|
||||||
|
if col_spec.get("correlated", False):
|
||||||
|
parent_table = self._extract_parent_table(col_name)
|
||||||
|
parent_tables.add((parent_table, col_name))
|
||||||
|
|
||||||
|
# Generate parent table data if not already generated
|
||||||
|
for parent_table, col_name in parent_tables:
|
||||||
|
if parent_table not in self.generated_ids or not self.generated_ids[parent_table]:
|
||||||
|
if hasattr(self, 'default_schemas') and parent_table in self.default_schemas:
|
||||||
|
parent_schema = self.default_schemas[parent_table]
|
||||||
|
await self.generate_synthetic_data(parent_table, parent_schema, rows)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Parent table {parent_table} schema not found")
|
||||||
|
|
||||||
|
# Generate data for each column
|
||||||
|
for _ in range(rows):
|
||||||
|
for col_name, col_spec in schema.items():
|
||||||
|
col_type = col_spec["type"]
|
||||||
|
value = None
|
||||||
|
|
||||||
|
if col_name.endswith("_id") and not col_spec.get("correlated", False):
|
||||||
|
# Generate unique ID
|
||||||
|
value = self._generate_unique_id(table_name, col_spec)
|
||||||
|
elif col_spec.get("correlated", False):
|
||||||
|
# Generate correlated ID from parent table
|
||||||
|
parent_table = self._extract_parent_table(col_name)
|
||||||
|
value = self._generate_correlated_id(parent_table)
|
||||||
|
elif col_type == "string":
|
||||||
|
if "generator" in col_spec:
|
||||||
|
if col_spec["generator"].startswith("faker."):
|
||||||
|
value = self._generate_faker_value(col_spec["generator"])
|
||||||
|
elif col_spec["generator"].startswith("mimesis."):
|
||||||
|
value = self._generate_mimesis_value(col_spec["generator"])
|
||||||
|
elif "categories" in col_spec:
|
||||||
|
value = np.random.choice(col_spec["categories"])
|
||||||
|
else:
|
||||||
|
value = self._generate_faker_value("faker.word")
|
||||||
|
elif col_type == "int":
|
||||||
|
value = np.random.randint(col_spec.get("min", 0), col_spec.get("max", 100))
|
||||||
|
elif col_type == "float":
|
||||||
|
value = np.random.uniform(col_spec.get("min", 0.0), col_spec.get("max", 1.0))
|
||||||
|
elif col_type == "datetime":
|
||||||
|
if "generator" in col_spec:
|
||||||
|
value = self._generate_faker_value(col_spec["generator"])
|
||||||
|
else:
|
||||||
|
value = self._generate_faker_value("faker.date_time_this_decade")
|
||||||
|
elif col_type == "category":
|
||||||
|
value = np.random.choice(col_spec["categories"])
|
||||||
|
|
||||||
|
result[col_name].append(value)
|
||||||
|
|
||||||
|
# Store generated IDs for correlated columns
|
||||||
|
for col_name, values in result.items():
|
||||||
|
if col_name.endswith("_id") and not schema[col_name].get("correlated", False):
|
||||||
|
if table_name not in self.generated_ids:
|
||||||
|
self.generated_ids[table_name] = set()
|
||||||
|
self.generated_ids[table_name].update(values)
|
||||||
|
|
||||||
|
return result
|
||||||
168
src/datagen/tests/unit/test_insurance_data.py
Normal file
168
src/datagen/tests/unit/test_insurance_data.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Unit tests for insurance-specific data generation."""
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from mcp_server_datagen.synthetic import SyntheticDataGenerator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def data_generator(customers_schema, policies_schema, claims_schema):
|
||||||
|
"""Create a data generator instance for testing."""
|
||||||
|
generator = SyntheticDataGenerator()
|
||||||
|
generator.default_schemas = {
|
||||||
|
"customers": customers_schema,
|
||||||
|
"policies": policies_schema,
|
||||||
|
"claims": claims_schema
|
||||||
|
}
|
||||||
|
return generator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def customers_schema():
|
||||||
|
"""Create the customers table schema."""
|
||||||
|
return {
|
||||||
|
"customer_id": {"type": "int", "min": 10000, "max": 99999},
|
||||||
|
"first_name": {"type": "string", "generator": "faker.first_name"},
|
||||||
|
"last_name": {"type": "string", "generator": "faker.last_name"},
|
||||||
|
"email": {"type": "string", "generator": "faker.email"},
|
||||||
|
"phone": {"type": "string", "generator": "faker.phone_number"},
|
||||||
|
"address": {"type": "string", "generator": "faker.address"},
|
||||||
|
"date_of_birth": {"type": "datetime", "generator": "faker.date_of_birth"},
|
||||||
|
"credit_score": {"type": "int", "min": 300, "max": 850}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def policies_schema():
|
||||||
|
"""Create the policies table schema."""
|
||||||
|
return {
|
||||||
|
"policy_id": {"type": "int", "min": 100000, "max": 999999},
|
||||||
|
"customer_id": {"type": "int", "min": 10000, "max": 99999, "correlated": True},
|
||||||
|
"policy_type": {"type": "category", "categories": ["auto", "home", "life", "health"]},
|
||||||
|
"start_date": {"type": "datetime"},
|
||||||
|
"end_date": {"type": "datetime"},
|
||||||
|
"premium": {"type": "float", "min": 500.0, "max": 5000.0},
|
||||||
|
"coverage_amount": {"type": "float", "min": 50000.0, "max": 1000000.0},
|
||||||
|
"status": {"type": "category", "categories": ["active", "expired", "cancelled", "pending"]}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def claims_schema():
|
||||||
|
"""Create the claims table schema."""
|
||||||
|
return {
|
||||||
|
"claim_id": {"type": "int", "min": 1000000, "max": 9999999},
|
||||||
|
"policy_id": {"type": "int", "min": 100000, "max": 999999, "correlated": True},
|
||||||
|
"date_filed": {"type": "datetime"},
|
||||||
|
"incident_date": {"type": "datetime"},
|
||||||
|
"claim_type": {"type": "category", "categories": [
|
||||||
|
"accident", "theft", "natural_disaster", "medical", "property_damage"
|
||||||
|
]},
|
||||||
|
"amount_claimed": {"type": "float", "min": 1000.0, "max": 100000.0},
|
||||||
|
"status": {"type": "category", "categories": ["pending", "approved", "denied", "in_review"]},
|
||||||
|
"description": {"type": "string", "generator": "mimesis.text.text"}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_customers_table(data_generator, customers_schema):
|
||||||
|
"""Test generation of customers table with 10,000 rows."""
|
||||||
|
rows = 10000
|
||||||
|
data = await data_generator.generate_synthetic_data("customers", customers_schema, rows)
|
||||||
|
|
||||||
|
# Verify row count
|
||||||
|
assert all(len(values) == rows for values in data.values())
|
||||||
|
|
||||||
|
# Verify data types and ranges
|
||||||
|
assert all(isinstance(x, (int, np.integer)) for x in data["customer_id"])
|
||||||
|
assert all(10000 <= x <= 99999 for x in data["customer_id"])
|
||||||
|
assert all(300 <= x <= 850 for x in data["credit_score"])
|
||||||
|
|
||||||
|
# Verify Faker-generated fields
|
||||||
|
assert all(isinstance(x, str) and "@" in x for x in data["email"])
|
||||||
|
assert all(isinstance(x, str) and len(x) > 0 for x in data["first_name"])
|
||||||
|
assert all(isinstance(x, str) and len(x) > 0 for x in data["last_name"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_policies_table(data_generator, policies_schema):
|
||||||
|
"""Test generation of policies table with 10,000 rows."""
|
||||||
|
rows = 10000
|
||||||
|
data = await data_generator.generate_synthetic_data("policies", policies_schema, rows)
|
||||||
|
|
||||||
|
# Verify row count
|
||||||
|
assert all(len(values) == rows for values in data.values())
|
||||||
|
|
||||||
|
# Verify data types and ranges
|
||||||
|
assert all(isinstance(x, (int, np.integer)) for x in data["policy_id"])
|
||||||
|
assert all(100000 <= x <= 999999 for x in data["policy_id"])
|
||||||
|
assert all(isinstance(x, (float, np.floating)) for x in data["premium"])
|
||||||
|
assert all(500.0 <= x <= 5000.0 for x in data["premium"])
|
||||||
|
assert all(50000.0 <= x <= 1000000.0 for x in data["coverage_amount"])
|
||||||
|
|
||||||
|
# Verify categorical fields
|
||||||
|
valid_types = ["auto", "home", "life", "health"]
|
||||||
|
valid_statuses = ["active", "expired", "cancelled", "pending"]
|
||||||
|
assert all(x in valid_types for x in data["policy_type"])
|
||||||
|
assert all(x in valid_statuses for x in data["status"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_claims_table(data_generator, claims_schema):
|
||||||
|
"""Test generation of claims table with 10,000 rows."""
|
||||||
|
rows = 10000
|
||||||
|
data = await data_generator.generate_synthetic_data("claims", claims_schema, rows)
|
||||||
|
|
||||||
|
# Verify row count
|
||||||
|
assert all(len(values) == rows for values in data.values())
|
||||||
|
|
||||||
|
# Verify data types and ranges
|
||||||
|
assert all(isinstance(x, (int, np.integer)) for x in data["claim_id"])
|
||||||
|
assert all(1000000 <= x <= 9999999 for x in data["claim_id"])
|
||||||
|
assert all(isinstance(x, (float, np.floating)) for x in data["amount_claimed"])
|
||||||
|
assert all(1000.0 <= x <= 100000.0 for x in data["amount_claimed"])
|
||||||
|
|
||||||
|
# Verify categorical fields
|
||||||
|
valid_types = ["accident", "theft", "natural_disaster", "medical", "property_damage"]
|
||||||
|
valid_statuses = ["pending", "approved", "denied", "in_review"]
|
||||||
|
assert all(x in valid_types for x in data["claim_type"])
|
||||||
|
assert all(x in valid_statuses for x in data["status"])
|
||||||
|
|
||||||
|
# Verify Mimesis-generated descriptions
|
||||||
|
assert all(isinstance(x, str) and len(x) > 0 for x in data["description"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_data_relationships(data_generator, customers_schema, policies_schema, claims_schema):
|
||||||
|
"""Test relationships between tables."""
|
||||||
|
# Generate all three tables
|
||||||
|
customers = await data_generator.generate_synthetic_data("customers", customers_schema, 1000)
|
||||||
|
policies = await data_generator.generate_synthetic_data("policies", policies_schema, 2000)
|
||||||
|
claims = await data_generator.generate_synthetic_data("claims", claims_schema, 3000)
|
||||||
|
|
||||||
|
# Verify customer-policy relationship
|
||||||
|
customer_ids = set(customers["customer_id"])
|
||||||
|
policy_customer_ids = set(policies["customer_id"])
|
||||||
|
assert policy_customer_ids.issubset(customer_ids)
|
||||||
|
|
||||||
|
# Verify policy-claim relationship
|
||||||
|
policy_ids = set(policies["policy_id"])
|
||||||
|
claim_policy_ids = set(claims["policy_id"])
|
||||||
|
assert claim_policy_ids.issubset(policy_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_csv_export(data_generator, customers_schema, tmp_path):
|
||||||
|
"""Test CSV export functionality."""
|
||||||
|
rows = 100
|
||||||
|
data = await data_generator.generate_synthetic_data("customers", customers_schema, rows)
|
||||||
|
|
||||||
|
# Convert to DataFrame and save as CSV
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
csv_path = tmp_path / "customers.csv"
|
||||||
|
df.to_csv(csv_path, index=False)
|
||||||
|
|
||||||
|
# Read back and verify
|
||||||
|
df_read = pd.read_csv(csv_path)
|
||||||
|
assert len(df_read) == rows
|
||||||
|
assert all(col in df_read.columns for col in customers_schema.keys())
|
||||||
163
src/datagen/tests/unit/test_server.py
Normal file
163
src/datagen/tests/unit/test_server.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
"""Unit tests for MCP data generation server."""
|
||||||
|
import pytest
|
||||||
|
from typing import Dict
|
||||||
|
from mcp_server_datagen.server import DataGenServer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def server():
|
||||||
|
"""Create a server instance for testing."""
|
||||||
|
return DataGenServer()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tools(server):
|
||||||
|
"""Test that the server correctly lists available tools."""
|
||||||
|
tools = await server.list_tools()
|
||||||
|
|
||||||
|
# Verify tool list structure
|
||||||
|
assert isinstance(tools, list)
|
||||||
|
assert len(tools) > 0
|
||||||
|
|
||||||
|
# Verify required tools are present
|
||||||
|
tool_names = [tool.name for tool in tools]
|
||||||
|
assert "generate_tables" in tool_names
|
||||||
|
|
||||||
|
# Verify tool schema
|
||||||
|
generate_tool = next(tool for tool in tools if tool.name == "generate_tables")
|
||||||
|
assert generate_tool.inputSchema is not None
|
||||||
|
assert "tables" in generate_tool.inputSchema["properties"]
|
||||||
|
assert "rows" in generate_tool.inputSchema["properties"]
|
||||||
|
assert "schemas" in generate_tool.inputSchema["properties"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_insurance_tables(server):
|
||||||
|
"""Test generation of insurance tables through the server."""
|
||||||
|
# Test parameters
|
||||||
|
params = {
|
||||||
|
"tables": ["customers", "policies", "claims"],
|
||||||
|
"rows": 100
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the tool
|
||||||
|
result = await server.handle_generate_tables(params)
|
||||||
|
|
||||||
|
# Verify result structure
|
||||||
|
assert isinstance(result, Dict)
|
||||||
|
assert all(table in result for table in params["tables"])
|
||||||
|
|
||||||
|
# Verify each table's data
|
||||||
|
for table_name, table_data in result.items():
|
||||||
|
assert isinstance(table_data, Dict)
|
||||||
|
assert len(next(iter(table_data.values()))) == params["rows"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_custom_schema(server):
|
||||||
|
"""Test generation with custom schema."""
|
||||||
|
custom_schema = {
|
||||||
|
"test_table": {
|
||||||
|
"id": {"type": "int", "min": 1, "max": 100},
|
||||||
|
"name": {"type": "string", "generator": "faker.name"},
|
||||||
|
"description": {"type": "string", "generator": "mimesis.text.text"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"tables": ["test_table"],
|
||||||
|
"rows": 50,
|
||||||
|
"schemas": custom_schema
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await server.handle_generate_tables(params)
|
||||||
|
|
||||||
|
# Verify custom schema generation
|
||||||
|
assert "test_table" in result
|
||||||
|
table_data = result["test_table"]
|
||||||
|
assert len(table_data["id"]) == 50
|
||||||
|
assert all(1 <= x <= 100 for x in table_data["id"])
|
||||||
|
assert all(isinstance(x, str) for x in table_data["name"])
|
||||||
|
assert all(isinstance(x, str) for x in table_data["description"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_table_name(server):
|
||||||
|
"""Test error handling for invalid table names."""
|
||||||
|
params = {
|
||||||
|
"tables": ["nonexistent_table"],
|
||||||
|
"rows": 100
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await server.handle_generate_tables(params)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_row_count(server):
|
||||||
|
"""Test error handling for invalid row counts."""
|
||||||
|
params = {
|
||||||
|
"tables": ["customers"],
|
||||||
|
"rows": -1
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await server.handle_generate_tables(params)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_large_dataset_generation(server):
|
||||||
|
"""Test generation of large datasets (10,000 rows)."""
|
||||||
|
params = {
|
||||||
|
"tables": ["customers", "policies", "claims"],
|
||||||
|
"rows": 10000
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await server.handle_generate_tables(params)
|
||||||
|
|
||||||
|
# Verify row counts
|
||||||
|
assert all(len(next(iter(table_data.values()))) == 10000
|
||||||
|
for table_data in result.values())
|
||||||
|
|
||||||
|
# Verify data relationships
|
||||||
|
customers = result["customers"]
|
||||||
|
policies = result["policies"]
|
||||||
|
claims = result["claims"]
|
||||||
|
|
||||||
|
# Customer IDs from policies should exist in customers
|
||||||
|
customer_ids = set(customers["customer_id"])
|
||||||
|
policy_customer_ids = set(policies["customer_id"])
|
||||||
|
assert policy_customer_ids.issubset(customer_ids)
|
||||||
|
|
||||||
|
# Policy IDs from claims should exist in policies
|
||||||
|
policy_ids = set(policies["policy_id"])
|
||||||
|
claim_policy_ids = set(claims["policy_id"])
|
||||||
|
assert claim_policy_ids.issubset(policy_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_csv_export_format(server):
|
||||||
|
"""Test that generated data can be exported as CSV."""
|
||||||
|
import pandas as pd
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"tables": ["customers"],
|
||||||
|
"rows": 100
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await server.handle_generate_tables(params)
|
||||||
|
|
||||||
|
# Convert to DataFrame and save as CSV
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
csv_path = os.path.join(tmp_dir, "customers.csv")
|
||||||
|
df = pd.DataFrame(result["customers"])
|
||||||
|
df.to_csv(csv_path, index=False)
|
||||||
|
|
||||||
|
|
||||||
|
# Read back and verify
|
||||||
|
df_read = pd.read_csv(csv_path)
|
||||||
|
assert len(df_read) == 100
|
||||||
|
assert all(col in df_read.columns
|
||||||
|
for col in server.default_schemas["customers"].keys())
|
||||||
114
src/datagen/tests/unit/test_synthetic.py
Normal file
114
src/datagen/tests/unit/test_synthetic.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
"""Unit tests for synthetic data generation."""
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from mcp_server_datagen.synthetic import SyntheticDataGenerator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def data_generator():
|
||||||
|
"""Create a data generator instance for testing."""
|
||||||
|
return SyntheticDataGenerator()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_schema():
|
||||||
|
"""Create a sample schema for testing."""
|
||||||
|
return {
|
||||||
|
"id": {"type": "int", "min": 1, "max": 1000},
|
||||||
|
"name": {"type": "string", "categories": ["Alice", "Bob", "Charlie"]},
|
||||||
|
"age": {"type": "int", "min": 18, "max": 100},
|
||||||
|
"score": {"type": "float", "min": 0.0, "max": 1.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_metadata(data_generator, sample_schema):
|
||||||
|
"""Test metadata creation from schema."""
|
||||||
|
metadata = data_generator.create_metadata("test_table", sample_schema)
|
||||||
|
|
||||||
|
# Verify all columns are present
|
||||||
|
assert set(sample_schema.keys()) == set(metadata.columns.keys())
|
||||||
|
|
||||||
|
# Verify column types are mapped correctly
|
||||||
|
assert metadata.columns["id"]["sdtype"] == "numerical"
|
||||||
|
assert metadata.columns["name"]["sdtype"] == "categorical"
|
||||||
|
assert metadata.columns["age"]["sdtype"] == "numerical"
|
||||||
|
assert metadata.columns["score"]["sdtype"] == "numerical"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_synthetic_data(data_generator, sample_schema):
|
||||||
|
"""Test synthetic data generation."""
|
||||||
|
rows = 100
|
||||||
|
data = await data_generator.generate_synthetic_data("test_table", sample_schema, rows)
|
||||||
|
|
||||||
|
# Verify all columns are present
|
||||||
|
assert set(data.keys()) == set(sample_schema.keys())
|
||||||
|
|
||||||
|
# Verify number of rows
|
||||||
|
assert all(len(values) == rows for values in data.values())
|
||||||
|
|
||||||
|
# Verify data types and ranges
|
||||||
|
assert all(isinstance(x, (int, np.integer)) for x in data["id"])
|
||||||
|
assert all(1 <= x <= 1000 for x in data["id"])
|
||||||
|
|
||||||
|
assert all(isinstance(x, str) for x in data["name"])
|
||||||
|
assert all(x in ["Alice", "Bob", "Charlie"] for x in data["name"])
|
||||||
|
|
||||||
|
assert all(isinstance(x, (int, np.integer)) for x in data["age"])
|
||||||
|
assert all(18 <= x <= 100 for x in data["age"])
|
||||||
|
|
||||||
|
assert all(isinstance(x, (float, np.floating)) for x in data["score"])
|
||||||
|
assert all(0.0 <= x <= 1.0 for x in data["score"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_large_dataset(data_generator):
|
||||||
|
"""Test generation of a large dataset."""
|
||||||
|
schema = {
|
||||||
|
"customer_id": {"type": "int", "min": 10000, "max": 99999},
|
||||||
|
"first_name": {"type": "string"},
|
||||||
|
"last_name": {"type": "string"},
|
||||||
|
"age": {"type": "int", "min": 18, "max": 100},
|
||||||
|
"credit_score": {"type": "int", "min": 300, "max": 850},
|
||||||
|
}
|
||||||
|
|
||||||
|
rows = 10000
|
||||||
|
data = await data_generator.generate_synthetic_data("customers", schema, rows)
|
||||||
|
|
||||||
|
# Verify row count
|
||||||
|
assert all(len(values) == rows for values in data.values())
|
||||||
|
|
||||||
|
# Verify data constraints
|
||||||
|
assert all(10000 <= x <= 99999 for x in data["customer_id"])
|
||||||
|
assert all(18 <= x <= 100 for x in data["age"])
|
||||||
|
assert all(300 <= x <= 850 for x in data["credit_score"])
|
||||||
|
|
||||||
|
# Verify unique IDs
|
||||||
|
assert len(set(data["customer_id"])) > rows * 0.95 # Allow for some duplicates due to random generation
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_table_generation(data_generator):
|
||||||
|
"""Test generation of multiple related tables."""
|
||||||
|
customers_schema = {
|
||||||
|
"customer_id": {"type": "int", "min": 1, "max": 1000},
|
||||||
|
"name": {"type": "string"},
|
||||||
|
}
|
||||||
|
|
||||||
|
policies_schema = {
|
||||||
|
"policy_id": {"type": "int", "min": 1, "max": 2000},
|
||||||
|
"customer_id": {"type": "int", "min": 1, "max": 1000, "correlated": True},
|
||||||
|
"premium": {"type": "float", "min": 500.0, "max": 5000.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate both tables
|
||||||
|
customers = await data_generator.generate_synthetic_data("customers", customers_schema, 100)
|
||||||
|
policies = await data_generator.generate_synthetic_data("policies", policies_schema, 200)
|
||||||
|
|
||||||
|
# Verify referential integrity is maintained
|
||||||
|
customer_ids = set(customers["customer_id"])
|
||||||
|
policy_customer_ids = set(policies["customer_id"])
|
||||||
|
|
||||||
|
# All policy customer_ids should exist in customers table
|
||||||
|
assert policy_customer_ids.issubset(customer_ids)
|
||||||
1115
src/datagen/uv.lock
generated
Normal file
1115
src/datagen/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user