1
0
Эх сурвалжийг харах

escape underline and percent sign for DocumentSearchResult

玖亖伍 1 жил өмнө
parent
commit
0dbb5d7967

+ 41 - 13
models/DocumentSearchResult.go

@@ -1,12 +1,13 @@
 package models
 
 import (
-	"time"
-
+	"regexp"
 	"strings"
+	"time"
 
 	"github.com/beego/beego/v2/client/orm"
 	"github.com/beego/beego/v2/core/logs"
+	"github.com/beego/beego/v2/server/web"
 )
 
 type DocumentSearchResult struct {
@@ -24,11 +25,22 @@ type DocumentSearchResult struct {
 	SearchType   string    `json:"search_type"`
 }
 
+var escape_re = regexp.MustCompile(`(?mi)(\bLIKE\s+\?)`)
+var escape_replace = "${1} ESCAPE '\\'"
+
+func need_escape(keyword string) bool {
+	dbadapter, _ := web.AppConfig.String("db_adapter")
+	if strings.EqualFold(dbadapter, "sqlite3") && (strings.Contains(keyword, "\\_") || strings.Contains(keyword, "\\%")) {
+		return true
+	}
+	return false
+}
+
 func NewDocumentSearchResult() *DocumentSearchResult {
 	return &DocumentSearchResult{}
 }
 
-//分页全局搜索.
+// 分页全局搜索.
 func (m *DocumentSearchResult) FindToPager(keyword string, pageIndex, pageSize, memberId int) (searchResult []*DocumentSearchResult, totalCount int, err error) {
 	o := orm.NewOrm()
 
@@ -36,6 +48,14 @@ func (m *DocumentSearchResult) FindToPager(keyword string, pageIndex, pageSize,
 
 	keyword = "%" + strings.Replace(keyword, " ", "%", -1) + "%"
 
+	_need_escape := need_escape(keyword)
+	escape_sql := func(sql string) string {
+		if _need_escape {
+			return escape_re.ReplaceAllString(sql, escape_replace)
+		}
+		return sql
+	}
+
 	if memberId <= 0 {
 		sql1 := `SELECT count(doc.document_id) as total_count FROM md_documents AS doc
   LEFT JOIN md_books as book ON doc.book_id = book.book_id
@@ -98,7 +118,7 @@ WHERE book.privately_owned = 0 AND (book.book_name LIKE ? OR book.description LI
 ORDER BY create_time DESC
 LIMIT ? OFFSET ?;`
 
-		err = o.Raw(sql1, keyword, keyword).QueryRow(&totalCount)
+		err = o.Raw(escape_sql(sql1), keyword, keyword).QueryRow(&totalCount)
 		if err != nil {
 			logs.Error("查询搜索结果失败 -> ", err)
 			return
@@ -109,7 +129,7 @@ LIMIT ? OFFSET ?;`
        WHERE blog.blog_status = 'public' AND (blog.blog_release LIKE ? OR blog.blog_title LIKE ?);`
 
 		c := 0
-		err = o.Raw(sql3, keyword, keyword).QueryRow(&c)
+		err = o.Raw(escape_sql(sql3), keyword, keyword).QueryRow(&c)
 		if err != nil {
 			logs.Error("查询搜索结果失败 -> ", err)
 			return
@@ -120,7 +140,7 @@ LIMIT ? OFFSET ?;`
 		sql4 := `SELECT count(*) as total_count FROM md_books as book
 WHERE book.privately_owned = 0 AND (book.book_name LIKE ? OR book.description LIKE ?);`
 
-		err = o.Raw(sql4, keyword, keyword).QueryRow(&c)
+		err = o.Raw(escape_sql(sql4), keyword, keyword).QueryRow(&c)
 		if err != nil {
 			logs.Error("查询搜索结果失败 -> ", err)
 			return
@@ -128,7 +148,7 @@ WHERE book.privately_owned = 0 AND (book.book_name LIKE ? OR book.description LI
 
 		totalCount += c
 
-		_, err = o.Raw(sql2, keyword, keyword, keyword, keyword, keyword, keyword, pageSize, offset).QueryRows(&searchResult)
+		_, err = o.Raw(escape_sql(sql2), keyword, keyword, keyword, keyword, keyword, keyword, pageSize, offset).QueryRows(&searchResult)
 		if err != nil {
 			logs.Error("查询搜索结果失败 -> ", err)
 			return
@@ -226,7 +246,7 @@ FROM (
 ORDER BY create_time DESC
 LIMIT ? OFFSET ?;`
 
-		err = o.Raw(sql1, memberId, memberId, keyword, keyword).QueryRow(&totalCount)
+		err = o.Raw(escape_sql(sql1), memberId, memberId, keyword, keyword).QueryRow(&totalCount)
 		if err != nil {
 			return
 		}
@@ -237,7 +257,7 @@ LIMIT ? OFFSET ?;`
              (blog.blog_release LIKE ? OR blog.blog_title LIKE ?);`
 
 		c := 0
-		err = o.Raw(sql3, memberId, keyword, keyword).QueryRow(&c)
+		err = o.Raw(escape_sql(sql3), memberId, keyword, keyword).QueryRow(&c)
 		if err != nil {
 			logs.Error("查询搜索结果失败 -> ", err)
 			return
@@ -254,7 +274,7 @@ LIMIT ? OFFSET ?;`
 					on team.book_id = book.book_id
 WHERE (book.privately_owned = 0 OR rel1.relationship_id > 0 or team.team_member_id > 0)  AND (book.book_name LIKE ? OR book.description LIKE ?);`
 
-		err = o.Raw(sql4, memberId, memberId, keyword, keyword).QueryRow(&c)
+		err = o.Raw(escape_sql(sql4), memberId, memberId, keyword, keyword).QueryRow(&c)
 		if err != nil {
 			logs.Error("查询搜索结果失败 -> ", err)
 			return
@@ -262,7 +282,7 @@ WHERE (book.privately_owned = 0 OR rel1.relationship_id > 0 or team.team_member_
 
 		totalCount += c
 
-		_, err = o.Raw(sql2, memberId, memberId, keyword, keyword, memberId, memberId, keyword, keyword, memberId, keyword, keyword, pageSize, offset).QueryRows(&searchResult)
+		_, err = o.Raw(escape_sql(sql2), memberId, memberId, keyword, keyword, memberId, memberId, keyword, keyword, memberId, keyword, keyword, pageSize, offset).QueryRows(&searchResult)
 		if err != nil {
 			return
 		}
@@ -270,14 +290,22 @@ WHERE (book.privately_owned = 0 OR rel1.relationship_id > 0 or team.team_member_
 	return
 }
 
-//项目内搜索.
+// 项目内搜索.
 func (m *DocumentSearchResult) SearchDocument(keyword string, bookId int) (docs []*DocumentSearchResult, err error) {
 	o := orm.NewOrm()
 
 	sql := "SELECT * FROM md_documents WHERE book_id = ? AND (document_name LIKE ? OR `release` LIKE ?) "
 	keyword = "%" + keyword + "%"
 
-	_, err = o.Raw(sql, bookId, keyword, keyword).QueryRows(&docs)
+	_need_escape := need_escape(keyword)
+	escape_sql := func(sql string) string {
+		if _need_escape {
+			return escape_re.ReplaceAllString(sql, escape_replace)
+		}
+		return sql
+	}
+
+	_, err = o.Raw(escape_sql(sql), bookId, keyword, keyword).QueryRows(&docs)
 
 	return
 }