Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions sqle/driver/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/actiontech/sqle/sqle/driver/mysql/plocale"
rulepkg "github.com/actiontech/sqle/sqle/driver/mysql/rule"
_ "github.com/actiontech/sqle/sqle/driver/mysql/rule/ai"
aiutil "github.com/actiontech/sqle/sqle/driver/mysql/rule/ai/util"
"github.com/actiontech/sqle/sqle/driver/mysql/session"
"github.com/actiontech/sqle/sqle/driver/mysql/util"
driverV2 "github.com/actiontech/sqle/sqle/driver/v2"
Expand Down Expand Up @@ -639,6 +640,83 @@ func (p *PluginProcessor) GetDriverMetas() (*driverV2.DriverMetas, error) {
return metas, nil
}

func (i *MysqlDriverImpl) GetSelectivityOfSQLColumns(ctx context.Context, sql string) (map[string]map[string]float32, error) {
node, err := util.ParseOneSql(sql)
if err != nil {
return nil, err
}

if _, ok := node.(*ast.SelectStmt); !ok {
log.NewEntry().Errorf("get selectivity of sql columns failed, sql is not a select statement, sql: %s", sql)
return nil, nil
}

selectVisitor := &util.SelectVisitor{}
node.Accept(selectVisitor)

result := make(map[string]map[string]float32)

for _, selectNode := range selectVisitor.SelectList {
if selectNode.From == nil || selectNode.From.TableRefs == nil {
continue
}

// 获取表别名映射关系
aliasInfo := aiutil.GetTableAliasInfoFromJoin(selectNode.From.TableRefs)
aliasMap := make(map[string]string)
allTables := make([]string, 0, len(aliasInfo))

for _, alias := range aliasInfo {
if alias.TableAliasName != "" {
aliasMap[alias.TableAliasName] = alias.TableName
}
allTables = append(allTables, alias.TableName)
}

// 提取列并按表分组
tableColumns := util.ExtractColumnsFromSelectStmt(selectNode, aliasMap, allTables)

// 遍历每个表,获取其列的选择性
for tableName, columnSet := range tableColumns {
columns := make([]string, 0, len(columnSet))
for colName := range columnSet {
columns = append(columns, colName)
}

if len(columns) == 0 {
continue
}

// 构造 TableName 对象
var schemaName string
for _, alias := range aliasInfo {
if alias.TableName == tableName {
schemaName = alias.SchemaName
break
}
}
tableNameObj := util.NewTableName(schemaName, tableName)

columnSelectivityMap, err := i.Ctx.GetSelectivityOfColumns(tableNameObj, columns)
if err != nil {
log.NewEntry().Errorf("get selectivity of columns failed, table: %s, columns: %v, error: %v", tableName, columns, err)
continue
}

if result[tableName] == nil {
result[tableName] = make(map[string]float32)
}
for columnName, selectivity := range columnSelectivityMap {
if selectivity > 0 {
result[tableName][columnName] = float32(selectivity)
}
}
}
}

return result, nil
}

func (p *PluginProcessor) Open(l *logrus.Entry, cfg *driverV2.Config) (driver.Plugin, error) {
return NewInspect(l, cfg)
}
Expand Down
103 changes: 103 additions & 0 deletions sqle/driver/mysql/util/parser_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -931,3 +931,106 @@ func ConvertAliasToTable(alias string, tables []*ast.TableSource) (*ast.TableNam
}
return nil, errors.New("can not find table")
}

// TableColumnMap 表示按表分组的列名集合
type TableColumnMap map[string]map[string]struct{}

