go-broker-test/cmd/server/main_test.go

221 lines
6.3 KiB
Go
Raw Normal View History

2025-05-07 00:25:34 +03:00
package main
import (
"bytes"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"
"gitlab.com/digineat/go-broker-test/internal/model"
_ "github.com/mattn/go-sqlite3"
)
const createTradesQTableSQLTest = `
CREATE TABLE IF NOT EXISTS trades_q (
id INTEGER PRIMARY KEY AUTOINCREMENT,
account TEXT NOT NULL,
symbol TEXT NOT NULL,
volume REAL NOT NULL,
open_price REAL NOT NULL,
close_price REAL NOT NULL,
side TEXT NOT NULL CHECK(side IN ('buy', 'sell')),
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);`
const createAccountStatsTableSQLTest = `
CREATE TABLE IF NOT EXISTS account_stats (
account TEXT PRIMARY KEY,
trades_count INTEGER NOT NULL DEFAULT 0,
profit REAL NOT NULL DEFAULT 0.0,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);`
// setupTestDB initializes an in-memory SQLite database for tests.
func setupTestDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open in-memory database: %v", err)
}
if _, err := db.Exec(createTradesQTableSQLTest); err != nil {
t.Fatalf("Failed to create trades_q table: %v", err)
}
if _, err := db.Exec(createAccountStatsTableSQLTest); err != nil {
t.Fatalf("Failed to create account_stats table: %v", err)
}
return db
}
// createTestAppMux now uses the refactored handlers from the main package.
func createTestAppMux(db *sql.DB) *http.ServeMux {
app := &server{db: db}
mux := http.NewServeMux()
mux.HandleFunc("POST /trades", app.handlePostTrades())
mux.HandleFunc("GET /stats/{acc}", app.handleGetStats())
mux.HandleFunc("GET /healthz", app.handleGetHealthz())
return mux
}
func TestMain(m *testing.M) {
code := m.Run()
os.Exit(code)
}
func TestTradesHandler(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
mux := createTestAppMux(db)
tt := []struct {
name string
payload string
expectedStatus int
verifyDB func(t *testing.T, testDB *sql.DB)
}{
{
name: "Valid trade",
payload: `{"account":"acc1","symbol":"EURUSD","volume":1.0,"open":1.1,"close":1.2,"side":"buy"}`,
expectedStatus: http.StatusOK,
verifyDB: func(t *testing.T, testDB *sql.DB) {
var count int
err := testDB.QueryRow("SELECT COUNT(*) FROM trades_q WHERE account = ?", "acc1").Scan(&count)
if err != nil {
t.Fatalf("Failed to query db: %v", err)
}
if count != 1 {
t.Errorf("Expected 1 trade in db, got %d", count)
}
},
},
{
name: "Invalid JSON",
payload: `{"account":"acc2","symbol":"EURUSD","volume":1.0,"open":1.1,"close":1.2,"side":"buy"`,
expectedStatus: http.StatusBadRequest,
},
{
name: "Validation error - invalid symbol format (custom validator)",
payload: `{"account":"acc3","symbol":"eurusd","volume":1.0,"open":1.1,"close":1.2,"side":"buy"}`,
expectedStatus: http.StatusBadRequest,
},
{
name: "Validation error - invalid symbol pattern (regexp)",
payload: `{"account":"acc3","symbol":"EUR123","volume":1.0,"open":1.1,"close":1.2,"side":"buy"}`,
expectedStatus: http.StatusBadRequest,
},
{
name: "Validation error - missing required field (account)",
payload: `{"symbol":"EURUSD","volume":1.0,"open":1.1,"close":1.2,"side":"buy"}`,
expectedStatus: http.StatusBadRequest,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
_, err := db.Exec("DELETE FROM trades_q")
if err != nil {
t.Fatalf("Failed to clear trades_q table: %v", err)
}
req, _ := http.NewRequest("POST", "/trades", bytes.NewBufferString(tc.payload))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", tc.expectedStatus, rr.Code, rr.Body.String())
}
if tc.verifyDB != nil {
tc.verifyDB(t, db)
}
})
}
}
func TestStatsHandler(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
mux := createTestAppMux(db)
_, err := db.Exec("INSERT INTO account_stats (account, trades_count, profit, updated_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP)", "acc1", 10, 123.45)
if err != nil {
t.Fatalf("Failed to insert test data: %v", err)
}
tt := []struct {
name string
accountID string
expectedStatus int
expectedBody model.AccountStats
}{
{
name: "Account exists",
accountID: "acc1",
expectedStatus: http.StatusOK,
expectedBody: model.AccountStats{Account: "acc1", TradesCount: 10, Profit: 123.45},
},
{
name: "Account does not exist",
accountID: "nonexistent",
expectedStatus: http.StatusOK,
expectedBody: model.AccountStats{Account: "nonexistent", TradesCount: 0, Profit: 0},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", "/stats/"+tc.accountID, nil)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", tc.expectedStatus, rr.Code, rr.Body.String())
}
var actual model.AccountStats
if err := json.Unmarshal(rr.Body.Bytes(), &actual); err != nil {
t.Fatalf("Failed to unmarshal actual response body: %v. Body: %s", err, rr.Body.String())
}
if actual.Account != tc.expectedBody.Account || actual.TradesCount != tc.expectedBody.TradesCount || actual.Profit != tc.expectedBody.Profit {
t.Errorf("Expected body for account %s to be %+v, got %+v", tc.accountID, tc.expectedBody, actual)
}
})
}
}
func TestHealthzHandler(t *testing.T) {
db := setupTestDB(t)
muxHealthy := createTestAppMux(db)
t.Run("DB OK", func(t *testing.T) {
req, _ := http.NewRequest("GET", "/healthz", nil)
rr := httptest.NewRecorder()
muxHealthy.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, rr.Code)
}
})
db.Close()
t.Run("DB Ping fails", func(t *testing.T) {
req, _ := http.NewRequest("GET", "/healthz", nil)
rr := httptest.NewRecorder()
muxHealthy.ServeHTTP(rr, req)
if rr.Code != http.StatusInternalServerError {
t.Errorf("Expected status %d for closed DB, got %d", http.StatusInternalServerError, rr.Code)
}
})
}