package backupformat

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"os"
	"time"

	"github.com/hamba/avro/v2/ocf"
	"github.com/mattn/go-isatty"
	"github.com/natefinch/atomic"
	"github.com/rs/zerolog"
	"github.com/rs/zerolog/log"
	"github.com/schollz/progressbar/v3"
	"google.golang.org/protobuf/proto"

	v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"

	"github.com/authzed/zed/internal/console"
)

// Encoder represents the operations required to iteratively encode a backup
// of SpiceDB relationship data.
type Encoder interface {
	WriteSchema(schema, revision string) error

	// Append encodes an additional Relationship using the provided cursor to
	// keep track of progress.
	Append(r *v1.Relationship, cursor string) error

	// MarkComplete signals that the final relationship has been written and
	// that the process is complete.
	MarkComplete()
}

var (
	_ Encoder = (*MockEncoder)(nil)
	_ Encoder = (*RewriteEncoder)(nil)
	_ Encoder = (*OcfEncoder)(nil)
	_ Encoder = (*OcfFileEncoder)(nil)
	_ Encoder = (*ProgressRenderingEncoder)(nil)
)

type MockEncoder struct {
	Relationships []*v1.Relationship
	Cursors       []string
	Complete      bool
}

func (m *MockEncoder) Append(r *v1.Relationship, cursor string) error {
	m.Relationships = append(m.Relationships, r)
	m.Cursors = append(m.Cursors, cursor)
	return nil
}

func (m *MockEncoder) WriteSchema(_, _ string) error { return nil }
func (m *MockEncoder) MarkComplete()                 { m.Complete = true }

func WithRewriter(rw Rewriter, e Encoder) *RewriteEncoder {
	return &RewriteEncoder{Rewriter: rw, Encoder: e}
}

// RewriteEncoder implements `Encoder` by rewriting any relationships before
// passing it on to the provided Encoder.
type RewriteEncoder struct {
	Rewriter
	Encoder
}

func (e *RewriteEncoder) Append(r *v1.Relationship, cursor string) error {
	rel, err := e.RewriteRelationship(r)
	if err != nil {
		return err
	} else if rel == nil {
		return nil
	}
	return e.Encoder.Append(rel, cursor)
}

func (e *RewriteEncoder) MarshalZerologObject(event *zerolog.Event) {
	if obj, ok := e.Rewriter.(zerolog.LogObjectMarshaler); ok {
		event.EmbedObject(obj)
	}

	if obj, ok := e.Encoder.(zerolog.LogObjectMarshaler); ok {
		event.EmbedObject(obj)
	}
}

func (e *RewriteEncoder) Close() error {
	if closer, ok := e.Encoder.(io.Closer); ok {
		return closer.Close()
	}
	return nil
}

// OcfEncoder implements `Encoder` by formatting data in the AVRO OCF format.
type OcfEncoder struct {
	w   io.Writer
	enc *ocf.Encoder
}

func NewOcfEncoder(w io.Writer) *OcfEncoder {
	return &OcfEncoder{w: w}
}

func (e *OcfEncoder) encoder(revision string) (*ocf.Encoder, error) {
	if e.enc != nil {
		return e.enc, nil
	}

	avroSchema, err := avroSchemaV1()
	if err != nil {
		return nil, fmt.Errorf("unable to create avro schema: %w", err)
	}

	opts := []ocf.EncoderFunc{ocf.WithCodec(ocf.Snappy)}
	if revision != "" {
		opts = append(opts, ocf.WithMetadata(map[string][]byte{metadataKeyZT: []byte(revision)}))
	}

	e.enc, err = ocf.NewEncoder(avroSchema, e.w, opts...)
	if err != nil {
		return nil, fmt.Errorf("unable to create encoder: %w", err)
	}

	return e.enc, nil
}

func (e *OcfEncoder) WriteSchema(schema, revision string) error {
	enc, err := e.encoder(revision)
	if err != nil {
		return err
	}

	if err := enc.Encode(SchemaV1{SchemaText: schema}); err != nil {
		return fmt.Errorf("unable to encode SpiceDB schema object: %w", err)
	}

	return nil
}

func (e *OcfEncoder) MarshalZerologObject(event *zerolog.Event) {
	event.Str("format", "avro ocf")
}

func (e *OcfEncoder) Append(r *v1.Relationship, _ string) error {
	var toEncode RelationshipV1

	toEncode.ObjectType = r.Resource.ObjectType
	toEncode.ObjectID = r.Resource.ObjectId
	toEncode.Relation = r.Relation
	toEncode.SubjectObjectType = r.Subject.Object.ObjectType
	toEncode.SubjectObjectID = r.Subject.Object.ObjectId
	toEncode.SubjectRelation = r.Subject.OptionalRelation
	if r.OptionalCaveat != nil {
		contextBytes, err := proto.Marshal(r.OptionalCaveat.Context)
		if err != nil {
			return fmt.Errorf("error marshaling caveat context: %w", err)
		}

		toEncode.CaveatName = r.OptionalCaveat.CaveatName
		toEncode.CaveatContext = contextBytes
	}

	if r.OptionalExpiresAt != nil && !r.OptionalExpiresAt.AsTime().IsZero() {
		toEncode.Expiration = r.OptionalExpiresAt.AsTime()
	}

	encoder, err := e.encoder("")
	if err != nil {
		return err
	}

	if err := encoder.Encode(toEncode); err != nil {
		return fmt.Errorf("unable to encode relationship: %w", err)
	}

	return nil
}

