mirror of
https://github.com/modelcontextprotocol/servers.git
synced 2026-04-25 23:35:27 +02:00
fix: Handle table relationships, type checking, and data generation
Co-Authored-By: alexander@anthropic.com <alexander@anthropic.com>
This commit is contained in:
@@ -271,29 +271,38 @@ class DataGenServer:
|
||||
if rows <= 0:
|
||||
raise ValueError("Row count must be positive")
|
||||
|
||||
results = {}
|
||||
try:
|
||||
for table_name in tables:
|
||||
if table_name not in self.default_schemas and table_name not in custom_schemas:
|
||||
raise ValueError(f"Unknown table: {table_name}")
|
||||
# Validate and prepare schemas
|
||||
schemas = {}
|
||||
for table_name in tables:
|
||||
if table_name in custom_schemas:
|
||||
schemas[table_name] = custom_schemas[table_name]
|
||||
elif table_name in self.default_schemas:
|
||||
schemas[table_name] = self.default_schemas[table_name]
|
||||
else:
|
||||
raise ValueError(f"No schema found for table {table_name}")
|
||||
|
||||
# Use custom schema if provided, otherwise use default
|
||||
schema = custom_schemas.get(table_name, self.default_schemas.get(table_name, {}))
|
||||
data = await self.generator.generate_synthetic_data(
|
||||
table_name=table_name,
|
||||
schema=schema,
|
||||
rows=rows
|
||||
# Clear previously generated IDs
|
||||
self.generator._clear_generated_ids()
|
||||
|
||||
# Generate tables in correct order to maintain relationships
|
||||
result = {}
|
||||
table_order = ["customers", "policies", "claims"]
|
||||
ordered_tables = sorted(
|
||||
tables,
|
||||
key=lambda x: table_order.index(x) if x in table_order else len(table_order)
|
||||
)
|
||||
|
||||
for table_name in ordered_tables:
|
||||
try:
|
||||
result[table_name] = await self.generator.generate_synthetic_data(
|
||||
table_name,
|
||||
schemas[table_name],
|
||||
rows
|
||||
)
|
||||
results[table_name] = data
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Failed to generate table {table_name}: {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
except ValueError as e:
|
||||
# Re-raise validation errors directly
|
||||
raise e
|
||||
except Exception as e:
|
||||
# Wrap unexpected errors in McpError
|
||||
raise McpError(f"Error generating data: {str(e)}")
|
||||
return result
|
||||
|
||||
async def handle_generate_insurance_data(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Handle generate_insurance_data tool requests."""
|
||||
@@ -301,19 +310,24 @@ class DataGenServer:
|
||||
if rows <= 0:
|
||||
raise ValueError("Row count must be positive")
|
||||
|
||||
results = {}
|
||||
try:
|
||||
for table_name in ["customers", "policies", "claims"]:
|
||||
schema = self.default_schemas[table_name]
|
||||
data = await self.generator.generate_synthetic_data(
|
||||
table_name=table_name,
|
||||
schema=schema,
|
||||
rows=rows
|
||||
# Clear previously generated IDs
|
||||
self.generator._clear_generated_ids()
|
||||
|
||||
# Generate tables in correct order to maintain relationships
|
||||
result = {}
|
||||
table_order = ["customers", "policies", "claims"]
|
||||
|
||||
for table_name in table_order:
|
||||
try:
|
||||
result[table_name] = await self.generator.generate_synthetic_data(
|
||||
table_name,
|
||||
self.default_schemas[table_name],
|
||||
rows
|
||||
)
|
||||
results[table_name] = data
|
||||
return results
|
||||
except Exception as e:
|
||||
raise McpError(f"Error generating data: {str(e)}")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Failed to generate table {table_name}: {str(e)}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def serve() -> None:
|
||||
|
||||
@@ -1,39 +1,40 @@
|
||||
"""Synthetic data generation using numpy and faker."""
|
||||
|
||||
import datetime
|
||||
from datetime import datetime, date
|
||||
import numpy as np
|
||||
from faker import Faker
|
||||
from mimesis import Generic
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Set, Union, Callable
|
||||
|
||||
# Default schemas for insurance data generation
|
||||
DEFAULT_CUSTOMER_SCHEMA = {
|
||||
"customer_id": {"type": "integer", "generator": "numpy", "min": 10000, "max": 99999},
|
||||
"first_name": {"type": "first_name", "generator": "faker"},
|
||||
"last_name": {"type": "last_name", "generator": "faker"},
|
||||
"age": {"type": "integer", "generator": "numpy", "min": 18, "max": 85},
|
||||
"credit_score": {"type": "integer", "generator": "numpy", "min": 300, "max": 850},
|
||||
"active": {"type": "boolean", "generator": "numpy"}
|
||||
"customer_id": {"type": "integer", "min": 100000, "max": 999999},
|
||||
"name": {"type": "string", "generator": "faker", "method": "name"},
|
||||
"email": {"type": "string", "generator": "faker", "method": "email"},
|
||||
"phone": {"type": "string", "generator": "faker", "method": "phone_number"},
|
||||
"address": {"type": "string", "generator": "faker", "method": "address"},
|
||||
"date_of_birth": {"type": "date", "generator": "faker"},
|
||||
"risk_score": {"type": "float", "min": 0.0, "max": 1.0}
|
||||
}
|
||||
|
||||
DEFAULT_POLICY_SCHEMA = {
|
||||
"policy_id": {"type": "integer", "generator": "numpy", "min": 100000, "max": 999999, "prefix": "POL-"},
|
||||
"customer_id": {"type": "integer", "generator": "numpy", "min": 10000, "max": 99999, "correlated": True},
|
||||
"premium": {"type": "float", "generator": "numpy", "min": 500.0, "max": 5000.0},
|
||||
"deductible": {"type": "float", "generator": "numpy", "min": 250.0, "max": 2000.0},
|
||||
"coverage_amount": {"type": "integer", "generator": "numpy", "min": 50000, "max": 1000000},
|
||||
"risk_score": {"type": "integer", "generator": "numpy", "min": 1, "max": 100},
|
||||
"policy_type": {"type": "category", "generator": "numpy", "categories": ["Auto", "Home", "Life", "Health"]},
|
||||
"start_date": {"type": "date_this_year", "generator": "faker"}
|
||||
"policy_id": {"type": "integer", "min": 100000, "max": 999999, "prefix": "POL-2024-"},
|
||||
"customer_id": {"type": "integer", "correlated": True},
|
||||
"type": {"type": "category", "categories": ["Auto", "Home", "Life", "Health"]},
|
||||
"premium": {"type": "float", "min": 500.0, "max": 5000.0},
|
||||
"coverage_amount": {"type": "integer", "min": 50000, "max": 1000000},
|
||||
"start_date": {"type": "date", "generator": "faker"},
|
||||
"status": {"type": "category", "categories": ["Active", "Pending", "Expired", "Cancelled"]},
|
||||
"deductible": {"type": "float", "min": 250.0, "max": 2000.0}
|
||||
}
|
||||
|
||||
DEFAULT_CLAIMS_SCHEMA = {
|
||||
"claim_id": {"type": "integer", "generator": "numpy", "min": 500000, "max": 999999},
|
||||
"policy_id": {"type": "integer", "generator": "numpy", "min": 100000, "max": 999999, "correlated": True},
|
||||
"amount": {"type": "float", "generator": "numpy", "min": 100.0, "max": 50000.0},
|
||||
"status": {"type": "category", "generator": "numpy", "categories": ["Open", "Closed", "Pending", "Denied"]},
|
||||
"date_filed": {"type": "date_this_year", "generator": "faker"},
|
||||
"description": {"type": "text", "generator": "faker"}
|
||||
"claim_id": {"type": "integer", "min": 1000000, "max": 9999999},
|
||||
"policy_id": {"type": "integer", "correlated": True},
|
||||
"date_filed": {"type": "datetime"},
|
||||
"amount_claimed": {"type": "float", "min": 1000.0, "max": 100000.0},
|
||||
"status": {"type": "category", "categories": ["Open", "Under Review", "Approved", "Denied"]},
|
||||
"type": {"type": "category", "categories": ["accident", "theft", "natural_disaster", "medical", "property_damage"]}
|
||||
}
|
||||
|
||||
|
||||
@@ -47,119 +48,184 @@ class SyntheticDataGenerator:
|
||||
# Store generated IDs for relationships
|
||||
self._generated_ids: Dict[str, Set[Union[int, str]]] = {}
|
||||
self._id_counters: Dict[str, int] = {}
|
||||
self.current_table: str = "" # Track current table being generated
|
||||
self.default_schemas: Dict[str, Dict[str, Dict[str, Any]]] = {
|
||||
"customers": DEFAULT_CUSTOMER_SCHEMA,
|
||||
"policies": DEFAULT_POLICY_SCHEMA,
|
||||
"claims": DEFAULT_CLAIMS_SCHEMA
|
||||
}
|
||||
|
||||
def _map_type_to_generator(self, data_type: str) -> str:
|
||||
"""Map data type to appropriate generator."""
|
||||
# Handle legacy faker.method format
|
||||
if "." in data_type:
|
||||
return "faker"
|
||||
def _ensure_json_serializable(self, value: Any) -> Any:
|
||||
"""Convert value to JSON serializable type."""
|
||||
if isinstance(value, (np.integer, np.floating)):
|
||||
return value.item()
|
||||
elif isinstance(value, np.bool_):
|
||||
return bool(value)
|
||||
elif isinstance(value, (date, datetime)):
|
||||
return value.isoformat()
|
||||
elif isinstance(value, str):
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
# Direct faker types
|
||||
if data_type in {
|
||||
"first_name", "last_name", "email", "phone_number",
|
||||
"address", "text", "date_this_year", "date_this_decade",
|
||||
"date_of_birth", "name", "string"
|
||||
}:
|
||||
return "faker"
|
||||
def _map_type_to_generator(self, col_name: str, spec: Dict[str, Any]) -> Callable[[], Any]:
|
||||
"""Map a column specification to a generator function."""
|
||||
data_type = spec["type"]
|
||||
generator = spec.get("generator", None)
|
||||
|
||||
# Numpy types
|
||||
if data_type in {"boolean", "integer", "int", "float", "category"}:
|
||||
return "numpy"
|
||||
# Handle explicit generator specification
|
||||
if generator:
|
||||
if generator.startswith("faker."):
|
||||
return lambda: self._map_faker_type(generator)
|
||||
elif generator.startswith("mimesis."):
|
||||
return lambda: self._generate_mimesis_value(generator)
|
||||
elif generator.startswith("numpy."):
|
||||
return lambda: self._ensure_json_serializable(getattr(np.random, generator.split(".", 1)[1])())
|
||||
elif generator == "faker":
|
||||
return lambda: self._map_faker_type(data_type)
|
||||
elif generator == "numpy":
|
||||
# Use default numpy generators based on type
|
||||
if data_type in ["int", "integer"]:
|
||||
min_val = spec.get("min", 0)
|
||||
max_val = spec.get("max", 100)
|
||||
return lambda: int(self._ensure_json_serializable(np.random.randint(min_val, max_val)))
|
||||
elif data_type == "float":
|
||||
min_val = spec.get("min", 0.0)
|
||||
max_val = spec.get("max", 1.0)
|
||||
return lambda: float(self._ensure_json_serializable(np.random.uniform(min_val, max_val)))
|
||||
elif data_type == "boolean":
|
||||
return lambda: bool(np.random.choice([True, False]))
|
||||
else:
|
||||
raise ValueError(f"Unsupported numpy type: {data_type}")
|
||||
else:
|
||||
raise ValueError(f"Unknown generator type: {generator}")
|
||||
|
||||
# Default to faker for string types
|
||||
return "faker"
|
||||
# Handle special ID generation cases
|
||||
if "correlated" in spec:
|
||||
parent_table = self._extract_parent_table(self.current_table)
|
||||
return lambda: self._generate_correlated_id(parent_table)
|
||||
if col_name.endswith("_id") or col_name == "id":
|
||||
return lambda: self._generate_unique_id(self.current_table, spec)
|
||||
|
||||
def _map_faker_type(self, data_type: str) -> str:
|
||||
"""Map data type to Faker method name."""
|
||||
# Handle legacy faker.method format
|
||||
if "." in data_type:
|
||||
return data_type.split(".")[-1]
|
||||
|
||||
# Direct mapping for faker types
|
||||
if data_type in {
|
||||
"first_name", "last_name", "email", "phone_number",
|
||||
"address", "text", "date_this_year", "date_this_decade",
|
||||
"date_of_birth", "name", "string"
|
||||
}:
|
||||
return "text" if data_type == "string" else data_type
|
||||
|
||||
# Legacy type mapping
|
||||
type_mapping = {
|
||||
"string": "text",
|
||||
}
|
||||
return type_mapping.get(data_type, "text")
|
||||
|
||||
def _generate_faker_value(self, data_type: str) -> str:
|
||||
"""Generate a value using Faker."""
|
||||
# Handle legacy faker.method format
|
||||
if "." in data_type:
|
||||
method_name = data_type.split(".")[-1]
|
||||
# Map basic types to generators
|
||||
if data_type == "string":
|
||||
if "prefix" in spec:
|
||||
prefix = spec["prefix"]
|
||||
if "categories" in spec:
|
||||
categories = spec["categories"]
|
||||
return lambda: f"{prefix}{np.random.choice(categories)}"
|
||||
return lambda: f"{prefix}{self._map_faker_type('text')}"
|
||||
if "categories" in spec:
|
||||
categories = spec["categories"]
|
||||
return lambda: str(np.random.choice(categories))
|
||||
return lambda: self._map_faker_type("text")
|
||||
elif data_type in ["int", "integer"]:
|
||||
min_val = spec.get("min", 0)
|
||||
max_val = spec.get("max", 100)
|
||||
return lambda: int(self._ensure_json_serializable(np.random.randint(min_val, max_val)))
|
||||
elif data_type == "float":
|
||||
min_val = spec.get("min", 0.0)
|
||||
max_val = spec.get("max", 1.0)
|
||||
return lambda: float(self._ensure_json_serializable(np.random.uniform(min_val, max_val)))
|
||||
elif data_type == "boolean":
|
||||
return lambda: bool(np.random.choice([True, False]))
|
||||
elif data_type == "category":
|
||||
categories = spec.get("categories", [])
|
||||
return lambda: str(np.random.choice(categories))
|
||||
elif data_type == "datetime":
|
||||
return lambda: self._map_faker_type("date_time_this_year")
|
||||
elif data_type == "date":
|
||||
return lambda: self._map_faker_type("date")
|
||||
else:
|
||||
method_name = data_type
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
|
||||
def _map_faker_type(self, data_type: str) -> Any:
|
||||
"""Map a data type to a faker method."""
|
||||
# Handle faker.method format
|
||||
if data_type.startswith("faker."):
|
||||
method = data_type.split(".", 1)[1]
|
||||
else:
|
||||
# Map common types to faker methods
|
||||
method_map = {
|
||||
"text": "text",
|
||||
"string": "text", # Map string to text
|
||||
"first_name": "first_name",
|
||||
"last_name": "last_name",
|
||||
"email": "email",
|
||||
"phone": "phone_number",
|
||||
"address": "address",
|
||||
"company": "company",
|
||||
"date": "date",
|
||||
"date_time_this_year": "date_time_this_year",
|
||||
"date_this_year": "date_this_year",
|
||||
}
|
||||
# Try to get mapped method, fallback to text for string types
|
||||
if data_type == "string":
|
||||
method = "text"
|
||||
else:
|
||||
method = method_map.get(data_type, data_type)
|
||||
|
||||
try:
|
||||
faker_method = getattr(self.faker, method)
|
||||
return faker_method()
|
||||
except AttributeError:
|
||||
if data_type == "string":
|
||||
return self.faker.text()
|
||||
raise ValueError(f"Unsupported faker type: {data_type}")
|
||||
|
||||
faker_method = getattr(self.faker, method_name, None)
|
||||
if faker_method is None:
|
||||
return self.faker.text() # Default to text if method not found
|
||||
return faker_method()
|
||||
|
||||
def _generate_mimesis_value(self, generator: str) -> Any:
|
||||
"""Generate value using Mimesis."""
|
||||
if not generator.startswith("mimesis."):
|
||||
return None
|
||||
"""Generate a value using mimesis."""
|
||||
if "." in generator:
|
||||
methods = generator.split(".")
|
||||
obj = self.mimesis
|
||||
for method in methods[1:]: # Skip 'mimesis' prefix
|
||||
if not hasattr(obj, method):
|
||||
raise ValueError(f"Invalid mimesis method: {generator}")
|
||||
obj = getattr(obj, method)
|
||||
if callable(obj):
|
||||
return obj()
|
||||
return obj
|
||||
return getattr(self.mimesis, generator)()
|
||||
|
||||
category, method = generator.split(".", 1)[1].split(".")
|
||||
if hasattr(self.mimesis, category):
|
||||
category_instance = getattr(self.mimesis, category)
|
||||
if hasattr(category_instance, method):
|
||||
return getattr(category_instance, method)()
|
||||
return None
|
||||
|
||||
def _generate_unique_id(self, table_name: str, col_spec: Dict[str, Any]) -> int:
|
||||
def _generate_unique_id(self, table_name: str, spec: Dict[str, Any]) -> Union[int, str]:
|
||||
"""Generate a unique ID for a table."""
|
||||
# Initialize ID storage and counter for this table if not exists
|
||||
if table_name not in self._generated_ids:
|
||||
self._generated_ids[table_name] = set()
|
||||
id_type = spec.get("type", "integer")
|
||||
min_val = spec.get("min", 1)
|
||||
max_val = spec.get("max", 1000000)
|
||||
prefix = spec.get("prefix", "")
|
||||
|
||||
# Get ID range from schema
|
||||
min_val = col_spec.get("min", 1)
|
||||
max_val = col_spec.get("max", 999999)
|
||||
# Generate a unique ID
|
||||
attempts = 0
|
||||
max_attempts = 100
|
||||
while attempts < max_attempts:
|
||||
if id_type in ["integer", "int"]:
|
||||
id_val = np.random.randint(min_val, max_val + 1)
|
||||
if prefix:
|
||||
id_val = f"{prefix}{id_val}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported ID type: {id_type}")
|
||||
|
||||
# Initialize counter at min_val - 1 if not set
|
||||
if table_name not in self._id_counters:
|
||||
self._id_counters[table_name] = min_val - 1
|
||||
# Check if ID is unique for this table
|
||||
if table_name not in self._generated_ids:
|
||||
self._generated_ids[table_name] = set()
|
||||
|
||||
# Try to generate a unique ID within the range
|
||||
for _ in range(max_val - min_val + 1): # Try all possible values in range
|
||||
self._id_counters[table_name] += 1
|
||||
if self._id_counters[table_name] > max_val:
|
||||
self._id_counters[table_name] = min_val
|
||||
if id_val not in self._generated_ids[table_name]:
|
||||
self._generated_ids[table_name].add(id_val)
|
||||
return id_val
|
||||
|
||||
candidate_id = self._id_counters[table_name]
|
||||
attempts += 1
|
||||
|
||||
# Check if ID is unique
|
||||
if candidate_id not in self._generated_ids[table_name]:
|
||||
self._generated_ids[table_name].add(candidate_id)
|
||||
return candidate_id
|
||||
|
||||
raise ValueError(
|
||||
f"Could not generate unique ID for {table_name}. "
|
||||
f"All IDs in range {min_val}-{max_val} have been used."
|
||||
)
|
||||
raise ValueError(f"Failed to generate unique ID for table {table_name} after {max_attempts} attempts")
|
||||
|
||||
def _generate_correlated_id(self, parent_table: str) -> Union[int, str]:
|
||||
"""Generate a correlated ID from a parent table."""
|
||||
if not self._generated_ids.get(parent_table):
|
||||
raise ValueError(f"No IDs available for parent table {parent_table}")
|
||||
# Get the raw ID without prefix
|
||||
parent_ids = [
|
||||
int(id_val.split('-')[-1]) if isinstance(id_val, str) else id_val
|
||||
for id_val in self._generated_ids[parent_table]
|
||||
]
|
||||
|
||||
# Get a random ID from the parent table
|
||||
parent_ids = list(self._generated_ids[parent_table])
|
||||
if not parent_ids:
|
||||
raise ValueError(f"No IDs available in parent table {parent_table}")
|
||||
return np.random.choice(parent_ids)
|
||||
|
||||
def _extract_parent_table(self, col_name: str) -> str:
|
||||
@@ -171,7 +237,6 @@ class SyntheticDataGenerator:
|
||||
return "customers"
|
||||
if col_name == "claim_id":
|
||||
return "claims"
|
||||
|
||||
# General case: remove '_id' suffix
|
||||
if col_name.endswith("_id"):
|
||||
table_name = col_name[:-3]
|
||||
@@ -181,10 +246,33 @@ class SyntheticDataGenerator:
|
||||
return table_name
|
||||
raise ValueError(f"Invalid column name for parent table extraction: {col_name}")
|
||||
|
||||
def _ensure_parent_table_ids(self, parent_table: str, rows: int) -> None:
|
||||
"""Ensure parent table has generated IDs."""
|
||||
if not self._generated_ids.get(parent_table):
|
||||
# Get the schema for the parent table
|
||||
if parent_table not in self.default_schemas:
|
||||
raise ValueError(f"No schema found for parent table {parent_table}")
|
||||
|
||||
parent_schema = self.default_schemas[parent_table]
|
||||
# Try both singular_id and table_id formats
|
||||
id_field_candidates = [
|
||||
f"{parent_table[:-1]}_id", # Remove 's' and add '_id'
|
||||
f"{parent_table}_id", # Full table name with _id
|
||||
"policy_id" if parent_table == "policies" else None, # Special case for policies
|
||||
"customer_id" if parent_table == "customers" else None, # Special case for customers
|
||||
"claim_id" if parent_table == "claims" else None # Special case for claims
|
||||
]
|
||||
id_field = next((field for field in id_field_candidates if field and field in parent_schema), None)
|
||||
if not id_field:
|
||||
raise ValueError(f"No ID field found in schema for parent table {parent_table}")
|
||||
|
||||
# Generate IDs for parent table
|
||||
generator = self._map_type_to_generator(id_field, parent_schema[id_field])
|
||||
self._generated_ids[parent_table] = {generator() for _ in range(rows)} # Use set comprehension
|
||||
|
||||
def _clear_generated_ids(self):
|
||||
"""Clear all generated IDs."""
|
||||
self._generated_ids = {}
|
||||
self._id_counters = {}
|
||||
|
||||
async def generate_synthetic_data(
|
||||
self,
|
||||
@@ -195,94 +283,48 @@ class SyntheticDataGenerator:
|
||||
"""Generate synthetic data for a table."""
|
||||
data: Dict[str, List[Any]] = {}
|
||||
|
||||
# Generate data for each column
|
||||
# If this is a parent table, ensure we generate IDs first
|
||||
if table_name in self.default_schemas:
|
||||
id_field = f"{table_name[:-1]}_id" # Remove 's' and add '_id'
|
||||
if id_field in schema:
|
||||
data[id_field] = [
|
||||
self._generate_unique_id(table_name, schema[id_field])
|
||||
for _ in range(rows)
|
||||
]
|
||||
self._generated_ids[table_name] = set(data[id_field])
|
||||
|
||||
# First pass: Generate non-correlated fields
|
||||
for col_name, col_spec in schema.items():
|
||||
generator = col_spec.get("generator", "numpy")
|
||||
|
||||
if generator == "faker":
|
||||
# Handle legacy faker.method format
|
||||
if "." in col_spec.get("type", ""):
|
||||
method_name = col_spec["type"].split(".")[-1]
|
||||
else:
|
||||
method_name = self._map_faker_type(col_spec.get("type", "text"))
|
||||
|
||||
values = []
|
||||
for _ in range(rows):
|
||||
value = self._generate_faker_value(method_name)
|
||||
# Convert date objects to string format
|
||||
if isinstance(value, (datetime.date, datetime.datetime)):
|
||||
value = value.isoformat()
|
||||
values.append(value)
|
||||
data[col_name] = values
|
||||
|
||||
elif generator == "mimesis":
|
||||
data[col_name] = [self._generate_mimesis_value(col_spec["type"]) for _ in range(rows)]
|
||||
|
||||
elif generator == "numpy":
|
||||
if col_spec["type"] == "boolean":
|
||||
data[col_name] = [bool(x) for x in np.random.choice([True, False], size=rows)]
|
||||
elif "correlated" in col_spec and col_spec["correlated"]:
|
||||
parent_table = self._extract_parent_table(col_name)
|
||||
if not self._generated_ids.get(parent_table):
|
||||
raise ValueError(f"Parent table {parent_table} must be generated before {table_name}")
|
||||
if not col_spec.get("correlated") and col_name not in data:
|
||||
if col_name.endswith("_id"):
|
||||
# Generate unique IDs with optional prefix
|
||||
data[col_name] = [
|
||||
self._generate_correlated_id(parent_table)
|
||||
self._generate_unique_id(table_name, col_spec)
|
||||
for _ in range(rows)
|
||||
]
|
||||
elif col_spec["type"] in ["integer", "int"]:
|
||||
if col_name.endswith("_id"):
|
||||
# Generate unique IDs
|
||||
data[col_name] = []
|
||||
for _ in range(rows):
|
||||
id_val = self._generate_unique_id(table_name, col_spec)
|
||||
# Add prefix if specified
|
||||
if "prefix" in col_spec:
|
||||
id_val = f"{col_spec['prefix']}{id_val}"
|
||||
data[col_name].append(id_val)
|
||||
# Store raw ID for relationships
|
||||
if isinstance(id_val, str) and "-" in id_val:
|
||||
raw_id = int(id_val.split("-")[-1])
|
||||
if table_name not in self._generated_ids:
|
||||
self._generated_ids[table_name] = set()
|
||||
self._generated_ids[table_name].add(raw_id)
|
||||
else:
|
||||
if table_name not in self._generated_ids:
|
||||
self._generated_ids[table_name] = set()
|
||||
self._generated_ids[table_name].add(id_val)
|
||||
else:
|
||||
# Regular integer values
|
||||
min_val = col_spec.get("min", 0)
|
||||
max_val = col_spec.get("max", 100)
|
||||
data[col_name] = [
|
||||
int(x) for x in np.random.randint(min_val, max_val + 1, size=rows)
|
||||
]
|
||||
elif col_spec["type"] == "float":
|
||||
min_val = float(col_spec.get("min", 0.0))
|
||||
max_val = float(col_spec.get("max", 1.0))
|
||||
data[col_name] = [
|
||||
float(x) for x in np.random.uniform(min_val, max_val, size=rows)
|
||||
]
|
||||
elif col_spec["type"] == "category":
|
||||
categories = col_spec.get("categories", [])
|
||||
data[col_name] = list(np.random.choice(categories, size=rows))
|
||||
elif col_spec["type"] == "string" and "categories" in col_spec:
|
||||
categories = col_spec.get("categories", [])
|
||||
data[col_name] = list(np.random.choice(categories, size=rows))
|
||||
else:
|
||||
# Default to string type
|
||||
data[col_name] = [str(x) for x in range(rows)]
|
||||
# Generate other fields using appropriate generator
|
||||
generator = self._map_type_to_generator(col_name, col_spec)
|
||||
data[col_name] = [
|
||||
self._ensure_json_serializable(generator())
|
||||
for _ in range(rows)
|
||||
]
|
||||
|
||||
# Convert all numpy types to native Python types for JSON serialization
|
||||
# Second pass: Generate correlated fields
|
||||
for col_name, col_spec in schema.items():
|
||||
if col_spec.get("correlated"):
|
||||
parent_table = self._extract_parent_table(col_name)
|
||||
self._ensure_parent_table_ids(parent_table, rows)
|
||||
data[col_name] = [
|
||||
self._generate_correlated_id(parent_table)
|
||||
for _ in range(rows)
|
||||
]
|
||||
|
||||
# Ensure all numpy types are converted to native Python types
|
||||
for col_name in data:
|
||||
if isinstance(data[col_name], np.ndarray):
|
||||
data[col_name] = data[col_name].tolist()
|
||||
data[col_name] = [
|
||||
int(x) if isinstance(x, np.integer)
|
||||
else float(x) if isinstance(x, np.floating)
|
||||
else bool(x) if isinstance(x, np.bool_)
|
||||
else str(x) if isinstance(x, np.str_)
|
||||
else x
|
||||
for x in data[col_name]
|
||||
self._ensure_json_serializable(val)
|
||||
for val in data[col_name]
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
@@ -25,15 +25,30 @@ def sample_schema():
|
||||
async def test_type_mapping(data_generator, sample_schema):
|
||||
"""Test type mapping functionality."""
|
||||
# Test numpy type mapping
|
||||
assert data_generator._map_type_to_generator("integer") == "numpy"
|
||||
assert data_generator._map_type_to_generator("float") == "numpy"
|
||||
assert data_generator._map_type_to_generator("boolean") == "numpy"
|
||||
assert data_generator._map_type_to_generator("category") == "numpy"
|
||||
generator = data_generator._map_type_to_generator("id", {"type": "integer", "generator": "numpy", "min": 1, "max": 100})
|
||||
assert callable(generator)
|
||||
value = generator()
|
||||
assert isinstance(value, (int, np.integer))
|
||||
assert 1 <= value <= 100
|
||||
|
||||
generator = data_generator._map_type_to_generator("score", {"type": "float", "generator": "numpy", "min": 0.0, "max": 1.0})
|
||||
assert callable(generator)
|
||||
value = generator()
|
||||
assert isinstance(value, (float, np.floating))
|
||||
assert 0.0 <= value <= 1.0
|
||||
|
||||
generator = data_generator._map_type_to_generator("active", {"type": "boolean", "generator": "numpy"})
|
||||
assert callable(generator)
|
||||
assert isinstance(generator(), bool)
|
||||
|
||||
# Test faker type mapping
|
||||
assert data_generator._map_type_to_generator("first_name") == "faker"
|
||||
assert data_generator._map_type_to_generator("email") == "faker"
|
||||
assert data_generator._map_type_to_generator("text") == "faker"
|
||||
generator = data_generator._map_type_to_generator("name", {"type": "first_name", "generator": "faker"})
|
||||
assert callable(generator)
|
||||
assert isinstance(generator(), str)
|
||||
|
||||
generator = data_generator._map_type_to_generator("email", {"type": "email", "generator": "faker"})
|
||||
assert callable(generator)
|
||||
assert isinstance(generator(), str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user