package main import ( "bytes" "database/sql" "encoding/json" "net/http" "net/http/httptest" "os" "testing" "gitlab.com/digineat/go-broker-test/internal/model" _ "github.com/mattn/go-sqlite3" ) const createTradesQTableSQLTest = ` CREATE TABLE IF NOT EXISTS trades_q ( id INTEGER PRIMARY KEY AUTOINCREMENT, account TEXT NOT NULL, symbol TEXT NOT NULL, volume REAL NOT NULL, open_price REAL NOT NULL, close_price REAL NOT NULL, side TEXT NOT NULL CHECK(side IN ('buy', 'sell')), created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP );` const createAccountStatsTableSQLTest = ` CREATE TABLE IF NOT EXISTS account_stats ( account TEXT PRIMARY KEY, trades_count INTEGER NOT NULL DEFAULT 0, profit REAL NOT NULL DEFAULT 0.0, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP );` // setupTestDB initializes an in-memory SQLite database for tests. func setupTestDB(t *testing.T) *sql.DB { t.Helper() db, err := sql.Open("sqlite3", ":memory:") if err != nil { t.Fatalf("Failed to open in-memory database: %v", err) } if _, err := db.Exec(createTradesQTableSQLTest); err != nil { t.Fatalf("Failed to create trades_q table: %v", err) } if _, err := db.Exec(createAccountStatsTableSQLTest); err != nil { t.Fatalf("Failed to create account_stats table: %v", err) } return db } // createTestAppMux now uses the refactored handlers from the main package. func createTestAppMux(db *sql.DB) *http.ServeMux { app := &server{db: db} mux := http.NewServeMux() mux.HandleFunc("POST /trades", app.handlePostTrades()) mux.HandleFunc("GET /stats/{acc}", app.handleGetStats()) mux.HandleFunc("GET /healthz", app.handleGetHealthz()) return mux } func TestMain(m *testing.M) { code := m.Run() os.Exit(code) } func TestTradesHandler(t *testing.T) { db := setupTestDB(t) defer db.Close() mux := createTestAppMux(db) tt := []struct { name string payload string expectedStatus int verifyDB func(t *testing.T, testDB *sql.DB) }{ { name: "Valid trade", payload: `{"account":"acc1","symbol":"EURUSD","volume":1.0,"open":1.1,"close":1.2,"side":"buy"}`, expectedStatus: http.StatusOK, verifyDB: func(t *testing.T, testDB *sql.DB) { var count int err := testDB.QueryRow("SELECT COUNT(*) FROM trades_q WHERE account = ?", "acc1").Scan(&count) if err != nil { t.Fatalf("Failed to query db: %v", err) } if count != 1 { t.Errorf("Expected 1 trade in db, got %d", count) } }, }, { name: "Invalid JSON", payload: `{"account":"acc2","symbol":"EURUSD","volume":1.0,"open":1.1,"close":1.2,"side":"buy"`, expectedStatus: http.StatusBadRequest, }, { name: "Validation error - invalid symbol format (custom validator)", payload: `{"account":"acc3","symbol":"eurusd","volume":1.0,"open":1.1,"close":1.2,"side":"buy"}`, expectedStatus: http.StatusBadRequest, }, { name: "Validation error - invalid symbol pattern (regexp)", payload: `{"account":"acc3","symbol":"EUR123","volume":1.0,"open":1.1,"close":1.2,"side":"buy"}`, expectedStatus: http.StatusBadRequest, }, { name: "Validation error - missing required field (account)", payload: `{"symbol":"EURUSD","volume":1.0,"open":1.1,"close":1.2,"side":"buy"}`, expectedStatus: http.StatusBadRequest, }, } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { _, err := db.Exec("DELETE FROM trades_q") if err != nil { t.Fatalf("Failed to clear trades_q table: %v", err) } req, _ := http.NewRequest("POST", "/trades", bytes.NewBufferString(tc.payload)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() mux.ServeHTTP(rr, req) if rr.Code != tc.expectedStatus { t.Errorf("Expected status %d, got %d. Body: %s", tc.expectedStatus, rr.Code, rr.Body.String()) } if tc.verifyDB != nil { tc.verifyDB(t, db) } }) } } func TestStatsHandler(t *testing.T) { db := setupTestDB(t) defer db.Close() mux := createTestAppMux(db) _, err := db.Exec("INSERT INTO account_stats (account, trades_count, profit, updated_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP)", "acc1", 10, 123.45) if err != nil { t.Fatalf("Failed to insert test data: %v", err) } tt := []struct { name string accountID string expectedStatus int expectedBody model.AccountStats }{ { name: "Account exists", accountID: "acc1", expectedStatus: http.StatusOK, expectedBody: model.AccountStats{Account: "acc1", TradesCount: 10, Profit: 123.45}, }, { name: "Account does not exist", accountID: "nonexistent", expectedStatus: http.StatusOK, expectedBody: model.AccountStats{Account: "nonexistent", TradesCount: 0, Profit: 0}, }, } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { req, _ := http.NewRequest("GET", "/stats/"+tc.accountID, nil) rr := httptest.NewRecorder() mux.ServeHTTP(rr, req) if rr.Code != tc.expectedStatus { t.Errorf("Expected status %d, got %d. Body: %s", tc.expectedStatus, rr.Code, rr.Body.String()) } var actual model.AccountStats if err := json.Unmarshal(rr.Body.Bytes(), &actual); err != nil { t.Fatalf("Failed to unmarshal actual response body: %v. Body: %s", err, rr.Body.String()) } if actual.Account != tc.expectedBody.Account || actual.TradesCount != tc.expectedBody.TradesCount || actual.Profit != tc.expectedBody.Profit { t.Errorf("Expected body for account %s to be %+v, got %+v", tc.accountID, tc.expectedBody, actual) } }) } } func TestHealthzHandler(t *testing.T) { db := setupTestDB(t) muxHealthy := createTestAppMux(db) t.Run("DB OK", func(t *testing.T) { req, _ := http.NewRequest("GET", "/healthz", nil) rr := httptest.NewRecorder() muxHealthy.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected status %d, got %d", http.StatusOK, rr.Code) } }) db.Close() t.Run("DB Ping fails", func(t *testing.T) { req, _ := http.NewRequest("GET", "/healthz", nil) rr := httptest.NewRecorder() muxHealthy.ServeHTTP(rr, req) if rr.Code != http.StatusInternalServerError { t.Errorf("Expected status %d for closed DB, got %d", http.StatusInternalServerError, rr.Code) } }) }