// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package http2

import (
	"bytes"
	"io"
	"net/http"
	"os"
	"reflect"
	"slices"
	"testing"

	"golang.org/x/net/http2/hpack"
)

type testConnFramer struct {
	t   testing.TB
	fr  *Framer
	dec *hpack.Decoder
}

// readFrame reads the next frame.
// It returns nil if the conn is closed or no frames are available.
func (tf *testConnFramer) readFrame() Frame {
	tf.t.Helper()
	fr, err := tf.fr.ReadFrame()
	if err == io.EOF || err == os.ErrDeadlineExceeded {
		return nil
	}
	if err != nil {
		tf.t.Fatalf("ReadFrame: %v", err)
	}
	return fr
}

type readFramer interface {
	readFrame() Frame
}

// readFrame reads a frame of a specific type.
func readFrame[T any](t testing.TB, framer readFramer) T {
	t.Helper()
	var v T
	fr := framer.readFrame()
	if fr == nil {
		t.Fatalf("got no frame, want frame %T", v)
	}
	v, ok := fr.(T)
	if !ok {
		t.Fatalf("got frame %T, want %T", fr, v)
	}
	return v
}

// wantFrameType reads the next frame.
// It produces an error if the frame type is not the expected value.
func (tf *testConnFramer) wantFrameType(want FrameType) {
	tf.t.Helper()
	fr := tf.readFrame()
	if fr == nil {
		tf.t.Fatalf("got no frame, want frame %v", want)
	}
	if got := fr.Header().Type; got != want {
		tf.t.Fatalf("got frame %v, want %v", got, want)
	}
}

// wantUnorderedFrames reads frames until every condition in want has been satisfied.
//
// want is a list of func(*SomeFrame) bool.
// wantUnorderedFrames will call each func with frames of the appropriate type
// until the func returns true.
// It calls t.Fatal if an unexpected frame is received (no func has that frame type,
// or all funcs with that type have returned true), or if the framer runs out of frames
// with unsatisfied funcs.
//
// Example:
//
//	// Read a SETTINGS frame, and any number of DATA frames for a stream.
//	// The SETTINGS frame may appear anywhere in the sequence.
//	// The last DATA frame must indicate the end of the stream.
//	tf.wantUnorderedFrames(
//		func(f *SettingsFrame) bool {
//			return true
//		},
//		func(f *DataFrame) bool {
//			return f.StreamEnded()
//		},
//	)
func (tf *testConnFramer) wantUnorderedFrames(want ...any) {
	tf.t.Helper()
	want = slices.Clone(want)
	seen := 0
frame:
	for seen < len(want) && !tf.t.Failed() {
		fr := tf.readFrame()
		if fr == nil {
			break
		}
		for i, f := range want {
			if f == nil {
				continue
			}
			typ := reflect.TypeOf(f)
			if typ.Kind() != reflect.Func ||
				typ.NumIn() != 1 ||
				typ.NumOut() != 1 ||
				typ.Out(0) != reflect.TypeOf(true) {
				tf.t.Fatalf("expected func(*SomeFrame) bool, got %T", f)
			}
			if typ.In(0) == reflect.TypeOf(fr) {
				out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)})
				if out[0].Bool() {
					want[i] = nil
					seen++
				}
				continue frame
			}
		}
		tf.t.Errorf("got unexpected frame type %T", fr)
	}
	if seen < len(want) {
		for _, f := range want {
			if f == nil {
				continue
			}
			tf.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0))
		}
		tf.t.Fatalf("did not see %v expected frame types", len(want)-seen)
	}
}

type wantHeader struct {
	streamID  uint32
	endStream bool
	header    http.Header
}

