﻿/*
 * Copyright (c) 2009,2010 Yoshikazu Kuramochi
 * All rights reserved.
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#include "StdAfx.h"
#include "DirectShowInput.h"


#define RELEASE(p)						do { if (p) { p->Release(); p = NULL; } } while (false)
#define RETURN_HRESULT_IF_FAILED(hr)	do { HRESULT _hr = (hr); if (FAILED(_hr)) return _hr; } while (false)
#define BAIL_IF_FAILED(label, hr)		do { if (FAILED((hr))) goto label; } while (false)


InputBase::InputBase()
	: mGraphBuilder(NULL), mSampleGrabber(NULL),
	  mMediaControl(NULL), mMediaSeeking(NULL), mMediaEvent(NULL),
	  mDuration(0),
	  mRotRegister(0)
{
}

InputBase::~InputBase()
{
	RemoveFromRot();

	if (mSampleGrabber != NULL) {
		mSampleGrabber->SetCallback(NULL, 1);
	}

	RELEASE(mMediaEvent);
	RELEASE(mMediaSeeking);
	RELEASE(mMediaControl);
	RELEASE(mSampleGrabber);
	RELEASE(mGraphBuilder);
}

HRESULT InputBase::InitWithFile(LPCWSTR file)
{
	HRESULT hr;

	IBaseFilter*	grabberFilter = NULL;
	IBaseFilter*	sourceFilter = NULL;
	IBaseFilter*	nullRenderer = NULL;
	IMediaFilter*	mediaFilter = NULL;
	CMediaType		mediaType;


	// Graph

	BAIL_IF_FAILED(bail,
		hr = CoCreateInstance(CLSID_FilterGraph, NULL, CLSCTX_INPROC, IID_IGraphBuilder, (LPVOID*) &mGraphBuilder));


	// Sample Grabber

	BAIL_IF_FAILED(bail,
		hr = CoCreateInstance(CLSID_SampleGrabber, NULL, CLSCTX_INPROC, IID_IBaseFilter, (LPVOID*) &grabberFilter));

	BAIL_IF_FAILED(bail,
		hr = grabberFilter->QueryInterface(IID_ISampleGrabber, (LPVOID*) &mSampleGrabber));

	GetMediaType(&mediaType);

	BAIL_IF_FAILED(bail,
		hr = mSampleGrabber->SetMediaType(&mediaType));

	BAIL_IF_FAILED(bail,
		hr = mSampleGrabber->SetBufferSamples(FALSE));

	BAIL_IF_FAILED(bail,
		hr = mSampleGrabber->SetOneShot(TRUE));

	BAIL_IF_FAILED(bail,
		hr = mSampleGrabber->SetCallback(new Callback(this), 1));

	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->AddFilter(grabberFilter, L"Grabber"));


	// Source Filter

	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->AddSourceFilter(file, L"Source", &sourceFilter));


	// Null Renderer

	BAIL_IF_FAILED(bail,
		hr = CoCreateInstance(CLSID_NullRenderer, NULL, CLSCTX_INPROC, IID_IBaseFilter, (LPVOID*) &nullRenderer));

	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->AddFilter(nullRenderer, L"Null Renderer"));


	// Connect

	BAIL_IF_FAILED(bail,
		hr = ConnectSourceToGrabber(sourceFilter, grabberFilter));

	BAIL_IF_FAILED(bail,
		hr = Connect(grabberFilter, nullRenderer));


	// No reference clock is used

	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->QueryInterface(IID_IMediaFilter, (LPVOID*) &mediaFilter));

	BAIL_IF_FAILED(bail,
		hr = mediaFilter->SetSyncSource(NULL));



	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->QueryInterface(IID_IMediaControl, (LPVOID*) &mMediaControl));

	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->QueryInterface(IID_IMediaSeeking, (LPVOID*) &mMediaSeeking));

	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->QueryInterface(IID_IMediaEvent, (LPVOID*) &mMediaEvent));



	BAIL_IF_FAILED(bail,
		hr = mMediaControl->Pause());

	if (hr == S_FALSE) {
		for (int i = 0; i < 10; ++i) {
			OAFilterState state;
			hr = mMediaControl->GetState(1000, &state);
			if (hr == VFW_S_STATE_INTERMEDIATE) {
				continue;
			} else if (hr == VFW_S_CANT_CUE && state == State_Paused) {
				hr = S_OK;
			} else if (hr == S_OK && state != State_Paused) {
				// ここに来ることはあるのだろうか？
				hr = E_FAIL;
			}
			break;
		}
		if (hr != S_OK) {
			goto bail;
		}
	}



	BAIL_IF_FAILED(bail,
		hr = mMediaSeeking->SetTimeFormat(&TIME_FORMAT_MEDIA_TIME));

	BAIL_IF_FAILED(bail,
		hr = mMediaSeeking->GetDuration(&mDuration));



	AddToRot();

	hr = S_OK;

bail:
	RELEASE(mediaFilter);
	RELEASE(nullRenderer);
	RELEASE(sourceFilter);
	RELEASE(grabberFilter);
	return hr;
}

IPin* InputBase::GetPin(IBaseFilter* filter, PIN_DIRECTION pindir)
{
	bool		found = false;
	IEnumPins*	enumPins;
	IPin*		pin;

	filter->EnumPins(&enumPins);

	while (enumPins->Next(1, &pin, NULL) == S_OK) {
		PIN_DIRECTION pd;
		pin->QueryDirection(&pd);
		if (pd == pindir) {
			found = true;
			break;
		}
		pin->Release();
	}

	enumPins->Release();

	return found ? pin : NULL;
}

HRESULT InputBase::Connect(IBaseFilter* from, IBaseFilter* to)
{
	IPin* outPin = GetOutPin(from);
	IPin* inPin = GetInPin(to);

	HRESULT hr = mGraphBuilder->Connect(outPin, inPin);

	outPin->Release();
	inPin->Release();

	return hr;
}

void InputBase::AddToRot()
{
	IRunningObjectTable* rot;
	if (FAILED(GetRunningObjectTable(0, &rot))) {
		return;
	}

	WCHAR name[256];
	wsprintfW(name, L"FilterGraph %08x pid %08x", (DWORD_PTR)mGraphBuilder, GetCurrentProcessId());

	IMoniker* moniker;
	HRESULT hr = CreateItemMoniker(L"!", name, &moniker);
	if (SUCCEEDED(hr)) {
		hr = rot->Register(ROTFLAGS_REGISTRATIONKEEPSALIVE, mGraphBuilder, moniker, &mRotRegister);
		moniker->Release();
	}
	rot->Release();
}

void InputBase::RemoveFromRot()
{
	if (mRotRegister != 0) {
		IRunningObjectTable* rot;
		if (SUCCEEDED(GetRunningObjectTable(0, &rot))) {
			rot->Revoke(mRotRegister);
			rot->Release();
			mRotRegister = 0;
		}
	}
}


VideoInput::VideoInput()
	: InputBase(),
	  /*mForcibleARGB32(NULL),*/ mWidth(0), mHeight(0), mFrameCount(0), mBuffer(NULL)
{
}

