Refactor to not inherit from Server

This commit is contained in:
Justin Spahr-Summers
2024-11-21 21:38:06 +00:00
parent 812e8213dc
commit a184344b85

View File

@@ -100,7 +100,13 @@ The provided XML tags are for the assistants understanding. Emplore to make all
Start your first message fully in character with something like "Oh, Hey there! I see you've chosen the topic {topic}. Let's get started! 🚀"
"""
class McpServer(Server):
class SqliteDatabase:
def __init__(self, db_path: str):
self.db_path = str(Path(db_path).expanduser())
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._init_database()
self.insights: list[str] = []
def _init_database(self):
"""Initialize connection to the SQLite database"""
logger.debug("Initializing database connection")
@@ -127,7 +133,7 @@ class McpServer(Server):
logger.debug("Generated basic memo format")
return memo
def _execute_query(self, query: str, params=None) -> list[dict]:
def _execute_query(self, query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
"""Execute a SQL query and return results as a list of dictionaries"""
logger.debug(f"Executing query: {query}")
try:
@@ -152,187 +158,178 @@ class McpServer(Server):
logger.error(f"Database error executing query: {e}")
raise
def __init__(self, db_path: str = "~/sqlite_mcp_server.db"):
logger.info("Initializing McpServer")
super().__init__("sqlite-manager")
async def main(db_path: str):
logger.info(f"Starting SQLite MCP Server with DB path: {db_path}")
# Initialize SQLite database
self.db_path = str(Path(db_path).expanduser())
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._init_database()
logger.debug(f"Initialized database at {self.db_path}")
db = SqliteDatabase(db_path)
server = Server("sqlite-manager")
# Register handlers
logger.debug("Registering handlers")
# Initialize insights list
self.insights = []
@server.list_resources()
async def handle_list_resources() -> list[types.Resource]:
logger.debug("Handling list_resources request")
return [
types.Resource(
uri=AnyUrl("memo://insights"),
name="Business Insights Memo",
description="A living document of discovered business insights",
mimeType="text/plain",
)
]
# REGISTER HANDLERS
logger.debug("Registering handlers")
@server.read_resource()
async def handle_read_resource(uri: AnyUrl) -> str:
logger.debug(f"Handling read_resource request for URI: {uri}")
if uri.scheme != "memo":
logger.error(f"Unsupported URI scheme: {uri.scheme}")
raise ValueError(f"Unsupported URI scheme: {uri.scheme}")
@self.list_resources()
async def handle_list_resources() -> list[types.Resource]:
logger.debug("Handling list_resources request")
return [
types.Resource(
uri=AnyUrl("memo://insights"),
name="Business Insights Memo",
description="A living document of discovered business insights",
mimeType="text/plain",
)
]
path = str(uri).replace("memo://", "")
if not path or path != "insights":
logger.error(f"Unknown resource path: {path}")
raise ValueError(f"Unknown resource path: {path}")
@self.read_resource()
async def handle_read_resource(uri: AnyUrl) -> str:
logger.debug(f"Handling read_resource request for URI: {uri}")
if uri.scheme != "memo":
logger.error(f"Unsupported URI scheme: {uri.scheme}")
raise ValueError(f"Unsupported URI scheme: {uri.scheme}")
return db._synthesize_memo()
path = str(uri).replace("memo://", "")
if not path or path != "insights":
logger.error(f"Unknown resource path: {path}")
raise ValueError(f"Unknown resource path: {path}")
return self._synthesize_memo()
@self.list_prompts()
async def handle_list_prompts() -> list[types.Prompt]:
logger.debug("Handling list_prompts request")
return [
types.Prompt(
name="mcp-demo",
description="A prompt to seed the database with initial data and demonstrate what you can do with an SQLite MCP Server + Claude",
arguments=[
types.PromptArgument(
name="topic",
description="Topic to seed the database with initial data",
required=True,
)
],
)
]
@self.get_prompt()
async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult:
logger.debug(f"Handling get_prompt request for {name} with args {arguments}")
if name != "mcp-demo":
logger.error(f"Unknown prompt: {name}")
raise ValueError(f"Unknown prompt: {name}")
if not arguments or "topic" not in arguments:
logger.error("Missing required argument: topic")
raise ValueError("Missing required argument: topic")
topic = arguments["topic"]
prompt = PROMPT_TEMPLATE.format(topic=topic)
logger.debug(f"Generated prompt template for topic: {topic}")
return types.GetPromptResult(
description=f"Demo template for {topic}",
messages=[
types.PromptMessage(
role="user",
content=types.TextContent(type="text", text=prompt.strip()),
@server.list_prompts()
async def handle_list_prompts() -> list[types.Prompt]:
logger.debug("Handling list_prompts request")
return [
types.Prompt(
name="mcp-demo",
description="A prompt to seed the database with initial data and demonstrate what you can do with an SQLite MCP Server + Claude",
arguments=[
types.PromptArgument(
name="topic",
description="Topic to seed the database with initial data",
required=True,
)
],
)
]
# TOOL HANDLERS
@self.list_tools()
async def handle_list_tools() -> list[types.Tool]:
"""List available tools"""
return [
types.Tool(
name="read-query",
description="Execute a SELECT query on the SQLite database",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "SELECT SQL query to execute"},
},
"required": ["query"],
},
),
types.Tool(
name="write-query",
description="Execute an INSERT, UPDATE, or DELETE query on the SQLite database",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "SQL query to execute"},
},
"required": ["query"],
},
),
types.Tool(
name="create-table",
description="Create a new table in the SQLite database",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "CREATE TABLE SQL statement"},
},
"required": ["query"],
},
),
types.Tool(
name="list-tables",
description="List all tables in the SQLite database",
inputSchema={
"type": "object",
"properties": {},
},
),
types.Tool(
name="describe-table",
description="Get the schema information for a specific table",
inputSchema={
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Name of the table to describe"},
},
"required": ["table_name"],
},
),
types.Tool(
name="append-insight",
description="Add a business insight to the memo",
inputSchema={
"type": "object",
"properties": {
"insight": {"type": "string", "description": "Business insight discovered from data analysis"},
},
"required": ["insight"],
},
),
]
@server.get_prompt()
async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult:
logger.debug(f"Handling get_prompt request for {name} with args {arguments}")
if name != "mcp-demo":
logger.error(f"Unknown prompt: {name}")
raise ValueError(f"Unknown prompt: {name}")
@self.call_tool()
async def handle_call_tool(
name: str, arguments: dict | None
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""Handle tool execution requests"""
try:
if name == "list-tables":
results = self._execute_query(
"SELECT name FROM sqlite_master WHERE type='table'"
)
return [types.TextContent(type="text", text=str(results))]
if not arguments or "topic" not in arguments:
logger.error("Missing required argument: topic")
raise ValueError("Missing required argument: topic")
elif name == "describe-table":
if not arguments or "table_name" not in arguments:
raise ValueError("Missing table_name argument")
results = self._execute_query(
f"PRAGMA table_info({arguments['table_name']})"
)
return [types.TextContent(type="text", text=str(results))]
topic = arguments["topic"]
prompt = PROMPT_TEMPLATE.format(topic=topic)
elif name == "append-insight":
if not arguments or "insight" not in arguments:
raise ValueError("Missing insight argument")
logger.debug(f"Generated prompt template for topic: {topic}")
return types.GetPromptResult(
description=f"Demo template for {topic}",
messages=[
types.PromptMessage(
role="user",
content=types.TextContent(type="text", text=prompt.strip()),
)
],
)
self.insights.append(arguments["insight"])
memo = self._synthesize_memo()
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
"""List available tools"""
return [
types.Tool(
name="read-query",
description="Execute a SELECT query on the SQLite database",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "SELECT SQL query to execute"},
},
"required": ["query"],
},
),
types.Tool(
name="write-query",
description="Execute an INSERT, UPDATE, or DELETE query on the SQLite database",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "SQL query to execute"},
},
"required": ["query"],
},
),
types.Tool(
name="create-table",
description="Create a new table in the SQLite database",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "CREATE TABLE SQL statement"},
},
"required": ["query"],
},
),
types.Tool(
name="list-tables",
description="List all tables in the SQLite database",
inputSchema={
"type": "object",
"properties": {},
},
),
types.Tool(
name="describe-table",
description="Get the schema information for a specific table",
inputSchema={
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Name of the table to describe"},
},
"required": ["table_name"],
},
),
types.Tool(
name="append-insight",
description="Add a business insight to the memo",
inputSchema={
"type": "object",
"properties": {
"insight": {"type": "string", "description": "Business insight discovered from data analysis"},
},
"required": ["insight"],
},
),
]
@server.call_tool()
async def handle_call_tool(
name: str, arguments: dict[str, Any] | None
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""Handle tool execution requests"""
try:
if name == "list-tables":
results = db._execute_query(
"SELECT name FROM sqlite_master WHERE type='table'"
)
return [types.TextContent(type="text", text=str(results))]
elif name == "describe-table":
if not arguments or "table_name" not in arguments:
raise ValueError("Missing table_name argument")
results = db._execute_query(
f"PRAGMA table_info({arguments['table_name']})"
)
return [types.TextContent(type="text", text=str(results))]
elif name == "append-insight":
if not arguments or "insight" not in arguments:
raise ValueError("Missing insight argument")
db.insights.append(arguments["insight"])
_ = db._synthesize_memo()
# Notify clients that the memo resource has changed
await self.request_context.session.send_resource_updated(AnyUrl("memo://insights"))
@@ -341,35 +338,34 @@ class McpServer(Server):
if not arguments:
raise ValueError("Missing arguments")
if name == "read-query":
if not arguments["query"].strip().upper().startswith("SELECT"):
raise ValueError("Only SELECT queries are allowed for read-query")
results = self._execute_query(arguments["query"])
return [types.TextContent(type="text", text=str(results))]
if not arguments:
raise ValueError("Missing arguments")
elif name == "write-query":
if arguments["query"].strip().upper().startswith("SELECT"):
raise ValueError("SELECT queries are not allowed for write-query")
results = self._execute_query(arguments["query"])
return [types.TextContent(type="text", text=str(results))]
if name == "read-query":
if not arguments["query"].strip().upper().startswith("SELECT"):
raise ValueError("Only SELECT queries are allowed for read-query")
results = db._execute_query(arguments["query"])
return [types.TextContent(type="text", text=str(results))]
elif name == "create-table":
if not arguments["query"].strip().upper().startswith("CREATE TABLE"):
raise ValueError("Only CREATE TABLE statements are allowed")
self._execute_query(arguments["query"])
return [types.TextContent(type="text", text="Table created successfully")]
elif name == "write-query":
if arguments["query"].strip().upper().startswith("SELECT"):
raise ValueError("SELECT queries are not allowed for write-query")
results = db._execute_query(arguments["query"])
return [types.TextContent(type="text", text=str(results))]
else:
raise ValueError(f"Unknown tool: {name}")
elif name == "create-table":
if not arguments["query"].strip().upper().startswith("CREATE TABLE"):
raise ValueError("Only CREATE TABLE statements are allowed")
db._execute_query(arguments["query"])
return [types.TextContent(type="text", text="Table created successfully")]
except sqlite3.Error as e:
return [types.TextContent(type="text", text=f"Database error: {str(e)}")]
except Exception as e:
return [types.TextContent(type="text", text=f"Error: {str(e)}")]
else:
raise ValueError(f"Unknown tool: {name}")
async def main(db_path: str):
logger.info(f"Starting SQLite MCP Server with DB path: {db_path}")
server = McpServer(db_path)
except sqlite3.Error as e:
return [types.TextContent(type="text", text=f"Database error: {str(e)}")]
except Exception as e:
return [types.TextContent(type="text", text=f"Error: {str(e)}")]
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
logger.info("Server running with stdio transport")
@@ -381,8 +377,7 @@ async def main(db_path: str):
server_version="0.1.0",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={
},
experimental_capabilities={},
),
),
)