func (e *OcfEncoder) MarkComplete() {}
func (e *OcfEncoder) Close() error {
	if err := e.enc.Flush(); err != nil {
		return fmt.Errorf("unable to flush encoder: %w", err)
	}
	return nil
}

// OcfFileEncoder implements `Encoder` by formatting data in the AVRO OCF
// format, while also persisting it to a file and maintaining a lockfile that
// tracks the progress so that it can be resumed if stopped.
type OcfFileEncoder struct {
	file             *os.File
	lastSyncedCursor string
	completed        bool
	*OcfEncoder
}

func (fe *OcfFileEncoder) lockFileName() string {
	return fe.file.Name() + ".lock"
}

func (fe *OcfFileEncoder) Cursor() (string, error) {
	cursorBytes, err := os.ReadFile(fe.lockFileName())
	if os.IsNotExist(err) {
		return "", errors.New("completed backup file already exists")
	} else if err != nil {
		return "", err
	}
	return string(cursorBytes), nil
}

func NewFileEncoder(filename string) (e *OcfFileEncoder, existed bool, err error) {
	_, err = os.Stat(filename)
	backupExisted := filename != "-" && err == nil

	var f *os.File
	if filename == "-" {
		f = os.Stdout
	} else {
		var err error
		f, err = os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o644)
		if err != nil {
			return nil, backupExisted, fmt.Errorf("unable to open backup file: %w", err)
		}
	}

	return &OcfFileEncoder{file: f, OcfEncoder: &OcfEncoder{w: f}}, backupExisted, nil
}

func (fe *OcfFileEncoder) Append(r *v1.Relationship, cursor string) error {
	if err := fe.OcfEncoder.Append(r, cursor); err != nil {
		return fmt.Errorf("error storing relationship: %w", err)
	}

	if cursor != fe.lastSyncedCursor { // Only write to disk when necessary
		if err := atomic.WriteFile(fe.lockFileName(), bytes.NewBufferString(cursor)); err != nil {
			return fmt.Errorf("failed to store cursor in lockfile: %w", err)
		}
		fe.lastSyncedCursor = cursor
	}

	return nil
}

func (fe *OcfFileEncoder) MarkComplete() { fe.completed = true }

func (fe *OcfFileEncoder) Close() error {
	// Don't throw any errors if the file is nil when flushing/closing.
	safeClose := func() error {
		if fe.file != nil && fe.enc != nil {
			fe.OcfEncoder.Close()
			return errors.Join(fe.file.Sync(), fe.file.Close())
		}
		return nil
	}

	removeCompleted := func(filename string) error {
		if fe.completed {
			return os.Remove(filename)
		}
		return nil
	}

	return errors.Join(
		safeClose(),
		removeCompleted(fe.lockFileName()),
	)
}

func (fe *OcfFileEncoder) MarshalZerologObject(e *zerolog.Event) {
	e.EmbedObject(fe.OcfEncoder).
		Str("file", fe.file.Name()).
		Str("lockFile", fe.lockFileName())
}

// ProgressRenderingEncoder implements `Encoder` by wrapping an existing Encoder
// and displaying its progress to the current tty.
type ProgressRenderingEncoder struct {
	relsProcessed uint64
	progressBar   *progressbar.ProgressBar
	startTime     time.Time
	ticker        <-chan time.Time
	Encoder
}

func WithProgress(e Encoder) *ProgressRenderingEncoder {
	return &ProgressRenderingEncoder{
		startTime: time.Now(),
		ticker:    time.Tick(5 * time.Second),
		Encoder:   e,
	}
}

func (pre *ProgressRenderingEncoder) bar() *progressbar.ProgressBar {
	if pre.progressBar == nil {
		pre.progressBar = console.CreateProgressBar("processing backup")
	}
	return pre.progressBar
}

func (pre *ProgressRenderingEncoder) Close() error {
	if err := pre.bar().Finish(); err != nil {
		return err
	}

	if closer, ok := pre.Encoder.(io.Closer); ok {
		return closer.Close()
	}
	return nil
}

func (pre *ProgressRenderingEncoder) MarshalZerologObject(e *zerolog.Event) {
	if obj, ok := pre.Encoder.(zerolog.LogObjectMarshaler); ok {
		e.EmbedObject(obj)
	}

	e.
		Uint64("processed", pre.relsProcessed).
		Uint64("throughput", perSec(pre.relsProcessed, time.Since(pre.startTime))).
		Stringer("elapsed", time.Since(pre.startTime).Round(time.Second))
}

func (pre *ProgressRenderingEncoder) Append(r *v1.Relationship, cursor string) error {
	pre.relsProcessed++
	if err := pre.Encoder.Append(r, cursor); err != nil {
		return err
	}

	if err := pre.bar().Add(1); err != nil {
		return fmt.Errorf("error incrementing progress bar: %w", err)
	}
	if !isatty.IsTerminal(os.Stderr.Fd()) { // Fallback for non-interactive tty
		select {
		case <-pre.ticker:
			log.Info().EmbedObject(pre).Msg("backup progress")
		default:
		}
	}
	return nil
}

func perSec(i uint64, d time.Duration) uint64 {
	secs := uint64(d.Seconds())
	if secs == 0 {
		return i
	}
	return i / secs
}