VideoInput::~VideoInput()
{
	//if (mGraphBuilder != NULL && mForcibleARGB32 != NULL) {
	//	// RemoveFilter すると CUnknown の中で delete される。
	//	mGraphBuilder->RemoveFilter(mForcibleARGB32);
	//	mForcibleARGB32 = NULL;
	//}
}

HRESULT VideoInput::InitWithFile(LPCWSTR file)
{
	RETURN_HRESULT_IF_FAILED(
		InputBase::InitWithFile(file));


	CMediaType mediaType;

	RETURN_HRESULT_IF_FAILED(
		mSampleGrabber->GetConnectedMediaType(&mediaType));

	VIDEOINFO* vi = (VIDEOINFO*) mediaType.Format();
	mWidth = vi->bmiHeader.biWidth;
	mHeight = vi->bmiHeader.biHeight;


	RETURN_HRESULT_IF_FAILED(
		mMediaSeeking->SetTimeFormat(&TIME_FORMAT_FRAME));

	RETURN_HRESULT_IF_FAILED(
		mMediaSeeking->GetDuration(&mFrameCount));

	RETURN_HRESULT_IF_FAILED(
		mMediaSeeking->SetTimeFormat(&TIME_FORMAT_MEDIA_TIME));


	return S_OK;
}

void VideoInput::GetMediaType(CMediaType* mediaType)
{
	mediaType->SetType(&MEDIATYPE_Video);
	mediaType->SetSubtype(&MEDIASUBTYPE_ARGB32);
	mediaType->SetFormatType(&FORMAT_VideoInfo);
}

HRESULT VideoInput::ConnectSourceToGrabber(IBaseFilter* source, IBaseFilter* grabber)
{
	RETURN_HRESULT_IF_FAILED(
		Connect(source, grabber));


	// TODO アルファチャンネル強制モードの場合

	//mForcibleARGB32 = new ForcibleARGB32(&hr);
	//if (mForcibleARGB32 == NULL) {
	//	hr = E_OUTOFMEMORY;
	//}
	//RETURN_HRESULT_IF_FAILED(hr);

	//RETURN_HRESULT_IF_FAILED(
	//	mGraphBuilder->AddFilter(mForcibleARGB32, L"ForcibleARGB32"));

	//RETURN_HRESULT_IF_FAILED(
	//	Connect(source, mForcibleARGB32));

	//RETURN_HRESULT_IF_FAILED(
	//	Connect(mForcibleARGB32, grabber));


	return S_OK;
}

