221 lines
6.3 KiB
Go
221 lines
6.3 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"testing"
|
|
|
|
"gitlab.com/digineat/go-broker-test/internal/model"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
const createTradesQTableSQLTest = `
|
|
CREATE TABLE IF NOT EXISTS trades_q (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
account TEXT NOT NULL,
|
|
symbol TEXT NOT NULL,
|
|
volume REAL NOT NULL,
|
|
open_price REAL NOT NULL,
|
|
close_price REAL NOT NULL,
|
|
side TEXT NOT NULL CHECK(side IN ('buy', 'sell')),
|
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
);`
|
|
|
|
const createAccountStatsTableSQLTest = `
|
|
CREATE TABLE IF NOT EXISTS account_stats (
|
|
account TEXT PRIMARY KEY,
|
|
trades_count INTEGER NOT NULL DEFAULT 0,
|
|
profit REAL NOT NULL DEFAULT 0.0,
|
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
);`
|
|
|
|
// setupTestDB initializes an in-memory SQLite database for tests.
|
|
func setupTestDB(t *testing.T) *sql.DB {
|
|
t.Helper()
|
|
db, err := sql.Open("sqlite3", ":memory:")
|
|
if err != nil {
|
|
t.Fatalf("Failed to open in-memory database: %v", err)
|
|
}
|
|
|
|
if _, err := db.Exec(createTradesQTableSQLTest); err != nil {
|
|
t.Fatalf("Failed to create trades_q table: %v", err)
|
|
}
|
|
if _, err := db.Exec(createAccountStatsTableSQLTest); err != nil {
|
|
t.Fatalf("Failed to create account_stats table: %v", err)
|
|
}
|
|
return db
|
|
}
|
|
|
|
// createTestAppMux now uses the refactored handlers from the main package.
|
|
func createTestAppMux(db *sql.DB) *http.ServeMux {
|
|
app := &server{db: db}
|
|
|
|
mux := http.NewServeMux()
|
|
|
|
mux.HandleFunc("POST /trades", app.handlePostTrades())
|
|
mux.HandleFunc("GET /stats/{acc}", app.handleGetStats())
|
|
mux.HandleFunc("GET /healthz", app.handleGetHealthz())
|
|
return mux
|
|
}
|
|
|
|
func TestMain(m *testing.M) {
|
|
code := m.Run()
|
|
os.Exit(code)
|
|
}
|
|
|
|
func TestTradesHandler(t *testing.T) {
|
|
db := setupTestDB(t)
|
|
defer db.Close()
|
|
mux := createTestAppMux(db)
|
|
|
|
tt := []struct {
|
|
name string
|
|
payload string
|
|
expectedStatus int
|
|
verifyDB func(t *testing.T, testDB *sql.DB)
|
|
}{
|
|
{
|
|
name: "Valid trade",
|
|
payload: `{"account":"acc1","symbol":"EURUSD","volume":1.0,"open":1.1,"close":1.2,"side":"buy"}`,
|
|
expectedStatus: http.StatusOK,
|
|
verifyDB: func(t *testing.T, testDB *sql.DB) {
|
|
var count int
|
|
err := testDB.QueryRow("SELECT COUNT(*) FROM trades_q WHERE account = ?", "acc1").Scan(&count)
|
|
if err != nil {
|
|
t.Fatalf("Failed to query db: %v", err)
|
|
}
|
|
if count != 1 {
|
|
t.Errorf("Expected 1 trade in db, got %d", count)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "Invalid JSON",
|
|
payload: `{"account":"acc2","symbol":"EURUSD","volume":1.0,"open":1.1,"close":1.2,"side":"buy"`,
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Validation error - invalid symbol format (custom validator)",
|
|
payload: `{"account":"acc3","symbol":"eurusd","volume":1.0,"open":1.1,"close":1.2,"side":"buy"}`,
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Validation error - invalid symbol pattern (regexp)",
|
|
payload: `{"account":"acc3","symbol":"EUR123","volume":1.0,"open":1.1,"close":1.2,"side":"buy"}`,
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Validation error - missing required field (account)",
|
|
payload: `{"symbol":"EURUSD","volume":1.0,"open":1.1,"close":1.2,"side":"buy"}`,
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tt {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
_, err := db.Exec("DELETE FROM trades_q")
|
|
if err != nil {
|
|
t.Fatalf("Failed to clear trades_q table: %v", err)
|
|
}
|
|
|
|
req, _ := http.NewRequest("POST", "/trades", bytes.NewBufferString(tc.payload))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
rr := httptest.NewRecorder()
|
|
mux.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != tc.expectedStatus {
|
|
t.Errorf("Expected status %d, got %d. Body: %s", tc.expectedStatus, rr.Code, rr.Body.String())
|
|
}
|
|
|
|
if tc.verifyDB != nil {
|
|
tc.verifyDB(t, db)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStatsHandler(t *testing.T) {
|
|
db := setupTestDB(t)
|
|
defer db.Close()
|
|
mux := createTestAppMux(db)
|
|
|
|
_, err := db.Exec("INSERT INTO account_stats (account, trades_count, profit, updated_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP)", "acc1", 10, 123.45)
|
|
if err != nil {
|
|
t.Fatalf("Failed to insert test data: %v", err)
|
|
}
|
|
|
|
tt := []struct {
|
|
name string
|
|
accountID string
|
|
expectedStatus int
|
|
expectedBody model.AccountStats
|
|
}{
|
|
{
|
|
name: "Account exists",
|
|
accountID: "acc1",
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: model.AccountStats{Account: "acc1", TradesCount: 10, Profit: 123.45},
|
|
},
|
|
{
|
|
name: "Account does not exist",
|
|
accountID: "nonexistent",
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: model.AccountStats{Account: "nonexistent", TradesCount: 0, Profit: 0},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tt {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req, _ := http.NewRequest("GET", "/stats/"+tc.accountID, nil)
|
|
rr := httptest.NewRecorder()
|
|
mux.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != tc.expectedStatus {
|
|
t.Errorf("Expected status %d, got %d. Body: %s", tc.expectedStatus, rr.Code, rr.Body.String())
|
|
}
|
|
|
|
var actual model.AccountStats
|
|
if err := json.Unmarshal(rr.Body.Bytes(), &actual); err != nil {
|
|
t.Fatalf("Failed to unmarshal actual response body: %v. Body: %s", err, rr.Body.String())
|
|
}
|
|
|
|
if actual.Account != tc.expectedBody.Account || actual.TradesCount != tc.expectedBody.TradesCount || actual.Profit != tc.expectedBody.Profit {
|
|
t.Errorf("Expected body for account %s to be %+v, got %+v", tc.accountID, tc.expectedBody, actual)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHealthzHandler(t *testing.T) {
|
|
db := setupTestDB(t)
|
|
muxHealthy := createTestAppMux(db)
|
|
|
|
t.Run("DB OK", func(t *testing.T) {
|
|
req, _ := http.NewRequest("GET", "/healthz", nil)
|
|
rr := httptest.NewRecorder()
|
|
muxHealthy.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status %d, got %d", http.StatusOK, rr.Code)
|
|
}
|
|
})
|
|
|
|
db.Close()
|
|
|
|
t.Run("DB Ping fails", func(t *testing.T) {
|
|
req, _ := http.NewRequest("GET", "/healthz", nil)
|
|
rr := httptest.NewRecorder()
|
|
muxHealthy.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusInternalServerError {
|
|
t.Errorf("Expected status %d for closed DB, got %d", http.StatusInternalServerError, rr.Code)
|
|
}
|
|
})
|
|
}
|