refactored code for better readability and maintainability

This commit is contained in:
Ben Borla
2024-12-09 15:11:37 +08:00
parent 68549532d4
commit 90fabce9d5

View File

@@ -8,7 +8,7 @@ import {
ListToolsRequestSchema, ListToolsRequestSchema,
ReadResourceRequestSchema, ReadResourceRequestSchema,
} from "@modelcontextprotocol/sdk/types.js"; } from "@modelcontextprotocol/sdk/types.js";
import mysql, { MysqlError, PoolConnection, OkPacket } from "mysql"; import mysql, { MysqlError, PoolConnection } from "mysql";
type MySQLErrorType = MysqlError | null; type MySQLErrorType = MysqlError | null;
@@ -21,187 +21,229 @@ interface ColumnRow {
data_type: string; data_type: string;
} }
type QueryResult = OkPacket | any[] | any; const config = {
const server = new Server( server: {
{
name: "example-servers/mysql", name: "example-servers/mysql",
version: "0.1.0", version: "0.1.0",
}, },
{ mysql: {
capabilities: { host: process.env.MYSQL_HOST || "127.0.0.1",
resources: {}, port: Number(process.env.MYSQL_PORT || "3306"),
tools: {}, user: process.env.MYSQL_USER || "root",
}, password: process.env.MYSQL_PASS || "",
database: process.env.MYSQL_DB || "",
connectionLimit: 10,
}, },
); paths: {
schema: "schema",
},
};
const MYSQL_HOST = process.env.MYSQL_HOST || "127.0.0.1"; const mysqlQuery = (
const MYSQL_PORT = process.env.MYSQL_PORT || "3306"; connection: PoolConnection,
const MYSQL_USER = process.env.MYSQL_USER || "root"; sql: string,
const MYSQL_PASS = process.env.MYSQL_PASS || ""; params: any[] = [],
const MYSQL_DB = process.env.MYSQL_DB || ""; ): Promise<any> => {
return new Promise((resolve, reject) => {
connection.query(sql, params, (error: MySQLErrorType, results: any) => {
if (error) reject(error);
else resolve(results);
});
});
};
const pool = mysql.createPool({ const mysqlGetConnection = (pool: mysql.Pool): Promise<PoolConnection> => {
connectionLimit: 10, return new Promise(
host: MYSQL_HOST, (
port: Number(MYSQL_PORT), resolve: (value: PoolConnection | PromiseLike<PoolConnection>) => void,
user: MYSQL_USER, reject,
password: MYSQL_PASS, ) => {
database: MYSQL_DB, pool.getConnection(
(error: MySQLErrorType, connection: PoolConnection) => {
if (error) reject(error);
else resolve(connection);
},
);
},
);
};
const mysqlBeginTransaction = (connection: PoolConnection): Promise<void> => {
return new Promise((resolve, reject) => {
connection.beginTransaction((error: MySQLErrorType) => {
if (error) reject(error);
else resolve();
});
});
};
const mysqlRollback = (connection: PoolConnection): Promise<void> => {
return new Promise((resolve, _) => {
connection.rollback(() => resolve());
});
};
const pool = mysql.createPool(config.mysql);
const server = new Server(config.server, {
capabilities: {
resources: {},
tools: {},
},
}); });
const SCHEMA_PATH = "schema"; async function executeQuery(sql: string, params: any[] = []): Promise<any> {
const connection = await mysqlGetConnection(pool);
try {
const results = await mysqlQuery(connection, sql, params);
return results;
} finally {
connection.release();
}
}
async function executeReadOnlyQuery(sql: string): Promise<any> {
const connection = await mysqlGetConnection(pool);
try {
// Set read-only mode
await mysqlQuery(connection, "SET SESSION TRANSACTION READ ONLY");
// Begin transaction
await mysqlBeginTransaction(connection);
// Execute query
const results = await mysqlQuery(connection, sql);
// Rollback transaction (since it's read-only)
await mysqlRollback(connection);
// Reset to read-write mode
await mysqlQuery(connection, "SET SESSION TRANSACTION READ WRITE");
return {
content: [
{
type: "text",
text: JSON.stringify(results, null, 2),
},
],
isError: false,
};
} catch (error) {
await mysqlRollback(connection);
throw error;
} finally {
connection.release();
}
}
// Request handlers
server.setRequestHandler(ListResourcesRequestSchema, async () => { server.setRequestHandler(ListResourcesRequestSchema, async () => {
return new Promise((resolve, reject) => { const results = (await executeQuery(
pool.query( "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()",
"SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()", )) as TableRow[];
(error: MySQLErrorType, results: TableRow[]) => {
if (error) reject(error); return {
resolve({ resources: results.map((row: TableRow) => ({
resources: results.map((row: TableRow) => ({ uri: new URL(
uri: new URL( `${row.table_name}/${config.paths.schema}`,
`${row.table_name}/${SCHEMA_PATH}`, `${config.mysql.host}:${config.mysql.port}`,
`${MYSQL_HOST}:${MYSQL_PORT}`, ).href,
).href, mimeType: "application/json",
mimeType: "application/json", name: `"${row.table_name}" database schema`,
name: `"${row.table_name}" database schema`, })),
})), };
});
},
);
});
}); });
server.setRequestHandler(ReadResourceRequestSchema, async (request) => { server.setRequestHandler(ReadResourceRequestSchema, async (request) => {
const resourceUrl = new URL(request.params.uri); const resourceUrl = new URL(request.params.uri);
const pathComponents = resourceUrl.pathname.split("/"); const pathComponents = resourceUrl.pathname.split("/");
const schema = pathComponents.pop(); const schema = pathComponents.pop();
const tableName = pathComponents.pop(); const tableName = pathComponents.pop();
if (schema !== SCHEMA_PATH) { if (schema !== config.paths.schema) {
throw new Error("Invalid resource URI"); throw new Error("Invalid resource URI");
} }
return new Promise((resolve, reject) => { const results = (await executeQuery(
pool.query( "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = ?",
"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = ?", [tableName],
[tableName], )) as ColumnRow[];
(error: MySQLErrorType, results: ColumnRow[]) => {
if (error) reject(error);
resolve({
contents: [
{
uri: request.params.uri,
mimeType: "application/json",
text: JSON.stringify(results, null, 2),
},
],
});
},
);
});
});
server.setRequestHandler(ListToolsRequestSchema, async () => {
return { return {
tools: [ contents: [
{ {
name: "mysql_query", uri: request.params.uri,
description: "Run a read-only MySQL query", mimeType: "application/json",
inputSchema: { text: JSON.stringify(results, null, 2),
type: "object",
properties: {
sql: { type: "string" },
},
},
}, },
], ],
}; };
}); });
server.setRequestHandler(ListToolsRequestSchema, async () => ({
tools: [
{
name: "mysql_query",
description: "Run a read-only MySQL query",
inputSchema: {
type: "object",
properties: {
sql: { type: "string" },
},
},
},
],
}));
server.setRequestHandler(CallToolRequestSchema, async (request) => { server.setRequestHandler(CallToolRequestSchema, async (request) => {
if (request.params.name === "mysql_query") { if (request.params.name !== "mysql_query") {
const sql = request.params.arguments?.sql as string; throw new Error(`Unknown tool: ${request.params.name}`);
return new Promise((resolve, reject) => {
pool.getConnection((err: MySQLErrorType, connection: PoolConnection) => {
if (err) reject(err);
// @INFO: Set session to read only BEFORE beginning the transaction
connection.query(
"SET SESSION TRANSACTION READ ONLY",
(err: MySQLErrorType) => {
if (err) {
connection.release();
reject(err);
return;
}
connection.beginTransaction((err: MySQLErrorType) => {
if (err) {
connection.release();
reject(err);
return;
}
connection.query(
sql,
(error: MySQLErrorType, results: QueryResult) => {
if (error) {
connection.rollback(() => {
connection.release();
reject(error);
});
return;
}
// @INFO: Reset the transaction mode back to default before releasing
connection.rollback(() => {
connection.query(
"SET SESSION TRANSACTION READ WRITE",
(err: MySQLErrorType) => {
connection.release();
if (err) {
console.warn(
"Failed to reset transaction mode:",
err,
);
}
resolve({
content: [
{
type: "text",
text: JSON.stringify(results, null, 2),
},
],
isError: false,
});
},
);
});
},
);
});
},
);
});
});
} }
throw new Error(`Unknown tool: ${request.params.name}`);
const sql = request.params.arguments?.sql as string;
return executeReadOnlyQuery(sql);
}); });
// Server startup and shutdown
async function runServer() { async function runServer() {
const transport = new StdioServerTransport(); const transport = new StdioServerTransport();
await server.connect(transport); await server.connect(transport);
} }
process.on("SIGINT", () => { const shutdown = async (signal: string) => {
pool.end((err: MySQLErrorType) => { console.log(`Received ${signal}. Shutting down...`);
if (err) console.error("Error closing pool:", err); return new Promise<void>((resolve, reject) => {
process.exit(err ? 1 : 0); pool.end((err: MySQLErrorType) => {
if (err) {
console.error("Error closing pool:", err);
reject(err);
} else {
resolve();
}
});
}); });
};
process.on("SIGINT", async () => {
try {
await shutdown("SIGINT");
process.exit(0);
} catch (err) {
process.exit(1);
}
}); });
runServer().catch(console.error); process.on("SIGTERM", async () => {
try {
await shutdown("SIGTERM");
process.exit(0);
} catch (err) {
process.exit(1);
}
});
runServer().catch((error: unknown) => {
console.error("Server error:", error);
process.exit(1);
});