// wantHeaders reads a HEADERS frame and potential CONTINUATION frames,
// and asserts that they contain the expected headers.
func (tf *testConnFramer) wantHeaders(want wantHeader) {
	tf.t.Helper()

	hf := readFrame[*HeadersFrame](tf.t, tf)
	if got, want := hf.StreamID, want.streamID; got != want {
		tf.t.Fatalf("got stream ID %v, want %v", got, want)
	}
	if got, want := hf.StreamEnded(), want.endStream; got != want {
		tf.t.Fatalf("got stream ended %v, want %v", got, want)
	}

	gotHeader := make(http.Header)
	tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
		gotHeader[hf.Name] = append(gotHeader[hf.Name], hf.Value)
	})
	defer tf.dec.SetEmitFunc(nil)
	if _, err := tf.dec.Write(hf.HeaderBlockFragment()); err != nil {
		tf.t.Fatalf("decoding HEADERS frame: %v", err)
	}
	headersEnded := hf.HeadersEnded()
	for !headersEnded {
		cf := readFrame[*ContinuationFrame](tf.t, tf)
		if cf == nil {
			tf.t.Fatalf("got end of frames, want CONTINUATION")
		}
		if _, err := tf.dec.Write(cf.HeaderBlockFragment()); err != nil {
			tf.t.Fatalf("decoding CONTINUATION frame: %v", err)
		}
		headersEnded = cf.HeadersEnded()
	}
	if err := tf.dec.Close(); err != nil {
		tf.t.Fatalf("hpack decoding error: %v", err)
	}

	for k, v := range want.header {
		if !reflect.DeepEqual(v, gotHeader[k]) {
			tf.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k])
		}
	}
}

// decodeHeader supports some older server tests.
// TODO: rewrite those tests to use newer, more convenient test APIs.
func (tf *testConnFramer) decodeHeader(headerBlock []byte) (pairs [][2]string) {
	tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
		if hf.Name == "date" {
			return
		}
		pairs = append(pairs, [2]string{hf.Name, hf.Value})
	})
	defer tf.dec.SetEmitFunc(nil)
	if _, err := tf.dec.Write(headerBlock); err != nil {
		tf.t.Fatalf("hpack decoding error: %v", err)
	}
	if err := tf.dec.Close(); err != nil {
		tf.t.Fatalf("hpack decoding error: %v", err)
	}
	return pairs
}

type wantData struct {
	streamID  uint32
	endStream bool
	size      int
	data      []byte
	multiple  bool // data may be spread across multiple DATA frames
}

// wantData reads zero or more DATA frames, and asserts that they match the expectation.
func (tf *testConnFramer) wantData(want wantData) {
	tf.t.Helper()
	gotSize := 0
	gotEndStream := false
	if want.data != nil {
		want.size = len(want.data)
	}
	var gotData []byte
	for {
		fr := tf.readFrame()
		if fr == nil {
			break
		}
		data, ok := fr.(*DataFrame)
		if !ok {
			tf.t.Fatalf("got frame %T, want DataFrame", fr)
		}
		if want.data != nil {
			gotData = append(gotData, data.Data()...)
		}
		gotSize += len(data.Data())
		if data.StreamEnded() {
			gotEndStream = true
			break
		}
		if !want.endStream && gotSize >= want.size {
			break
		}
		if !want.multiple {
			break
		}
	}
	if gotSize != want.size {
		tf.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size)
	}
	if gotEndStream != want.endStream {
		tf.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream)
	}
	if want.data != nil && !bytes.Equal(gotData, want.data) {
		tf.t.Fatalf("got data %q, want %q", gotData, want.data)
	}
}

func (tf *testConnFramer) wantRSTStream(streamID uint32, code ErrCode) {
	tf.t.Helper()
	fr := readFrame[*RSTStreamFrame](tf.t, tf)
	if fr.StreamID != streamID || fr.ErrCode != code {
		tf.t.Fatalf("got %v, want RST_STREAM StreamID=%v, code=%v", summarizeFrame(fr), streamID, code)
	}
}

func (tf *testConnFramer) wantSettings(want map[SettingID]uint32) {
	fr := readFrame[*SettingsFrame](tf.t, tf)
	if fr.Header().Flags.Has(FlagSettingsAck) {
		tf.t.Errorf("got SETTINGS frame with ACK set, want no ACK")
	}
	for wantID, wantVal := range want {
		gotVal, ok := fr.Value(wantID)
		if !ok {
			tf.t.Errorf("SETTINGS: %v is not set, want %v", wantID, wantVal)
		} else if gotVal != wantVal {
			tf.t.Errorf("SETTINGS: %v is %v, want %v", wantID, gotVal, wantVal)
		}
	}
	if tf.t.Failed() {
		tf.t.Fatalf("%v", fr)
	}
}

func (tf *testConnFramer) wantSettingsAck() {
	tf.t.Helper()
	fr := readFrame[*SettingsFrame](tf.t, tf)
	if !fr.Header().Flags.Has(FlagSettingsAck) {
		tf.t.Fatal("Settings Frame didn't have ACK set")
	}
}