HRESULT VideoInput::BufferCB(void* buffer, long bufferSize)
{
	if (mBuffer == NULL) {
		return E_FAIL;
	}

	long dstBufferSize = mWidth * mHeight * 4;

	if (bufferSize < dstBufferSize) {
		CopyMemory(mBuffer, buffer, bufferSize);
	} else {
		CopyMemory(mBuffer, buffer, dstBufferSize);
	}

	return S_OK;
}

HRESULT VideoInput::FrameImageAtTime(LONGLONG time, void* buffer)
{
	mBuffer = buffer;

	RETURN_HRESULT_IF_FAILED(
		mMediaSeeking->SetPositions(&time, AM_SEEKING_AbsolutePositioning, NULL, AM_SEEKING_NoPositioning));

	RETURN_HRESULT_IF_FAILED(
		mMediaControl->Run());

	long eventCode;
	RETURN_HRESULT_IF_FAILED(
		mMediaEvent->WaitForCompletion(INFINITE, &eventCode));

	return S_OK;
}


AudioInput::AudioInput()
	: InputBase(),
	  mBuffer(NULL), mRestSize(0)
{
}

AudioInput::~AudioInput()
{
}

HRESULT AudioInput::InitWithFile(LPCWSTR file)
{
	RETURN_HRESULT_IF_FAILED(
		InputBase::InitWithFile(file));


	RETURN_HRESULT_IF_FAILED(
		mSampleGrabber->SetOneShot(FALSE));


	return S_OK;
}

void AudioInput::GetMediaType(CMediaType* mediaType)
{
	mediaType->SetType(&MEDIATYPE_Audio);
	mediaType->SetSubtype(&MEDIASUBTYPE_PCM);
	mediaType->SetFormatType(&FORMAT_WaveFormatEx);
}

HRESULT AudioInput::ConnectSourceToGrabber(IBaseFilter* source, IBaseFilter* grabber)
{
	RETURN_HRESULT_IF_FAILED(
		Connect(source, grabber));


	// TODO サンプルレートなどの変換が必要な場合


	return S_OK;
}

HRESULT AudioInput::BufferCB(void* buffer, long bufferSize)
{
	if (mBuffer == NULL) {
		return E_FAIL;
	}

	if (bufferSize >= mRestSize) {
		CopyMemory(mBuffer, buffer, mRestSize);
		mBuffer = NULL;
		mRestSize = 0;
		mMediaControl->Pause();
	} else {
		CopyMemory(mBuffer, buffer, bufferSize);
		mBuffer += bufferSize;
		mRestSize -= bufferSize;
	}

	return S_OK;
}

HRESULT AudioInput::FillBuffer(LONGLONG time, void* buffer, long bufferSize)
{
	mBuffer = (char*) buffer;
	mRestSize = bufferSize;

	RETURN_HRESULT_IF_FAILED(
		mMediaSeeking->SetPositions(&time, AM_SEEKING_AbsolutePositioning, NULL, AM_SEEKING_NoPositioning));

	RETURN_HRESULT_IF_FAILED(
		mMediaControl->Run());

	long eventCode;
	HRESULT hr = mMediaEvent->WaitForCompletion(INFINITE, &eventCode);
	if (mBuffer != NULL) {
		RETURN_HRESULT_IF_FAILED(hr);
		ZeroMemory(mBuffer, mRestSize);
	}

	return S_OK;
}


Callback::Callback(InputBase* input)
	: CUnknown(L"Callback", NULL), mInput(input)
{
}

Callback::~Callback()
{
}

STDMETHODIMP Callback::NonDelegatingQueryInterface(REFIID riid, void** ppv)
{
	if (riid == IID_ISampleGrabberCB) {
		return GetInterface((ISampleGrabberCB*)this, ppv);
	}
	return CUnknown::NonDelegatingQueryInterface(riid, ppv);
}

STDMETHODIMP Callback::SampleCB(double sampleTime, IMediaSample* sample)
{
	return E_NOTIMPL;
}

STDMETHODIMP Callback::BufferCB(double sampleTime, BYTE* buffer, long bufferSize)
{
	return mInput->BufferCB(buffer, bufferSize);
}
