﻿/*
 * Copyright (c) 2009,2010 Yoshikazu Kuramochi
 * All rights reserved.
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#include "StdAfx.h"
#include "DirectShowInput.h"
#include <MMReg.h>


#define RELEASE(p)						do { if (p) { p->Release(); p = NULL; } } while (false)
#define RETURN_HRESULT_IF_FAILED(hr)	do { HRESULT _hr = (hr); if (FAILED(_hr)) return _hr; } while (false)
#define BAIL_IF_FAILED(label, hr)		do { if (FAILED((hr))) goto label; } while (false)


InputBase::InputBase()
	: mGraphBuilder(NULL), mSampleGrabber(NULL),
	  mMediaControl(NULL), mMediaSeeking(NULL), mMediaEvent(NULL),
	  mDuration(0),
	  mRotRegister(0)
{
}

InputBase::~InputBase()
{
	RemoveFromRot();

	if (mSampleGrabber != NULL) {
		mSampleGrabber->SetCallback(NULL, 1);
	}

	RELEASE(mMediaEvent);
	RELEASE(mMediaSeeking);
	RELEASE(mMediaControl);
	RELEASE(mSampleGrabber);
	RELEASE(mGraphBuilder);
}

HRESULT InputBase::InitWithFile(LPCWSTR file)
{
	HRESULT hr;

	IBaseFilter*	grabberFilter = NULL;
	IBaseFilter*	sourceFilter = NULL;
	IBaseFilter*	nullRenderer = NULL;
	IMediaFilter*	mediaFilter = NULL;
	CMediaType		mediaType;


	// Graph

	BAIL_IF_FAILED(bail,
		hr = CoCreateInstance(CLSID_FilterGraph, NULL, CLSCTX_INPROC, IID_IGraphBuilder, (LPVOID*) &mGraphBuilder));


	// Sample Grabber

	BAIL_IF_FAILED(bail,
		hr = CoCreateInstance(CLSID_SampleGrabber, NULL, CLSCTX_INPROC, IID_IBaseFilter, (LPVOID*) &grabberFilter));

	BAIL_IF_FAILED(bail,
		hr = grabberFilter->QueryInterface(IID_ISampleGrabber, (LPVOID*) &mSampleGrabber));

	GetMediaType(&mediaType);

	BAIL_IF_FAILED(bail,
		hr = mSampleGrabber->SetMediaType(&mediaType));

	BAIL_IF_FAILED(bail,
		hr = mSampleGrabber->SetBufferSamples(FALSE));

	BAIL_IF_FAILED(bail,
		hr = mSampleGrabber->SetOneShot(TRUE));

	BAIL_IF_FAILED(bail,
		hr = mSampleGrabber->SetCallback(new Callback(this), 1));

	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->AddFilter(grabberFilter, L"Grabber"));


	// Source Filter

	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->AddSourceFilter(file, L"Source", &sourceFilter));


	// Null Renderer

	BAIL_IF_FAILED(bail,
		hr = CoCreateInstance(CLSID_NullRenderer, NULL, CLSCTX_INPROC, IID_IBaseFilter, (LPVOID*) &nullRenderer));

	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->AddFilter(nullRenderer, L"Null Renderer"));


	// Connect

	BAIL_IF_FAILED(bail,
		hr = ConnectSourceToGrabber(sourceFilter, grabberFilter));

	BAIL_IF_FAILED(bail,
		hr = Connect(grabberFilter, nullRenderer));


	// No reference clock is used

	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->QueryInterface(IID_IMediaFilter, (LPVOID*) &mediaFilter));

	BAIL_IF_FAILED(bail,
		hr = mediaFilter->SetSyncSource(NULL));



	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->QueryInterface(IID_IMediaControl, (LPVOID*) &mMediaControl));

	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->QueryInterface(IID_IMediaSeeking, (LPVOID*) &mMediaSeeking));

	BAIL_IF_FAILED(bail,
		hr = mGraphBuilder->QueryInterface(IID_IMediaEvent, (LPVOID*) &mMediaEvent));



	BAIL_IF_FAILED(bail,
		hr = mMediaControl->Pause());

	if (hr == S_FALSE) {
		for (int i = 0; i < 10; ++i) {
			OAFilterState state;
			hr = mMediaControl->GetState(1000, &state);
			if (hr == VFW_S_STATE_INTERMEDIATE) {
				continue;
			} else if (hr == VFW_S_CANT_CUE && state == State_Paused) {
				hr = S_OK;
			} else if (hr == S_OK && state != State_Paused) {
				// ここに来ることはあるのだろうか？
				hr = E_FAIL;
			}
			break;
		}
		if (hr != S_OK) {
			goto bail;
		}
	}



	BAIL_IF_FAILED(bail,
		hr = mMediaSeeking->SetTimeFormat(&TIME_FORMAT_MEDIA_TIME));

	BAIL_IF_FAILED(bail,
		hr = mMediaSeeking->GetDuration(&mDuration));



	AddToRot();

	hr = S_OK;

bail:
	RELEASE(mediaFilter);
	RELEASE(nullRenderer);
	RELEASE(sourceFilter);
	RELEASE(grabberFilter);
	return hr;
}

HRESULT InputBase::NextPin(IEnumPins* enumPins, PIN_DIRECTION pindir, IPin** pin)
{
	HRESULT hr;

	IPin* tmpPin;
	while ((hr = enumPins->Next(1, &tmpPin, NULL)) == S_OK) {
		PIN_DIRECTION pd;
		tmpPin->QueryDirection(&pd);
		if (pd == pindir) {
			*pin = tmpPin;
			return S_OK;
		}
		tmpPin->Release();
	}

	return hr;
}

HRESULT InputBase::FirstPin(IBaseFilter* filter, PIN_DIRECTION pindir, IPin** pin)
{
	HRESULT hr;
	IEnumPins* enumPins = NULL;

	BAIL_IF_FAILED(bail,
		hr = filter->EnumPins(&enumPins));

	BAIL_IF_FAILED(bail,
		hr = NextPin(enumPins, pindir, pin));

bail:
	RELEASE(enumPins);
	return hr;
}

HRESULT InputBase::Connect(IBaseFilter* from, IBaseFilter* to)
{
	HRESULT hr;

	IEnumPins* outEnum = NULL;
	IEnumPins* inEnum = NULL;
	IPin* outPin = NULL;
	IPin* inPin = NULL;

	BAIL_IF_FAILED(bail,
		hr = from->EnumPins(&outEnum));

	BAIL_IF_FAILED(bail,
		hr = to->EnumPins(&inEnum));

	while (true) {
		BAIL_IF_FAILED(bail,
			hr = NextPin(outEnum, PINDIR_OUTPUT, &outPin));
		if (hr != S_OK) break;

		inEnum->Reset();

		while (true) {
			BAIL_IF_FAILED(bail,
				hr = NextPin(inEnum, PINDIR_INPUT, &inPin));
			if (hr != S_OK) break;

			hr = mGraphBuilder->Connect(outPin, inPin);
			if (SUCCEEDED(hr)) goto bail;

			RELEASE(inPin);
		}

		RELEASE(outPin);
	}

bail:
	RELEASE(inPin);
	RELEASE(outPin);
	RELEASE(inEnum);
	RELEASE(outEnum);
	return hr;
}

void InputBase::AddToRot()
{
	IRunningObjectTable* rot;
	if (FAILED(GetRunningObjectTable(0, &rot))) {
		return;
	}

	WCHAR name[256];
	wsprintfW(name, L"FilterGraph %08x pid %08x", (DWORD_PTR)mGraphBuilder, GetCurrentProcessId());

	IMoniker* moniker;
	HRESULT hr = CreateItemMoniker(L"!", name, &moniker);
	if (SUCCEEDED(hr)) {
		hr = rot->Register(ROTFLAGS_REGISTRATIONKEEPSALIVE, mGraphBuilder, moniker, &mRotRegister);
		moniker->Release();
	}
	rot->Release();
}

void InputBase::RemoveFromRot()
{
	if (mRotRegister != 0) {
		IRunningObjectTable* rot;
		if (SUCCEEDED(GetRunningObjectTable(0, &rot))) {
			rot->Revoke(mRotRegister);
			rot->Release();
			mRotRegister = 0;
		}
	}
}


VideoInput::VideoInput()
	: InputBase(),
	  mForcibleARGB32(NULL), mWidth(0), mHeight(0), mFrameDuration(0), mFourCC(0), mBuffer(NULL)
{
}

VideoInput::~VideoInput()
{
	RELEASE(mForcibleARGB32);
}

HRESULT VideoInput::InitWithFile(LPCWSTR file)
{
	RETURN_HRESULT_IF_FAILED(
		InputBase::InitWithFile(file));


	CMediaType mediaType;

	RETURN_HRESULT_IF_FAILED(
		mSampleGrabber->GetConnectedMediaType(&mediaType));

	VIDEOINFO* vi = (VIDEOINFO*) mediaType.Format();
	mWidth = vi->bmiHeader.biWidth;
	mHeight = vi->bmiHeader.biHeight;
	mFrameDuration = vi->AvgTimePerFrame;
	mFourCC = FourCC();

	return S_OK;
}

DWORD VideoInput::FourCC()
{
	DWORD fourcc = 0;

	IBaseFilter* aviSplitter = NULL;
	IEnumPins* enumPins = NULL;

	BAIL_IF_FAILED(bail,
		mGraphBuilder->FindFilterByName(L"AVI Splitter", &aviSplitter));

	BAIL_IF_FAILED(bail,
		aviSplitter->EnumPins(&enumPins));

	for (IPin* pin = NULL; fourcc == 0 && NextPin(enumPins, PINDIR_OUTPUT, &pin) == S_OK; ) {
		CMediaType mt;
		if (pin->ConnectionMediaType(&mt) == S_OK && IsEqualGUID(*mt.Type(), MEDIATYPE_Video)) {
			fourcc = FOURCCMap(mt.Subtype()).GetFOURCC();
		}
		RELEASE(pin);
	}

bail:
	RELEASE(enumPins);
	RELEASE(aviSplitter);
	return fourcc;
}

void VideoInput::GetMediaType(CMediaType* mediaType)
{
	mediaType->SetType(&MEDIATYPE_Video);
	mediaType->SetSubtype(&MEDIASUBTYPE_ARGB32);
	mediaType->SetFormatType(&FORMAT_VideoInfo);
}

HRESULT VideoInput::ConnectSourceToGrabber(IBaseFilter* source, IBaseFilter* grabber)
{
	HRESULT hr = S_OK;

	mForcibleARGB32 = new ForcibleARGB32(NULL, &hr);
	if (mForcibleARGB32 == NULL) {
		hr = E_OUTOFMEMORY;
	} else {
		mForcibleARGB32->AddRef();
	}
	RETURN_HRESULT_IF_FAILED(hr);

	RETURN_HRESULT_IF_FAILED(
		mGraphBuilder->AddFilter(mForcibleARGB32, L"ForcibleARGB32"));

	RETURN_HRESULT_IF_FAILED(
		Connect(source, mForcibleARGB32));

	RETURN_HRESULT_IF_FAILED(
		Connect(mForcibleARGB32, grabber));

	return S_OK;
}

HRESULT VideoInput::BufferCB(void* buffer, long bufferSize)
{
	if (mBuffer == NULL) {
		return E_FAIL;
	}

	long dstBufferSize = mWidth * mHeight * 4;

	if (bufferSize < dstBufferSize) {
		CopyMemory(mBuffer, buffer, bufferSize);
	} else {
		CopyMemory(mBuffer, buffer, dstBufferSize);
	}

	return S_OK;
}

HRESULT VideoInput::FrameImageAtTime(LONGLONG time, void* buffer)
{
	mBuffer = buffer;

	RETURN_HRESULT_IF_FAILED(
		mMediaSeeking->SetPositions(&time, AM_SEEKING_AbsolutePositioning, NULL, AM_SEEKING_NoPositioning));

	RETURN_HRESULT_IF_FAILED(
		mMediaControl->Run());

	long eventCode;
	RETURN_HRESULT_IF_FAILED(
		mMediaEvent->WaitForCompletion(INFINITE, &eventCode));

	return S_OK;
}


AudioInput::AudioInput()
	: InputBase(),
	  mSampleRate(0), mSampleSizeInBits(0), mChannels(0), mFloat(false),
	  mBuffer(NULL), mRestSize(0), mCallbackCount(0)
{
}

AudioInput::~AudioInput()
{
}

HRESULT AudioInput::InitWithFile(LPCWSTR file)
{
	RETURN_HRESULT_IF_FAILED(
		InputBase::InitWithFile(file));


	RETURN_HRESULT_IF_FAILED(
		mSampleGrabber->SetOneShot(FALSE));


	return S_OK;
}

void AudioInput::GetMediaType(CMediaType* mediaType)
{
	mediaType->SetType(&MEDIATYPE_Audio);
//	mediaType->SetSubtype(&MEDIASUBTYPE_PCM);
	mediaType->SetFormatType(&FORMAT_WaveFormatEx);
}

HRESULT AudioInput::ConnectSourceToGrabber(IBaseFilter* source, IBaseFilter* grabber)
{
	HRESULT hr;
	IPin* inPin = NULL;
	CMediaType mt, audioMt(&MEDIATYPE_Audio);

	BAIL_IF_FAILED(bail,
		hr = Connect(source, grabber));

	BAIL_IF_FAILED(bail,
		hr = FirstPin(grabber, PINDIR_INPUT, &inPin));

	BAIL_IF_FAILED(bail,
		hr = (hr == S_OK ? S_OK : E_FAIL));

	BAIL_IF_FAILED(bail,
		hr = inPin->ConnectionMediaType(&mt));

	audioMt.SetSubtype(&MEDIASUBTYPE_PCM);
	if (!mt.MatchesPartial(&audioMt)) {
		audioMt.SetSubtype(&MEDIASUBTYPE_IEEE_FLOAT);
		if (!mt.MatchesPartial(&audioMt)) {
			hr = E_FAIL;
			goto bail;
		}
	}

	WAVEFORMATEX* wfx = (WAVEFORMATEX*)mt.Format();
	mSampleRate = wfx->nSamplesPerSec;
	mSampleSizeInBits = wfx->wBitsPerSample;
	mChannels = wfx->nChannels;
	mFloat = (wfx->wFormatTag == WAVE_FORMAT_IEEE_FLOAT);

bail:
	RELEASE(inPin);
	return hr;
}

HRESULT AudioInput::BufferCB(void* buffer, long bufferSize)
{
	if (mBuffer == NULL) {
		return E_FAIL;
	}

	if (mRestSize > 0) {
		long size = min(bufferSize, mRestSize);
		CopyMemory(mBuffer, buffer, size);
		mBuffer += size;
		mRestSize -= size;
	}

	if (mRestSize == 0 && mCallbackCount > 0) {
		mBuffer = NULL;
		mMediaControl->Pause();
	}

	++mCallbackCount;

	return S_OK;
}

HRESULT AudioInput::FillBuffer(LONGLONG time, void* buffer, long bufferSize)
{
	mBuffer = (char*) buffer;
	mRestSize = bufferSize;
	mCallbackCount = 0;

	RETURN_HRESULT_IF_FAILED(
		mMediaSeeking->SetPositions(&time, AM_SEEKING_AbsolutePositioning, NULL, AM_SEEKING_NoPositioning));

	RETURN_HRESULT_IF_FAILED(
		mMediaControl->Run());

	long eventCode;
	HRESULT hr = mMediaEvent->WaitForCompletion(INFINITE, &eventCode);
	if (mBuffer != NULL) {
		RETURN_HRESULT_IF_FAILED(hr);
		ZeroMemory(mBuffer, mRestSize);
	}

	return S_OK;
}


Callback::Callback(InputBase* input)
	: CUnknown(L"Callback", NULL), mInput(input)
{
}

Callback::~Callback()
{
}

STDMETHODIMP Callback::NonDelegatingQueryInterface(REFIID riid, void** ppv)
{
	if (riid == IID_ISampleGrabberCB) {
		return GetInterface((ISampleGrabberCB*)this, ppv);
	}
	return CUnknown::NonDelegatingQueryInterface(riid, ppv);
}

STDMETHODIMP Callback::SampleCB(double sampleTime, IMediaSample* sample)
{
	return E_NOTIMPL;
}

STDMETHODIMP Callback::BufferCB(double sampleTime, BYTE* buffer, long bufferSize)
{
	return mInput->BufferCB(buffer, bufferSize);
}