func (tf *testConnFramer) wantGoAway(maxStreamID uint32, code ErrCode) {
	tf.t.Helper()
	fr := readFrame[*GoAwayFrame](tf.t, tf)
	if fr.LastStreamID != maxStreamID || fr.ErrCode != code {
		tf.t.Fatalf("got %v, want GOAWAY LastStreamID=%v, code=%v", summarizeFrame(fr), maxStreamID, code)
	}
}

func (tf *testConnFramer) wantWindowUpdate(streamID, incr uint32) {
	tf.t.Helper()
	wu := readFrame[*WindowUpdateFrame](tf.t, tf)
	if wu.FrameHeader.StreamID != streamID {
		tf.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
	}
	if wu.Increment != incr {
		tf.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
	}
}

func (tf *testConnFramer) wantClosed() {
	tf.t.Helper()
	fr, err := tf.fr.ReadFrame()
	if err == nil {
		tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr)
	}
	if err == os.ErrDeadlineExceeded {
		tf.t.Fatalf("connection is not closed; want it to be")
	}
}

func (tf *testConnFramer) wantIdle() {
	tf.t.Helper()
	fr, err := tf.fr.ReadFrame()
	if err == nil {
		tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr)
	}
	if err != os.ErrDeadlineExceeded {
		tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err)
	}
}

func (tf *testConnFramer) writeSettings(settings ...Setting) {
	tf.t.Helper()
	if err := tf.fr.WriteSettings(settings...); err != nil {
		tf.t.Fatal(err)
	}
}

func (tf *testConnFramer) writeSettingsAck() {
	tf.t.Helper()
	if err := tf.fr.WriteSettingsAck(); err != nil {
		tf.t.Fatal(err)
	}
}

func (tf *testConnFramer) writeData(streamID uint32, endStream bool, data []byte) {
	tf.t.Helper()
	if err := tf.fr.WriteData(streamID, endStream, data); err != nil {
		tf.t.Fatal(err)
	}
}

func (tf *testConnFramer) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
	tf.t.Helper()
	if err := tf.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
		tf.t.Fatal(err)
	}
}

func (tf *testConnFramer) writeHeaders(p HeadersFrameParam) {
	tf.t.Helper()
	if err := tf.fr.WriteHeaders(p); err != nil {
		tf.t.Fatal(err)
	}
}

// writeHeadersMode writes header frames, as modified by mode:
//
//   - noHeader: Don't write the header.
//   - oneHeader: Write a single HEADERS frame.
//   - splitHeader: Write a HEADERS frame and CONTINUATION frame.
func (tf *testConnFramer) writeHeadersMode(mode headerType, p HeadersFrameParam) {
	tf.t.Helper()
	switch mode {
	case noHeader:
	case oneHeader:
		tf.writeHeaders(p)
	case splitHeader:
		if len(p.BlockFragment) < 2 {
			panic("too small")
		}
		contData := p.BlockFragment[1:]
		contEnd := p.EndHeaders
		p.BlockFragment = p.BlockFragment[:1]
		p.EndHeaders = false
		tf.writeHeaders(p)
		tf.writeContinuation(p.StreamID, contEnd, contData)
	default:
		panic("bogus mode")
	}
}

func (tf *testConnFramer) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) {
	tf.t.Helper()
	if err := tf.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil {
		tf.t.Fatal(err)
	}
}

func (tf *testConnFramer) writePriority(id uint32, p PriorityParam) {
	if err := tf.fr.WritePriority(id, p); err != nil {
		tf.t.Fatal(err)
	}
}

func (tf *testConnFramer) writeRSTStream(streamID uint32, code ErrCode) {
	tf.t.Helper()
	if err := tf.fr.WriteRSTStream(streamID, code); err != nil {
		tf.t.Fatal(err)
	}
}

func (tf *testConnFramer) writePing(ack bool, data [8]byte) {
	tf.t.Helper()
	if err := tf.fr.WritePing(ack, data); err != nil {
		tf.t.Fatal(err)
	}
}

func (tf *testConnFramer) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
	tf.t.Helper()
	if err := tf.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
		tf.t.Fatal(err)
	}
}

func (tf *testConnFramer) writeWindowUpdate(streamID, incr uint32) {
	tf.t.Helper()
	if err := tf.fr.WriteWindowUpdate(streamID, incr); err != nil {
		tf.t.Fatal(err)
	}
}