You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

384 lines
7.8 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. // Copyright 2018 Lars Hoogestraat
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package handler_test
  5. import (
  6. "bytes"
  7. "context"
  8. "database/sql"
  9. "fmt"
  10. "io"
  11. "io/ioutil"
  12. "log"
  13. "mime/multipart"
  14. "net/http"
  15. "net/http/httptest"
  16. "net/url"
  17. "os"
  18. "path/filepath"
  19. "testing"
  20. "git.hoogi.eu/snafu/go-blog/components/database"
  21. "git.hoogi.eu/snafu/go-blog/components/mail"
  22. "git.hoogi.eu/snafu/go-blog/logger"
  23. "git.hoogi.eu/snafu/go-blog/middleware"
  24. "git.hoogi.eu/snafu/go-blog/models"
  25. "git.hoogi.eu/snafu/go-blog/settings"
  26. "git.hoogi.eu/snafu/go-blog/utils"
  27. "git.hoogi.eu/snafu/session"
  28. _ "github.com/mattn/go-sqlite3"
  29. )
  30. var ctx *middleware.AppContext
  31. var db *sql.DB
  32. func setup(t *testing.T) {
  33. logger.InitLogger(ioutil.Discard, "Debug")
  34. db, err := sql.Open("sqlite3", ":memory:")
  35. if err != nil {
  36. t.Fatal(err)
  37. }
  38. err = database.InitTables(db)
  39. if err != nil {
  40. t.Fatal(err)
  41. }
  42. err = fillSeeds(db)
  43. if err != nil {
  44. t.Fatal(err)
  45. }
  46. cfg, err := settings.LoadConfig("../go-blog.conf")
  47. if err != nil {
  48. t.Fatal(err)
  49. }
  50. cfg.File.Location = os.TempDir()
  51. userService := models.UserService{
  52. Datasource: models.SQLiteUserDatasource{
  53. SQLConn: db,
  54. },
  55. Config: cfg.User,
  56. }
  57. userInviteService := models.UserInviteService{
  58. Datasource: models.SQLiteUserInviteDatasource{
  59. SQLConn: db,
  60. },
  61. UserService: userService,
  62. }
  63. articleService := models.ArticleService{
  64. AppConfig: cfg.Application,
  65. Datasource: models.SQLiteArticleDatasource{
  66. SQLConn: db,
  67. },
  68. }
  69. siteService := models.SiteService{
  70. Datasource: models.SQLiteSiteDatasource{
  71. SQLConn: db,
  72. },
  73. }
  74. fileService := models.FileService{
  75. Config: cfg.File,
  76. Datasource: models.SQLiteFileDatasource{
  77. SQLConn: db,
  78. },
  79. }
  80. categoryService := models.CategoryService{
  81. Datasource: models.SQLiteCategoryDatasource{
  82. SQLConn: db,
  83. },
  84. }
  85. tokenService := models.TokenService{
  86. Datasource: models.SQLiteTokenDatasource{
  87. SQLConn: db,
  88. },
  89. }
  90. mailer := models.Mailer{
  91. Sender: MockSMTP{},
  92. AppConfig: &cfg.Application,
  93. }
  94. sessionService := session.SessionService{
  95. Path: "/admin",
  96. Name: "test-session",
  97. HTTPOnly: true,
  98. Secure: true,
  99. SessionProvider: session.NewInMemoryProvider(),
  100. IdleSessionTTL: 10,
  101. }
  102. ctx = &middleware.AppContext{
  103. UserService: userService,
  104. UserInviteService: userInviteService,
  105. ArticleService: articleService,
  106. CategoryService: categoryService,
  107. SiteService: siteService,
  108. FileService: fileService,
  109. TokenService: tokenService,
  110. SessionService: &sessionService,
  111. Mailer: mailer,
  112. ConfigService: cfg,
  113. }
  114. }
  115. func teardown() {
  116. if db != nil {
  117. db.Close()
  118. }
  119. }
  120. func fillSeeds(db *sql.DB) error {
  121. salt := utils.GenerateSalt()
  122. saltedPassword := "123456789012" + salt
  123. password, err := utils.CryptPassword([]byte(saltedPassword), 12)
  124. if err != nil {
  125. return err
  126. }
  127. _, err = db.Exec("INSERT INTO user (id, username, email, display_name, salt, password, active, is_admin, last_modified) VALUES (1, 'alice', 'alice@example.org', 'Alice Schneier', ?, ?, 1, 1, date('now'))", string(salt), password)
  128. if err != nil {
  129. return err
  130. }
  131. _, err = db.Exec("INSERT INTO user (id, username, email, display_name, salt, password, active, is_admin, last_modified) VALUES (2, 'bob', 'bob@example.org', 'Bob Stallman', ?, ?, 1, 0, date('now'))", string(salt), string(password))
  132. if err != nil {
  133. return err
  134. }
  135. _, err = db.Exec("INSERT INTO user (id, username, email, display_name, salt, password, active, is_admin, last_modified) VALUES (3, 'mallory', 'mallory@example.org', 'Mallory Pike', ?, ?, 0, 1, date('now'))", string(salt), string(password))
  136. if err != nil {
  137. return err
  138. }
  139. _, err = db.Exec("INSERT INTO user (id, username, email, display_name, salt, password, active, is_admin, last_modified) VALUES (4, 'eve', 'eve@example.org', 'Mallory Pike', ?, ?, 0, 0, date('now'))", string(salt), string(password))
  140. if err != nil {
  141. return err
  142. }
  143. return nil
  144. }
  145. func dummyAdminUser() *models.User {
  146. u, _ := ctx.UserService.GetByID(1)
  147. return u
  148. }
  149. func dummyUser() *models.User {
  150. u, _ := ctx.UserService.GetByID(2)
  151. return u
  152. }
  153. func setHeader(r *http.Request, key, value string) {
  154. r.Header.Set("X-Unit-Testing-Value-"+key, value)
  155. }
  156. type MockSMTP struct{}
  157. func (sm MockSMTP) Send(m mail.Mail) error {
  158. return nil
  159. }
  160. func (sm MockSMTP) SendAsync(m mail.Mail) {
  161. }
  162. func addValue(m url.Values, key, value string) {
  163. m.Add(key, value)
  164. }
  165. func addCheckboxValue(m url.Values, key string, value bool) {
  166. if value {
  167. m.Add(key, "on")
  168. }
  169. m.Add(key, "off")
  170. }
  171. func postMultipart(path string, mp []multipartRequest) (*http.Request, error) {
  172. buf := &bytes.Buffer{}
  173. mw := multipart.NewWriter(buf)
  174. defer mw.Close()
  175. for _, v := range mp {
  176. fh, err := os.Open(v.file)
  177. if err != nil {
  178. return nil, err
  179. }
  180. defer fh.Close()
  181. fw, err := mw.CreateFormFile(v.key, filepath.Base(fh.Name()))
  182. if err != nil {
  183. return nil, err
  184. }
  185. _, err = io.Copy(fw, fh)
  186. if err != nil {
  187. return nil, err
  188. }
  189. }
  190. req, err := http.NewRequest("POST", path, buf)
  191. if err != nil {
  192. return nil, err
  193. }
  194. req.Header.Set("Content-Type", mw.FormDataContentType())
  195. return req, nil
  196. }
  197. func post(path string, values url.Values) (*http.Request, error) {
  198. var b bytes.Buffer
  199. b.WriteString(values.Encode())
  200. req, err := http.NewRequest("POST", path, &b)
  201. if err != nil {
  202. return nil, err
  203. }
  204. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  205. return req, nil
  206. }
  207. func get(path string, values url.Values) (*http.Request, error) {
  208. var b bytes.Buffer
  209. b.WriteString(values.Encode())
  210. req, err := http.NewRequest("GET", path, &b)
  211. if err != nil {
  212. return nil, err
  213. }
  214. return req, nil
  215. }
  216. //reqUser the user which should be added to the context
  217. type reqUser int
  218. const (
  219. rGuest = iota
  220. rAdminUser
  221. rUser
  222. rInactiveAdminUser
  223. rInactiveUser
  224. )
  225. //request used to build an http.Request with specified values
  226. //url will not really considered as the requests are not send, the *http.Request is just passed directly to the controllers
  227. //pathvar is an array of key/value pairs used as dynamic query parameters such as /article/{id}
  228. type request struct {
  229. url string
  230. user reqUser
  231. method string
  232. values url.Values
  233. pathVar []pathVar
  234. multipartReq []multipartRequest
  235. }
  236. type multipartRequest struct {
  237. key string
  238. file string
  239. }
  240. type pathVar struct {
  241. key string
  242. value string
  243. }
  244. func (r request) buildRequest() *http.Request {
  245. var req *http.Request
  246. var err error
  247. if len(r.multipartReq) > 0 {
  248. req, err = postMultipart(r.url, r.multipartReq)
  249. } else if r.method == http.MethodPost {
  250. req, err = post(r.url, r.values)
  251. } else {
  252. req, err = get(r.url, r.values)
  253. }
  254. if err != nil {
  255. log.Print(err)
  256. }
  257. if r.pathVar != nil {
  258. for _, v := range r.pathVar {
  259. setHeader(req, v.key, v.value)
  260. }
  261. }
  262. var user *models.User
  263. if r.user == rGuest {
  264. return req
  265. } else {
  266. user, _ = ctx.UserService.GetByID(int(r.user))
  267. recorder := httptest.NewRecorder()
  268. session := ctx.SessionService.Create(recorder, req)
  269. session.SetValue("userid", user.ID)
  270. cookie := recorder.Result().Cookies()[0]
  271. req.AddCookie(cookie)
  272. }
  273. reqCtx := context.WithValue(req.Context(), middleware.UserContextKey, user)
  274. req = req.WithContext(reqCtx)
  275. return req
  276. }
  277. type responseWrapper struct {
  278. template *middleware.Template
  279. response *httptest.ResponseRecorder
  280. }
  281. func (r responseWrapper) getTemplateError() error {
  282. return r.template.Err
  283. }
  284. func (r responseWrapper) isCodeSuccess() bool {
  285. return r.response.Result().StatusCode == http.StatusOK
  286. }
  287. func (r responseWrapper) getStatus() int {
  288. return r.response.Result().StatusCode
  289. }
  290. func (r responseWrapper) getCookie(name string) (*http.Cookie, error) {
  291. for _, c := range r.response.Result().Cookies() {
  292. if c.Name == name {
  293. return c, nil
  294. }
  295. }
  296. return nil, fmt.Errorf("cookie %s not found", name)
  297. }