// ExtractColumnsFromSelectStmt 从 SELECT 语句中提取列,并按表分组
// 参数:
// - selectStmt: SELECT 语句节点
// - aliasMap: 表别名到实际表名的映射
// - allTables: 所有涉及的表名列表(用于处理无表前缀的列)
//
// 返回:按表名分组的列名集合
func ExtractColumnsFromSelectStmt(selectStmt *ast.SelectStmt, aliasMap map[string]string, allTables []string) TableColumnMap {
tableColumns := make(TableColumnMap)

// 收集 SELECT 列表中的所有列别名
selectAliases := make(map[string]struct{})
if selectStmt.Fields != nil {
for _, field := range selectStmt.Fields.Fields {
if field.AsName.L != "" {
selectAliases[field.AsName.L] = struct{}{}
}
}
}

// 辅助函数:从表达式中提取列并按表分组
extractColumnsFromExpr := func(expr ast.Node, skipAliases bool) {
if expr == nil {
return
}
columnVisitor := &ColumnNameVisitor{}
expr.Accept(columnVisitor)

for _, colExpr := range columnVisitor.ColumnNameList {
if colExpr.Name == nil {
continue
}

// 如果需要跳过别名且当前列名是一个别名,则跳过
if skipAliases {
if _, isAlias := selectAliases[colExpr.Name.Name.L]; isAlias && colExpr.Name.Table.L == "" {
continue
}
}

var targetTableName string

// 如果列有表前缀(可能是别名或实际表名)
if colExpr.Name.Table.L != "" {
// 先尝试从别名映射中查找
if actualTable, exists := aliasMap[colExpr.Name.Table.L]; exists {
targetTableName = actualTable
} else {
// 如果不是别名,就当作实际表名
targetTableName = colExpr.Name.Table.L
}
}

if targetTableName != "" {
if tableColumns[targetTableName] == nil {
tableColumns[targetTableName] = make(map[string]struct{})
}
tableColumns[targetTableName][colExpr.Name.Name.L] = struct{}{}
} else {
// 没有表前缀的列,可能属于任何表
// 在多表查询中,尝试将该列添加到所有表
for _, tableName := range allTables {
if tableColumns[tableName] == nil {
tableColumns[tableName] = make(map[string]struct{})
}
tableColumns[tableName][colExpr.Name.Name.L] = struct{}{}
}
}
}
}

// 从 SELECT Fields 提取列(包括聚合函数内的列)
if selectStmt.Fields != nil {
for _, field := range selectStmt.Fields.Fields {
extractColumnsFromExpr(field.Expr, false)
}
}

// 从 WHERE 条件提取列
if selectStmt.Where != nil {
extractColumnsFromExpr(selectStmt.Where, false)
}

// 从 GROUP BY 提取列(需要跳过别名引用)
if selectStmt.GroupBy != nil {
for _, item := range selectStmt.GroupBy.Items {
extractColumnsFromExpr(item.Expr, true)
}
}

// 从 HAVING 提取列
if selectStmt.Having != nil {
extractColumnsFromExpr(selectStmt.Having.Expr, false)
}

// 注意:不从 ORDER BY 提取,因为可能包含别名引用

return tableColumns
}
4 changes: 4 additions & 0 deletions sqle/driver/plugin_adapter_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,7 @@ func (s *PluginImplV1) GetDatabaseObjectDDL(ctx context.Context, objInfos []*dri
func (s *PluginImplV1) GetDatabaseDiffModifySQL(ctx context.Context, calibratedDSN *driverV2.DSN, objInfos []*driverV2.DatabasCompareSchemaInfo) ([]*driverV2.DatabaseDiffModifySQLResult, error) {
return nil, fmt.Errorf("unimplement this method")
}

func (p *PluginImplV1) GetSelectivityOfSQLColumns(ctx context.Context, sql string) (map[string]map[string]float32, error) {
return nil, fmt.Errorf("unimplement this method")
}
22 changes: 22 additions & 0 deletions sqle/driver/plugin_adapter_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,3 +725,25 @@ func (s *PluginImplV2) GetDatabaseDiffModifySQL(ctx context.Context, calibratedD
}
return dbDiffSQLs, nil
}

func (s *PluginImplV2) GetSelectivityOfSQLColumns(ctx context.Context, sql string) (map[string]map[string]float32, error) {
api := "GetSelectivityOfSQLColumns"
s.preLog(api)
resp, err := s.client.GetSelectivityOfSQLColumns(ctx, &protoV2.GetSelectivityOfSQLColumnsRequest{
Session: s.Session,
Sql: sql,
})
s.afterLog(api, err)
if err != nil {
return nil, err
}
result := make(map[string]map[string]float32, len(resp.Selectivity))
for _, v := range resp.Selectivity {
colMap := make(map[string]float32, len(v.SelectivityOfColumns))
for k, sel := range v.SelectivityOfColumns {
colMap[k] = sel
}
result[v.TableName] = colMap
}
return result, nil
}
1 change: 1 addition & 0 deletions sqle/driver/plugin_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type Plugin interface {
Backup(ctx context.Context, backupStrategy string, sql string, backupMaxRows uint64) (backupSqls []string, executeResult string, err error)

RecommendBackupStrategy(ctx context.Context, sql string) (*RecommendBackupStrategyRes, error)
GetSelectivityOfSQLColumns(ctx context.Context, sql string) (map[string] /*table name*/ map[string] /*column name*/ float32, error)
}

type RecommendBackupStrategyRes struct {
Expand Down
24 changes: 24 additions & 0 deletions sqle/driver/v2/driver_grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,3 +674,27 @@ func (d *DriverGrpcServer) GetDatabaseDiffModifySQL(ctx context.Context, req *pr
SchemaDiffModify: scheamDiff,
}, nil
}

func (d *DriverGrpcServer) GetSelectivityOfSQLColumns(ctx context.Context, req *protoV2.GetSelectivityOfSQLColumnsRequest) (*protoV2.GetSelectivityOfSQLColumnsResponse, error) {
driver, err := d.getDriverBySession(req.Session)
if err != nil {
return &protoV2.GetSelectivityOfSQLColumnsResponse{}, err
}
selectivity, err := driver.GetSelectivityOfSQLColumns(ctx, req.Sql)
if err != nil {
return &protoV2.GetSelectivityOfSQLColumnsResponse{}, err
}
protoSelectivity := make([]*protoV2.SelectivityOfSQLColumns, 0, len(selectivity))
for tableName, colMap := range selectivity {
// 直接将 map[string]float32 赋值,无需合并操作
merged := make(map[string]float32, len(colMap))
for col, val := range colMap {
merged[col] = val
}
protoSelectivity = append(protoSelectivity, &protoV2.SelectivityOfSQLColumns{
TableName: tableName,
SelectivityOfColumns: merged,
})
}
return &protoV2.GetSelectivityOfSQLColumnsResponse{Selectivity: protoSelectivity}, nil
}
1 change: 1 addition & 0 deletions sqle/driver/v2/driver_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ type Driver interface {

Backup(ctx context.Context, req *BackupReq) (*BackupRes, error)
RecommendBackupStrategy(ctx context.Context, req *RecommendBackupStrategyReq) (*RecommendBackupStrategyRes, error)
GetSelectivityOfSQLColumns(ctx context.Context, sql string) (map[string] /*table name*/ map[string] /*column name*/ float32, error)
}

const (
Expand Down
Loading
Loading