/*
 * Copyright 2019-present Facebook, Inc.
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package gotest

import (
	"fmt"
	"net"
	"reflect"
	"testing"
	"thrift/lib/go/thrift"
	"thrift/test/go/if/thrifttest"
	"time"
)

const localConnTimeout = time.Second * 1
const testCallString = "this is a fairly lengthy test string \\ that ' has \x20 some 东西奇怪的"

// createTestHeaderServer Create and bind a test server to localhost
func createTestHeaderServer(handler thrifttest.ThriftTest) (*thrift.SimpleServer, net.Addr, error) {
	processor := thrifttest.NewThriftTestProcessor(handler)
	transportFactory := thrift.NewHeaderTransportFactory(thrift.NewTransportFactory())
	protocolFactory := thrift.NewHeaderProtocolFactory()

	transport, err := thrift.NewServerSocket("[::]:0")
	if err != nil {
		return nil, nil, fmt.Errorf("failed to open test socket: %s", err)
	}

	err = transport.Listen()
	if err != nil {
		return nil, nil, fmt.Errorf("failed to listen on socket: %s", err)
	}
	taddr := transport.Addr()

	server := thrift.NewSimpleServerContext(processor, transport,
		thrift.TransportFactories(transportFactory),
		thrift.ProtocolFactories(protocolFactory))
	go func(server *thrift.SimpleServer) {
		err = server.Serve()
		if err != nil && err != thrift.ErrServerClosed {
			panic(fmt.Errorf("failed to begin serving test socket: %s", err))
		}
	}(server)

	conn, err := net.DialTimeout(taddr.Network(), taddr.String(), localConnTimeout)
	if err != nil {
		return nil, nil, fmt.Errorf(
			"failed to connect to test socket: %s:%s", taddr.Network(), taddr.String(),
		)
	}
	conn.Close()

	return server, taddr, nil
}

// connectTestHeaderServer Create a client and connect to a test server
func connectTestHeaderServer(
	addr net.Addr,
	transportFactory thrift.TransportFactory,
	protocolFactory thrift.ProtocolFactory,
) (*thrifttest.ThriftTestClient, error) {
	var trans thrift.Transport
	trans, err := thrift.NewSocket(thrift.SocketAddr(addr.String()), thrift.SocketTimeout(localConnTimeout))
	if err != nil {
		return nil, err
	}

	err = trans.Open()

	if err != nil {
		return nil, err
	}

	if transportFactory != nil {
		trans = transportFactory.GetTransport(trans)
	}

	prot := protocolFactory.GetProtocol(trans)
	return thrifttest.NewThriftTestClient(trans, prot, prot), nil
}

func doClientTest(t *testing.T, transportFactory thrift.TransportFactory, protocolFactory thrift.ProtocolFactory) {
	handler := &testHandler{}
	serv, addr, err := createTestHeaderServer(handler)
	if err != nil {
		t.Fatalf("failed to create test server: %s", err.Error())
	}
	defer serv.Stop()

	client, err := connectTestHeaderServer(addr, transportFactory, protocolFactory)
	if err != nil {
		t.Fatalf("failed to connect to test server: %s", err.Error())
	}
	defer client.Close()

	res, err := client.DoTestString(testCallString)
	if err != nil {
		t.Fatalf("failed to query test server: %s", err.Error())
	}

	if res != testCallString {
		t.Fatalf("server query compare failed")
	}

	// Try sending a lot of requests
	for i := 0; i < 1000; i++ {
		res, err = client.DoTestString(testCallString)
		if err != nil {
			t.Fatalf("failed to query test server: %s", err.Error())
		}
		if res != testCallString {
			t.Fatalf("server query compare failed")
		}
	}

	// Try getting an application Exception
	exp1 := thrifttest.NewXception()
	exp1.ErrorCode = 5
	exp1.Message = testCallString
	handler.ReturnError = exp1

	err = client.DoTestException(testCallString)
	if texp, ok := err.(*thrifttest.Xception); ok && texp != nil {
		if texp.ErrorCode != 5 || texp.Message != testCallString {
			t.Fatalf("application exception values incorrect: got=%s", texp.String())
		}
	} else {
		t.Fatalf("application exception type incorrect: got=%v", err)
	}
	handler.ReturnError = nil

	// Make a large-ish struct
	insanity := thrifttest.NewInsanity()
	insanity.UserMap = map[thrifttest.Numberz]thrifttest.UserId{}
	insanity.Str2str = map[string]string{}
	for i := 0; i < 50000; i++ {
		insanity.UserMap[thrifttest.Numberz_SIX] = thrifttest.UserId(i)
		insanity.Xtructs = append(insanity.Xtructs, &thrifttest.Xtruct{
			StringThing: testCallString, ByteThing: 5, I32Thing: 50, I64Thing: 100,
		})
		insanity.Str2str[fmt.Sprintf("%d", i)] = testCallString
	}

	// Try sending a lot of large things
	for i := 0; i < 10; i++ {
		resp, terr := client.DoTestInsanity(insanity)
		if terr != nil {
			t.Fatalf("failed to query test server: %s", err.Error())
		}

		num, ok := resp[thrifttest.UserId(3)]
		if !ok {
			t.Fatalf("incorrect response from server on insanity")
		}

		data, ok := num[thrifttest.Numberz_EIGHT]
		if !ok {
			t.Fatalf("incorrect response from server on insanity")
		}

		if !reflect.DeepEqual(data, insanity) {
			t.Fatalf("incorrect response from server on insanity")
		}
	}
}

func TestHeaderHeader(t *testing.T) {
	doClientTest(
		t,
		thrift.NewHeaderTransportFactory(thrift.NewTransportFactory()),
		thrift.NewHeaderProtocolFactory(),
	)
}

func TestHeaderFramedBinary(t *testing.T) {
	doClientTest(
		t,
		thrift.NewFramedTransportFactory(thrift.NewTransportFactory()),
		thrift.NewBinaryProtocolFactory(false, true),
	)
}

func TestHeaderFramedCompact(t *testing.T) {
	doClientTest(
		t,
		thrift.NewFramedTransportFactory(thrift.NewTransportFactory()),
		thrift.NewCompactProtocolFactory(),
	)
}

// unframed not supported?
// func TestHeaderUnframedBinary(t *testing.T) {
// 	doClientTest(
// 		t,
// 		thrift.NewBufferedTransportFactory(8192),
// 		thrift.NewBinaryProtocolFactory(false, true),
// 	)
// }
//
// func TestHeaderUnframedCompact(t *testing.T) {
// 	doClientTest(
// 		t,
// 		thrift.NewBufferedTransportFactory(8192),
// 		thrift.NewCompactProtocolFactory(),
// 	)
// }
